diff --git a/arealite/api/cli_args.py b/arealite/api/cli_args.py index a25c841..5da0087 100644 --- a/arealite/api/cli_args.py +++ b/arealite/api/cli_args.py @@ -637,6 +637,9 @@ class DatasetConfig: default=0, metadata={"help": "Number of worker processes for data loading"} ) drop_last: bool = field(default=True) + reward_fn: Optional[str] = field( + default=None, + ) @dataclass diff --git a/examples/arealite/configs/geometry3k_grpo.yaml b/examples/arealite/configs/geometry3k_grpo.yaml index 59e2bb3..21b9509 100644 --- a/examples/arealite/configs/geometry3k_grpo.yaml +++ b/examples/arealite/configs/geometry3k_grpo.yaml @@ -99,6 +99,7 @@ train_dataset: num_workers: 4 path: hiyouga/geometry3k type: rl + reward_fn: ${train_dataset.path} valid_dataset: batch_size: 32 @@ -108,8 +109,8 @@ valid_dataset: path: ${train_dataset.path} type: rl -reward_fn: - path: ${train_dataset.path} + + # Utilities saver: @@ -140,6 +141,5 @@ stats_logger: experiment_name: ${experiment_name} trial_name: ${trial_name} fileroot: ${cluster.fileroot} - -wandb: - project: geometry3k-grpo \ No newline at end of file + wandb: + project: geometry3k-grpo \ No newline at end of file diff --git a/examples/arealite/reward/geometry3k.py b/examples/arealite/reward/geometry3k.py index 0978e23..4862a3f 100644 --- a/examples/arealite/reward/geometry3k.py +++ b/examples/arealite/reward/geometry3k.py @@ -13,19 +13,20 @@ def geometry3k_reward_fn( ): sol = extract_answer(completions, data_name="") # str number ans = answer - + sol = sol.replace(" ", "") + ans= ans.replace(" ", "") if sol is None: return 0 if ans is None: return 0 - is_numeric = sol.replace('.', '', 1).isdigit() # Allows for decimal check - is_latex = sol.startswith("\\frac") or '\\sqrt' in sol - print(f"completions: {completions}, answer: {answer}") + is_numeric = sol.replace('.', '', 1).isdigit() or ans.replace('.', '', 1).isdigit() # Allows for decimal check + is_latex = sol.startswith("\\frac") or '\\sqrt' in sol or ans.startswith("\\frac") or '\\sqrt' in ans + print(f"sol: {sol}, ans: {ans}") # Exact answer matching if sol == ans : reward = 1 - elif is_numeric and abs(float(sol) - float(ans)) < 1e-4: + elif is_numeric and not is_latex and abs(float(sol) - float(ans)) < 1e-4: reward = 0.8 # Reward for correct numerical approximation elif is_latex: # Check if numbers in LaTeX are correct diff --git a/examples/arealite/vision_grpo.py b/examples/arealite/vision_grpo.py index 71c8c05..dba161e 100644 --- a/examples/arealite/vision_grpo.py +++ b/examples/arealite/vision_grpo.py @@ -28,11 +28,13 @@ from realhf.base import stats_tracker def main(args): - wandb.init(project=config.wandb.project) + config, _ = load_expr_config(args, GRPOConfig) config: GRPOConfig + wandb.init(project=config.stats_logger.wandb.project) + rank = int(os.getenv("RANK")) world_size = int(os.getenv("WORLD_SIZE")) processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path) @@ -114,7 +116,7 @@ def main(args): config.gconfig.stop_token_ids.append(tokenizer.eos_token_id) reward_fn = custom_reward_fn( - path=config.reward_fn.path, + path=config.train_dataset.reward_fn, ) workflow = VisionRLVRWorkflow(