This commit is contained in:
lichangye.lcy 2025-07-31 16:35:13 +08:00
parent c5cd21d5db
commit 78d0367ff2
4 changed files with 18 additions and 12 deletions

View File

@ -637,6 +637,9 @@ class DatasetConfig:
default=0, metadata={"help": "Number of worker processes for data loading"}
)
drop_last: bool = field(default=True)
reward_fn: Optional[str] = field(
default=None,
)
@dataclass

View File

@ -99,6 +99,7 @@ train_dataset:
num_workers: 4
path: hiyouga/geometry3k
type: rl
reward_fn: ${train_dataset.path}
valid_dataset:
batch_size: 32
@ -108,8 +109,8 @@ valid_dataset:
path: ${train_dataset.path}
type: rl
reward_fn:
path: ${train_dataset.path}
# Utilities
saver:
@ -140,6 +141,5 @@ stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
project: geometry3k-grpo
wandb:
project: geometry3k-grpo

View File

@ -13,19 +13,20 @@ def geometry3k_reward_fn(
):
sol = extract_answer(completions, data_name="") # str number
ans = answer
sol = sol.replace(" ", "")
ans= ans.replace(" ", "")
if sol is None:
return 0
if ans is None:
return 0
is_numeric = sol.replace('.', '', 1).isdigit() # Allows for decimal check
is_latex = sol.startswith("\\frac") or '\\sqrt' in sol
print(f"completions: {completions}, answer: {answer}")
is_numeric = sol.replace('.', '', 1).isdigit() or ans.replace('.', '', 1).isdigit() # Allows for decimal check
is_latex = sol.startswith("\\frac") or '\\sqrt' in sol or ans.startswith("\\frac") or '\\sqrt' in ans
print(f"sol: {sol}, ans: {ans}")
# Exact answer matching
if sol == ans :
reward = 1
elif is_numeric and abs(float(sol) - float(ans)) < 1e-4:
elif is_numeric and not is_latex and abs(float(sol) - float(ans)) < 1e-4:
reward = 0.8 # Reward for correct numerical approximation
elif is_latex:
# Check if numbers in LaTeX are correct

View File

@ -28,11 +28,13 @@ from realhf.base import stats_tracker
def main(args):
wandb.init(project=config.wandb.project)
config, _ = load_expr_config(args, GRPOConfig)
config: GRPOConfig
wandb.init(project=config.stats_logger.wandb.project)
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
@ -114,7 +116,7 @@ def main(args):
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
reward_fn = custom_reward_fn(
path=config.reward_fn.path,
path=config.train_dataset.reward_fn,
)
workflow = VisionRLVRWorkflow(