0724_merge5

This commit is contained in:
朱晗 2025-07-24 15:38:24 +08:00
parent 84e2d75a0d
commit 1bc9310252
3 changed files with 2 additions and 6 deletions

View File

@ -185,8 +185,6 @@ def main(args):
stats_tracker.scope("grpo_actor"),
):
stats = actor.ppo_update(batch)
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
wandb.log({"task_reward": stats[0]["grpo_actor/task_reward/avg"]})
actor.step_lr_scheduler()
log_gpu_stats("ppo update")

View File

@ -24,8 +24,6 @@ from realhf.base import logging, seeding, stats_tracker
logger = logging.getLogger("GSM8K grpo")
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
from realhf.impl.dataset.math_parser import process_results

View File

@ -24,7 +24,7 @@ def main(args):
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
tokenizer = load_hf_tokenizer(config.tokenizer_path)
train_dataset = get_custom_dataset(
path=config.train_dataset.path,
rank=rank,
@ -52,7 +52,7 @@ def main(args):
drop_last=config.train_dataset.drop_last,
)
valid_dataloader = StatefulDataLoader(
valid_dataset ,
valid_dataset,
batch_size=config.valid_dataset.batch_size // world_size,
shuffle=config.valid_dataset.shuffle,
num_workers=config.valid_dataset.num_workers,