mirror of https://github.com/inclusionAI/AReaL
0724_4
This commit is contained in:
parent
e705db12f4
commit
6aeeabf7b9
|
@ -44,10 +44,8 @@ def clevr_count_70k_reward_fn(
|
|||
|
||||
if sol.strip() == ans.strip():
|
||||
print(f"completions: {completions}, answer: {answer}")
|
||||
if is_thinking:
|
||||
return 1
|
||||
else:
|
||||
return 0.8
|
||||
return 1
|
||||
|
||||
|
||||
if re.match(r"^\[\d+(\.\d+)?\]$", sol.strip()):
|
||||
return 0.05
|
||||
|
@ -185,6 +183,8 @@ def main(args):
|
|||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
|
||||
wandb.log({"task_reward": stats[0]["grpo_actor/task_reward/avg"]})
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("ppo update")
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ dependencies = [
|
|||
"huggingface_hub",
|
||||
"datasets",
|
||||
"accelerate",
|
||||
"transformers>=4.53.1",
|
||||
"transformers==4.53.1",
|
||||
|
||||
# Scientific computing
|
||||
"numpy<2.0.0",
|
||||
|
|
Loading…
Reference in New Issue