mirror of https://github.com/inclusionAI/AReaL
0724_merge5
This commit is contained in:
parent
84e2d75a0d
commit
1bc9310252
|
@ -185,8 +185,6 @@ def main(args):
|
|||
stats_tracker.scope("grpo_actor"),
|
||||
):
|
||||
stats = actor.ppo_update(batch)
|
||||
wandb.log({"actor_reward": stats[0]["grpo_actor/final_reward/avg"]})
|
||||
wandb.log({"task_reward": stats[0]["grpo_actor/task_reward/avg"]})
|
||||
actor.step_lr_scheduler()
|
||||
log_gpu_stats("ppo update")
|
||||
|
||||
|
|
|
@ -24,8 +24,6 @@ from realhf.base import logging, seeding, stats_tracker
|
|||
logger = logging.getLogger("GSM8K grpo")
|
||||
|
||||
|
||||
|
||||
|
||||
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):
|
||||
from realhf.impl.dataset.math_parser import process_results
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ def main(args):
|
|||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
|
||||
train_dataset = get_custom_dataset(
|
||||
path=config.train_dataset.path,
|
||||
rank=rank,
|
||||
|
@ -52,7 +52,7 @@ def main(args):
|
|||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
valid_dataloader = StatefulDataLoader(
|
||||
valid_dataset ,
|
||||
valid_dataset,
|
||||
batch_size=config.valid_dataset.batch_size // world_size,
|
||||
shuffle=config.valid_dataset.shuffle,
|
||||
num_workers=config.valid_dataset.num_workers,
|
||||
|
|
Loading…
Reference in New Issue