AReaL/examples/arealite/reward/geometry3k.py

27 lines
722 B
Python

import re
def extract_answer(pred_str, data_name, use_last_number=True):
matches = re.findall(r"\[([^\]]+)\]", pred_str)
if matches:
return matches[-1]
return ""
def geometry3k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
sol = extract_answer(completions, data_name="") # str number
ans = answer
sol = sol.replace(" ", "")
ans= ans.replace(" ", "")
if sol is None:
return 0
if ans is None:
return 0
# print(f"sol: {sol}, ans: {ans}")
from realhf.impl.dataset.math_parser import math_equal
if math_equal(sol, ans):
print(f"completions: {completions}, answer: {answer}")
return 1
return 0