This commit is contained in:
lichangye.lcy 2025-07-25 13:42:42 +08:00
parent b8549ac48a
commit 4198cd695c
7 changed files with 26 additions and 30 deletions

View File

@ -156,7 +156,7 @@ class FSDPEngine(BaseHFEngine):
update_name = names.update_weights_from_disk(
self.config.experiment_name,
self.config.trial_name,
self.model_version,
self.get_version(),
)
name_resolve.add(
update_name, str(datetime.now().timestamp()), keepalive_ttl=120

View File

@ -582,6 +582,8 @@ def update_weights_from_distributed(
for addr in addresses
]
)
logger.info(f"Distributed update weights done in {time.perf_counter() - tik}s")
return uvloop.run(_fn())

View File

@ -17,7 +17,6 @@ from realhf.base import gpu_utils, logging, name_resolve, names
from realhf.scheduler.client import JobException, JobInfo, JobState
logger = logging.getLogger("Local Scheduler")
JOB_STATE_TO_PROCESS_STATUS = {
JobState.NOT_FOUND: [],
JobState.PENDING: [psutil.STATUS_PARKED],

View File

@ -1,3 +1,4 @@
import os
import asyncio
import uuid
@ -21,8 +22,9 @@ class VisionRLVRWorkflow(RLVRWorkflow):
tokenizer: PreTrainedTokenizerFast,
processor: AutoProcessor,
enable_thinking: bool,
dump_dir: str | None = None,
):
super().__init__(reward_fn, gconfig, tokenizer, enable_thinking)
super().__init__(reward_fn, gconfig, tokenizer, enable_thinking, dump_dir)
self.processor = processor
async def arun_episode(self, engine, data):

View File

@ -6,6 +6,7 @@ import torch
import torch.distributed as dist
import wandb
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import Subset
from arealite.workflow.vision_rlvr import VisionRLVRWorkflow
from arealite.api.cli_args import GRPOConfig, load_expr_config
@ -32,8 +33,6 @@ def extract_answer(pred_str, data_name, use_last_number=True):
def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
is_thinking = "thinking" in completions.lower()
sol = extract_answer(completions, data_name="") # str number
ans = answer
@ -46,14 +45,11 @@ def clevr_count_70k_reward_fn(
print(f"completions: {completions}, answer: {answer}")
return 1
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
return 0.05
return 0
def main(args):
wandb.init(project="clevr_70k")
config, _ = load_expr_config(args, GRPOConfig)
@ -70,6 +66,16 @@ def main(args):
type=config.train_dataset.type,
processor=processor,
)
train_size = len(train_dataset)
subset_size = int(0.3 * train_size)
# 随机选择 30% 数据的索引
random_indices = torch.randperm(train_size).tolist()[:subset_size]
# 创建一个新的子集数据集
subset_train_dataset = Subset(train_dataset, random_indices)
valid_dataset = get_custom_dataset(
path=config.valid_dataset.path,
rank=rank,
@ -80,7 +86,7 @@ def main(args):
)
# Create dataset and dataloaders
train_dataloader = StatefulDataLoader(
train_dataset,
subset_train_dataset,
batch_size=config.train_dataset.batch_size // world_size,
shuffle=config.train_dataset.shuffle,
num_workers=config.train_dataset.num_workers,
@ -122,9 +128,8 @@ def main(args):
# due to `engine.get_param_specs()`.
# Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
weight_update_meta = [
WeightUpdateMeta.from_fsdp_nccl(
AllocationMode.from_str(config.allocation_mode), actor
)
WeightUpdateMeta.from_disk(
config.saver)
]
dist.broadcast_object_list(weight_update_meta, src=0)
weight_update_meta = weight_update_meta[0]
@ -141,11 +146,8 @@ def main(args):
tokenizer=tokenizer,
processor=processor,
enable_thinking=False,
dump_dir=os.path.join(
StatsLogger.get_log_path(config.stats_logger), "generated"
),
)
# Run training.
saver = Saver(config.saver, ft_spec, for_recover=False)
logger = StatsLogger(config.stats_logger, ft_spec)
@ -198,7 +200,7 @@ def main(args):
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
wandb.log({"final_reward": stats[0]["grpo_actor/final_reward/avg"]})
wandb.log({"task_reward": stats[0]["grpo_actor/task_reward/avg"]})
actor.step_lr_scheduler()
log_gpu_stats("ppo update")
@ -230,14 +232,7 @@ def main(args):
cnt += 1
batch = eval_rollout.wait(cnt, timeout=None)
rewards = batch["rewards"].float().to(actor.device)
wandb.log(
{
"eval_reward": rewards.mean().item(),
"epoch": epoch,
"step": step,
"global_step": global_step,
}
)
wandb.log({"eval_reward": rewards.mean().item()})
with stats_tracker.scope("grpo-eval"):
stats_tracker.denominator(
n_seqs=torch.ones(

View File

@ -38,7 +38,7 @@ gconfig:
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: /storage/openpsi/models/Qwen2.5-VL-7B-Instruct
path: /storage/openpsi/models/Qwen2.5-VL-3B-Instruct
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: false
@ -89,7 +89,7 @@ sglang:
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.9
mem_fraction_static: 0.8
# datasets
train_dataset:

View File

@ -2,8 +2,6 @@ experiment_name: gsm8k-grpo
trial_name: trial0
allocation_mode: sglang.d4p1t1+d4p1t1
cluster:
n_nodes: 1
n_gpus_per_node: 8
fileroot: /tmp/arealite/experiments
n_nodes: 1
n_gpus_per_node: 8