Merge branch 'main' of https://github.com/inclusionAI/AReaL into lite

This commit is contained in:
博惟 2025-07-10 12:53:30 +08:00
commit d48bf007cf
4 changed files with 16 additions and 5 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 163 KiB

After

Width:  |  Height:  |  Size: 2.7 KiB

View File

@ -95,7 +95,7 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
return result["result"], result["info"]
def code_verify(id2info, generateds, query_ids, debug=False):
def code_verify(id2info, generateds, query_ids, max_workers=None, debug=False):
assert len(generateds) == len(query_ids)
problems = [id2info[qid] for qid in query_ids]
@ -106,8 +106,10 @@ def code_verify(id2info, generateds, query_ids, debug=False):
infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT))
run_results = []
num_process = max(1, os.cpu_count() // 8)
with concurrent.futures.ProcessPoolExecutor(num_process) as executor:
if max_workers is None:
max_workers = max(1, os.cpu_count() // 8)
with concurrent.futures.ProcessPoolExecutor(max_workers) as executor:
run_results = executor.map(call_verify, *zip(*infer_args))
for run_result in run_results:

View File

@ -7,7 +7,7 @@ from typing import List, Union
import regex
from latex2sympy2 import latex2sympy
from pebble import ProcessPool
from pebble import ProcessExpired, ProcessPool
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
@ -289,6 +289,7 @@ def strip_string(string, skip_unit=False):
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
@ -398,7 +399,7 @@ def extract_answer(pred_str, data_name, use_last_number=True):
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
else: # use the last number
if use_last_number:
pattern = r"-?\d*\.?\d+"
pattern = "-?\d*\.?\d+"
pred = re.findall(pattern, pred_str.replace(",", ""))
if len(pred) >= 1:
pred = pred[-1]
@ -836,6 +837,12 @@ def parse_lines_in_parallel(
# print("[debug: timeout]")
logger.warning(f"Timeout occurred while justifying the math answer.")
x = (0, "timeout", "timeout")
except ProcessExpired as e:
logger.warning(f"Process terminated abnormally: {e}")
x = (0, "error", "error")
except Exception as e:
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
x = (0, "error", "error")
label = label or x[0]
labels.append(label)
return labels

View File

@ -57,6 +57,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
self.id2info,
answers,
[qid for _ in range(group_size)],
max_workers=1,
)
elif cur_task == "code":
answers = [extract_code(x) for x in answers]
@ -65,6 +66,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
self.id2info,
answers,
[qid for _ in range(group_size)],
max_workers=1,
)
else:
raise NotImplementedError()