mirror of https://github.com/inclusionAI/AReaL
0724_merge6
This commit is contained in:
parent
1bc9310252
commit
13fc236c99
|
@ -163,7 +163,7 @@ def main(args):
|
|||
# Create barrier to synchronize all rollout processes.
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
# breakpoint()
|
||||
|
||||
if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
|
||||
with stats_tracker.record_timing("recompute_logp"):
|
||||
logp = actor.compute_logp(batch)
|
||||
|
@ -185,6 +185,8 @@ def main(args):
|
|||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
|
||||
wandb.log({"task_reward": stats[0]["grpo_actor/task_reward/avg"]})
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("ppo update")
|
||||
|
||||
|
|
Loading…
Reference in New Issue