fix math reward verifier (#156)

* PullRequest: 293 fix get_param_realloc_path

Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
https://code.alipay.com/inclusionAI/AReaL/pull_requests/293

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix get_param_realloc_path

* PullRequest: 297 bugfix: reward is always -5

Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
https://code.alipay.com/inclusionAI/AReaL/pull_requests/297

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* bugfix: reward is always -5

* PullRequest: 321 fix checkpoint save dir

Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
https://code.alipay.com/inclusionAI/AReaL/pull_requests/321

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix checkpoint save dir

* PullRequest: 328 [Doc] update installation

Merge branch sxj/doc of git@code.alipay.com:inclusionAI/AReaL.git into gh
https://code.alipay.com/inclusionAI/AReaL/pull_requests/328

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* [Doc] update installation

* PullRequest: 329 bugfix: math verifier blocks the async training

Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
https://code.alipay.com/inclusionAI/AReaL/pull_requests/329

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* bugfix: math verifier block the async training

* format

---------

Co-authored-by: 冰临 <shenxujie.sxj@antgroup.com>
Co-authored-by: garrett4wade <fuwth17@gmail.com>
This commit is contained in:
xssstory 2025-07-07 15:49:13 +08:00 committed by GitHub
parent 5b7c83b5d9
commit 17ea7fe94d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 5 deletions

View File

@ -95,7 +95,7 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
return result["result"], result["info"] return result["result"], result["info"]
def code_verify(id2info, generateds, query_ids, debug=False): def code_verify(id2info, generateds, query_ids, max_workers=None, debug=False):
assert len(generateds) == len(query_ids) assert len(generateds) == len(query_ids)
problems = [id2info[qid] for qid in query_ids] problems = [id2info[qid] for qid in query_ids]
@ -106,8 +106,10 @@ def code_verify(id2info, generateds, query_ids, debug=False):
infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT)) infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT))
run_results = [] run_results = []
num_process = max(1, os.cpu_count() // 8) if max_workers is None:
with concurrent.futures.ProcessPoolExecutor(num_process) as executor: max_workers = max(1, os.cpu_count() // 8)
with concurrent.futures.ProcessPoolExecutor(max_workers) as executor:
run_results = executor.map(call_verify, *zip(*infer_args)) run_results = executor.map(call_verify, *zip(*infer_args))
for run_result in run_results: for run_result in run_results:

View File

@ -7,7 +7,7 @@ from typing import List, Union
import regex import regex
from latex2sympy2 import latex2sympy from latex2sympy2 import latex2sympy
from pebble import ProcessPool from pebble import ProcessExpired, ProcessPool
from sympy import N, simplify from sympy import N, simplify
from sympy.parsing.latex import parse_latex from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr from sympy.parsing.sympy_parser import parse_expr
@ -289,6 +289,7 @@ def strip_string(string, skip_unit=False):
# remove percentage # remove percentage
string = string.replace("\\%", "") string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace("%", "") string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
@ -398,7 +399,7 @@ def extract_answer(pred_str, data_name, use_last_number=True):
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
else: # use the last number else: # use the last number
if use_last_number: if use_last_number:
pattern = r"-?\d*\.?\d+" pattern = "-?\d*\.?\d+"
pred = re.findall(pattern, pred_str.replace(",", "")) pred = re.findall(pattern, pred_str.replace(",", ""))
if len(pred) >= 1: if len(pred) >= 1:
pred = pred[-1] pred = pred[-1]
@ -836,6 +837,12 @@ def parse_lines_in_parallel(
# print("[debug: timeout]") # print("[debug: timeout]")
logger.warning(f"Timeout occurred while justifying the math answer.") logger.warning(f"Timeout occurred while justifying the math answer.")
x = (0, "timeout", "timeout") x = (0, "timeout", "timeout")
except ProcessExpired as e:
logger.warning(f"Process terminated abnormally: {e}")
x = (0, "error", "error")
except Exception as e:
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
x = (0, "error", "error")
label = label or x[0] label = label or x[0]
labels.append(label) labels.append(label)
return labels return labels

View File

@ -57,6 +57,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
self.id2info, self.id2info,
answers, answers,
[qid for _ in range(group_size)], [qid for _ in range(group_size)],
max_workers=1,
) )
elif cur_task == "code": elif cur_task == "code":
answers = [extract_code(x) for x in answers] answers = [extract_code(x) for x in answers]
@ -65,6 +66,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
self.id2info, self.id2info,
answers, answers,
[qid for _ in range(group_size)], [qid for _ in range(group_size)],
max_workers=1,
) )
else: else:
raise NotImplementedError() raise NotImplementedError()