mirror of https://github.com/inclusionAI/AReaL
0725_1
This commit is contained in:
parent
b8549ac48a
commit
4198cd695c
|
@ -156,7 +156,7 @@ class FSDPEngine(BaseHFEngine):
|
|||
update_name = names.update_weights_from_disk(
|
||||
self.config.experiment_name,
|
||||
self.config.trial_name,
|
||||
self.model_version,
|
||||
self.get_version(),
|
||||
)
|
||||
name_resolve.add(
|
||||
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
|
||||
|
|
|
@ -582,6 +582,8 @@ def update_weights_from_distributed(
|
|||
for addr in addresses
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
|
||||
|
||||
return uvloop.run(_fn())
|
||||
|
|
|
@ -17,7 +17,6 @@ from realhf.base import gpu_utils, logging, name_resolve, names
|
|||
from realhf.scheduler.client import JobException, JobInfo, JobState
|
||||
|
||||
logger = logging.getLogger("Local Scheduler")
|
||||
|
||||
JOB_STATE_TO_PROCESS_STATUS = {
|
||||
JobState.NOT_FOUND: [],
|
||||
JobState.PENDING: [psutil.STATUS_PARKED],
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
|
@ -21,8 +22,9 @@ class VisionRLVRWorkflow(RLVRWorkflow):
|
|||
tokenizer: PreTrainedTokenizerFast,
|
||||
processor: AutoProcessor,
|
||||
enable_thinking: bool,
|
||||
dump_dir: str | None = None,
|
||||
):
|
||||
super().__init__(reward_fn, gconfig, tokenizer, enable_thinking)
|
||||
super().__init__(reward_fn, gconfig, tokenizer, enable_thinking, dump_dir)
|
||||
self.processor = processor
|
||||
|
||||
async def arun_episode(self, engine, data):
|
||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import wandb
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torch.utils.data import Subset
|
||||
|
||||
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
|
||||
from arealite.api.cli_args import GRPOConfig, load_expr_config
|
||||
|
@ -32,8 +33,6 @@ def extract_answer(pred_str, data_name, use_last_number=True):
|
|||
def clevr_count_70k_reward_fn(
|
||||
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
|
||||
):
|
||||
is_thinking = "thinking" in completions.lower()
|
||||
|
||||
sol = extract_answer(completions, data_name="") # str number
|
||||
ans = answer
|
||||
|
||||
|
@ -46,14 +45,11 @@ def clevr_count_70k_reward_fn(
|
|||
print(f"completions: {completions}, answer: {answer}")
|
||||
return 1
|
||||
|
||||
|
||||
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
|
||||
return 0.05
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
wandb.init(project="clevr_70k")
|
||||
|
||||
config, _ = load_expr_config(args, GRPOConfig)
|
||||
|
@ -70,6 +66,16 @@ def main(args):
|
|||
type=config.train_dataset.type,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
train_size = len(train_dataset)
|
||||
subset_size = int(0.3 * train_size)
|
||||
|
||||
# 随机选择 30% 数据的索引
|
||||
random_indices = torch.randperm(train_size).tolist()[:subset_size]
|
||||
|
||||
# 创建一个新的子集数据集
|
||||
subset_train_dataset = Subset(train_dataset, random_indices)
|
||||
|
||||
valid_dataset = get_custom_dataset(
|
||||
path=config.valid_dataset.path,
|
||||
rank=rank,
|
||||
|
@ -80,7 +86,7 @@ def main(args):
|
|||
)
|
||||
# Create dataset and dataloaders
|
||||
train_dataloader = StatefulDataLoader(
|
||||
train_dataset,
|
||||
subset_train_dataset,
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
shuffle=config.train_dataset.shuffle,
|
||||
num_workers=config.train_dataset.num_workers,
|
||||
|
@ -122,9 +128,8 @@ def main(args):
|
|||
# due to `engine.get_param_specs()`.
|
||||
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
|
||||
weight_update_meta = [
|
||||
WeightUpdateMeta.from_fsdp_nccl(
|
||||
AllocationMode.from_str(config.allocation_mode), actor
|
||||
)
|
||||
WeightUpdateMeta.from_disk(
|
||||
config.saver)
|
||||
]
|
||||
dist.broadcast_object_list(weight_update_meta, src=0)
|
||||
weight_update_meta = weight_update_meta[0]
|
||||
|
@ -141,11 +146,8 @@ def main(args):
|
|||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
enable_thinking=False,
|
||||
dump_dir=os.path.join(
|
||||
StatsLogger.get_log_path(config.stats_logger), "generated"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Run training.
|
||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||
|
@ -198,7 +200,7 @@ def main(args):
|
|||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
|
||||
wandb.log({"final_reward": stats[0]["grpo_actor/final_reward/avg"]})
|
||||
wandb.log({"task_reward": stats[0]["grpo_actor/task_reward/avg"]})
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("ppo update")
|
||||
|
@ -230,14 +232,7 @@ def main(args):
|
|||
cnt += 1
|
||||
batch = eval_rollout.wait(cnt, timeout=None)
|
||||
rewards = batch["rewards"].float().to(actor.device)
|
||||
wandb.log(
|
||||
{
|
||||
"eval_reward": rewards.mean().item(),
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"global_step": global_step,
|
||||
}
|
||||
)
|
||||
wandb.log({"eval_reward": rewards.mean().item()})
|
||||
with stats_tracker.scope("grpo-eval"):
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
|
|
|
@ -38,7 +38,7 @@ gconfig:
|
|||
actor:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/Qwen2.5-VL-7B-Instruct
|
||||
path: /storage/openpsi/models/Qwen2.5-VL-3B-Instruct
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
gradient_checkpointing: false
|
||||
|
@ -89,7 +89,7 @@ sglang:
|
|||
dtype: ${actor.dtype}
|
||||
max_running_requests: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.9
|
||||
mem_fraction_static: 0.8
|
||||
|
||||
# datasets
|
||||
train_dataset:
|
||||
|
|
|
@ -2,8 +2,6 @@ experiment_name: gsm8k-grpo
|
|||
trial_name: trial0
|
||||
allocation_mode: sglang.d4p1t1+d4p1t1
|
||||
cluster:
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
fileroot: /tmp/arealite/experiments
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
|
|
Loading…
Reference in New Issue