mirror of https://github.com/inclusionAI/AReaL
fix math reward verifier (#156)
* PullRequest: 293 fix get_param_realloc_path Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/293 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * fix get_param_realloc_path * PullRequest: 297 bugfix: reward is always -5 Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/297 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * bugfix: reward is always -5 * PullRequest: 321 fix checkpoint save dir Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/321 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * fix checkpoint save dir * PullRequest: 328 [Doc] update installation Merge branch sxj/doc of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/328 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * [Doc] update installation * PullRequest: 329 bugfix: math verifier blocks the async training Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/329 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * bugfix: math verifier block the async training * format --------- Co-authored-by: 冰临 <shenxujie.sxj@antgroup.com> Co-authored-by: garrett4wade <fuwth17@gmail.com>
This commit is contained in:
parent
5b7c83b5d9
commit
17ea7fe94d
|
@ -95,7 +95,7 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
|
|||
return result["result"], result["info"]
|
||||
|
||||
|
||||
def code_verify(id2info, generateds, query_ids, debug=False):
|
||||
def code_verify(id2info, generateds, query_ids, max_workers=None, debug=False):
|
||||
assert len(generateds) == len(query_ids)
|
||||
problems = [id2info[qid] for qid in query_ids]
|
||||
|
||||
|
@ -106,8 +106,10 @@ def code_verify(id2info, generateds, query_ids, debug=False):
|
|||
infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT))
|
||||
|
||||
run_results = []
|
||||
num_process = max(1, os.cpu_count() // 8)
|
||||
with concurrent.futures.ProcessPoolExecutor(num_process) as executor:
|
||||
if max_workers is None:
|
||||
max_workers = max(1, os.cpu_count() // 8)
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers) as executor:
|
||||
run_results = executor.map(call_verify, *zip(*infer_args))
|
||||
|
||||
for run_result in run_results:
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import List, Union
|
|||
|
||||
import regex
|
||||
from latex2sympy2 import latex2sympy
|
||||
from pebble import ProcessPool
|
||||
from pebble import ProcessExpired, ProcessPool
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
@ -289,6 +289,7 @@ def strip_string(string, skip_unit=False):
|
|||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace("\%", "")
|
||||
string = string.replace("%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
|
@ -398,7 +399,7 @@ def extract_answer(pred_str, data_name, use_last_number=True):
|
|||
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
|
||||
else: # use the last number
|
||||
if use_last_number:
|
||||
pattern = r"-?\d*\.?\d+"
|
||||
pattern = "-?\d*\.?\d+"
|
||||
pred = re.findall(pattern, pred_str.replace(",", ""))
|
||||
if len(pred) >= 1:
|
||||
pred = pred[-1]
|
||||
|
@ -836,6 +837,12 @@ def parse_lines_in_parallel(
|
|||
# print("[debug: timeout]")
|
||||
logger.warning(f"Timeout occurred while justifying the math answer.")
|
||||
x = (0, "timeout", "timeout")
|
||||
except ProcessExpired as e:
|
||||
logger.warning(f"Process terminated abnormally: {e}")
|
||||
x = (0, "error", "error")
|
||||
except Exception as e:
|
||||
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
|
||||
x = (0, "error", "error")
|
||||
label = label or x[0]
|
||||
labels.append(label)
|
||||
return labels
|
||||
|
|
|
@ -57,6 +57,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
|
|||
self.id2info,
|
||||
answers,
|
||||
[qid for _ in range(group_size)],
|
||||
max_workers=1,
|
||||
)
|
||||
elif cur_task == "code":
|
||||
answers = [extract_code(x) for x in answers]
|
||||
|
@ -65,6 +66,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
|
|||
self.id2info,
|
||||
answers,
|
||||
[qid for _ in range(group_size)],
|
||||
max_workers=1,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
|
Loading…
Reference in New Issue