fix math reward verifier (#156)

* PullRequest: 293 fix get_param_realloc_path

Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
https://code.alipay.com/inclusionAI/AReaL/pull_requests/293

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix get_param_realloc_path

* PullRequest: 297 bugfix: reward is always -5

Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
https://code.alipay.com/inclusionAI/AReaL/pull_requests/297

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* bugfix: reward is always -5

* PullRequest: 321 fix checkpoint save dir

Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
https://code.alipay.com/inclusionAI/AReaL/pull_requests/321

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix checkpoint save dir

* PullRequest: 328 [Doc] update installation

Merge branch sxj/doc of git@code.alipay.com:inclusionAI/AReaL.git into gh
https://code.alipay.com/inclusionAI/AReaL/pull_requests/328

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* [Doc] update installation

* PullRequest: 329 bugfix: math verifier blocks the async training

Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
https://code.alipay.com/inclusionAI/AReaL/pull_requests/329

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* bugfix: math verifier block the async training

* format

---------

Co-authored-by: 冰临 <shenxujie.sxj@antgroup.com>
Co-authored-by: garrett4wade <fuwth17@gmail.com>
This commit is contained in:
xssstory 2025-07-07 15:49:13 +08:00 committed by GitHub
parent 5b7c83b5d9
commit 17ea7fe94d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 5 deletions

View File

@ -95,7 +95,7 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
return result["result"], result["info"]
def code_verify(id2info, generateds, query_ids, debug=False):
def code_verify(id2info, generateds, query_ids, max_workers=None, debug=False):
assert len(generateds) == len(query_ids)
problems = [id2info[qid] for qid in query_ids]
@ -106,8 +106,10 @@ def code_verify(id2info, generateds, query_ids, debug=False):
infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT))
run_results = []
num_process = max(1, os.cpu_count() // 8)
with concurrent.futures.ProcessPoolExecutor(num_process) as executor:
if max_workers is None:
max_workers = max(1, os.cpu_count() // 8)
with concurrent.futures.ProcessPoolExecutor(max_workers) as executor:
run_results = executor.map(call_verify, *zip(*infer_args))
for run_result in run_results:

View File

@ -7,7 +7,7 @@ from typing import List, Union
import regex
from latex2sympy2 import latex2sympy
from pebble import ProcessPool
from pebble import ProcessExpired, ProcessPool
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
@ -289,6 +289,7 @@ def strip_string(string, skip_unit=False):
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
@ -398,7 +399,7 @@ def extract_answer(pred_str, data_name, use_last_number=True):
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
else: # use the last number
if use_last_number:
pattern = r"-?\d*\.?\d+"
pattern = "-?\d*\.?\d+"
pred = re.findall(pattern, pred_str.replace(",", ""))
if len(pred) >= 1:
pred = pred[-1]
@ -836,6 +837,12 @@ def parse_lines_in_parallel(
# print("[debug: timeout]")
logger.warning(f"Timeout occurred while justifying the math answer.")
x = (0, "timeout", "timeout")
except ProcessExpired as e:
logger.warning(f"Process terminated abnormally: {e}")
x = (0, "error", "error")
except Exception as e:
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
x = (0, "error", "error")
label = label or x[0]
labels.append(label)
return labels

View File

@ -57,6 +57,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
self.id2info,
answers,
[qid for _ in range(group_size)],
max_workers=1,
)
elif cur_task == "code":
answers = [extract_code(x) for x in answers]
@ -65,6 +66,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
self.id2info,
answers,
[qid for _ in range(group_size)],
max_workers=1,
)
else:
raise NotImplementedError()