mirror of https://github.com/inclusionAI/AReaL
0731_2
This commit is contained in:
parent
c5cd21d5db
commit
78d0367ff2
|
@ -637,6 +637,9 @@ class DatasetConfig:
|
|||
default=0, metadata={"help": "Number of worker processes for data loading"}
|
||||
)
|
||||
drop_last: bool = field(default=True)
|
||||
reward_fn: Optional[str] = field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -99,6 +99,7 @@ train_dataset:
|
|||
num_workers: 4
|
||||
path: hiyouga/geometry3k
|
||||
type: rl
|
||||
reward_fn: ${train_dataset.path}
|
||||
|
||||
valid_dataset:
|
||||
batch_size: 32
|
||||
|
@ -108,8 +109,8 @@ valid_dataset:
|
|||
path: ${train_dataset.path}
|
||||
type: rl
|
||||
|
||||
reward_fn:
|
||||
path: ${train_dataset.path}
|
||||
|
||||
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
|
@ -140,6 +141,5 @@ stats_logger:
|
|||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
|
||||
wandb:
|
||||
project: geometry3k-grpo
|
||||
wandb:
|
||||
project: geometry3k-grpo
|
|
@ -13,19 +13,20 @@ def geometry3k_reward_fn(
|
|||
):
|
||||
sol = extract_answer(completions, data_name="") # str number
|
||||
ans = answer
|
||||
|
||||
sol = sol.replace(" ", "")
|
||||
ans= ans.replace(" ", "")
|
||||
if sol is None:
|
||||
return 0
|
||||
if ans is None:
|
||||
return 0
|
||||
|
||||
is_numeric = sol.replace('.', '', 1).isdigit() # Allows for decimal check
|
||||
is_latex = sol.startswith("\\frac") or '\\sqrt' in sol
|
||||
print(f"completions: {completions}, answer: {answer}")
|
||||
is_numeric = sol.replace('.', '', 1).isdigit() or ans.replace('.', '', 1).isdigit() # Allows for decimal check
|
||||
is_latex = sol.startswith("\\frac") or '\\sqrt' in sol or ans.startswith("\\frac") or '\\sqrt' in ans
|
||||
print(f"sol: {sol}, ans: {ans}")
|
||||
# Exact answer matching
|
||||
if sol == ans :
|
||||
reward = 1
|
||||
elif is_numeric and abs(float(sol) - float(ans)) < 1e-4:
|
||||
elif is_numeric and not is_latex and abs(float(sol) - float(ans)) < 1e-4:
|
||||
reward = 0.8 # Reward for correct numerical approximation
|
||||
elif is_latex:
|
||||
# Check if numbers in LaTeX are correct
|
||||
|
|
|
@ -28,11 +28,13 @@ from realhf.base import stats_tracker
|
|||
|
||||
def main(args):
|
||||
|
||||
wandb.init(project=config.wandb.project)
|
||||
|
||||
|
||||
config, _ = load_expr_config(args, GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
wandb.init(project=config.stats_logger.wandb.project)
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
|
||||
|
@ -114,7 +116,7 @@ def main(args):
|
|||
config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
reward_fn = custom_reward_fn(
|
||||
path=config.reward_fn.path,
|
||||
path=config.train_dataset.reward_fn,
|
||||
)
|
||||
|
||||
workflow = VisionRLVRWorkflow(
|
||||
|
|
Loading…
Reference in New Issue