mirror of https://github.com/inclusionAI/AReaL
Merge branch 'main' of https://github.com/inclusionAI/AReaL into lite
This commit is contained in:
commit
d48bf007cf
Binary file not shown.
Before Width: | Height: | Size: 163 KiB After Width: | Height: | Size: 2.7 KiB |
|
@ -95,7 +95,7 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
|
|||
return result["result"], result["info"]
|
||||
|
||||
|
||||
def code_verify(id2info, generateds, query_ids, debug=False):
|
||||
def code_verify(id2info, generateds, query_ids, max_workers=None, debug=False):
|
||||
assert len(generateds) == len(query_ids)
|
||||
problems = [id2info[qid] for qid in query_ids]
|
||||
|
||||
|
@ -106,8 +106,10 @@ def code_verify(id2info, generateds, query_ids, debug=False):
|
|||
infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT))
|
||||
|
||||
run_results = []
|
||||
num_process = max(1, os.cpu_count() // 8)
|
||||
with concurrent.futures.ProcessPoolExecutor(num_process) as executor:
|
||||
if max_workers is None:
|
||||
max_workers = max(1, os.cpu_count() // 8)
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers) as executor:
|
||||
run_results = executor.map(call_verify, *zip(*infer_args))
|
||||
|
||||
for run_result in run_results:
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import List, Union
|
|||
|
||||
import regex
|
||||
from latex2sympy2 import latex2sympy
|
||||
from pebble import ProcessPool
|
||||
from pebble import ProcessExpired, ProcessPool
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
@ -289,6 +289,7 @@ def strip_string(string, skip_unit=False):
|
|||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace("\%", "")
|
||||
string = string.replace("%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
|
@ -398,7 +399,7 @@ def extract_answer(pred_str, data_name, use_last_number=True):
|
|||
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
|
||||
else: # use the last number
|
||||
if use_last_number:
|
||||
pattern = r"-?\d*\.?\d+"
|
||||
pattern = "-?\d*\.?\d+"
|
||||
pred = re.findall(pattern, pred_str.replace(",", ""))
|
||||
if len(pred) >= 1:
|
||||
pred = pred[-1]
|
||||
|
@ -836,6 +837,12 @@ def parse_lines_in_parallel(
|
|||
# print("[debug: timeout]")
|
||||
logger.warning(f"Timeout occurred while justifying the math answer.")
|
||||
x = (0, "timeout", "timeout")
|
||||
except ProcessExpired as e:
|
||||
logger.warning(f"Process terminated abnormally: {e}")
|
||||
x = (0, "error", "error")
|
||||
except Exception as e:
|
||||
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
|
||||
x = (0, "error", "error")
|
||||
label = label or x[0]
|
||||
labels.append(label)
|
||||
return labels
|
||||
|
|
|
@ -57,6 +57,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
|
|||
self.id2info,
|
||||
answers,
|
||||
[qid for _ in range(group_size)],
|
||||
max_workers=1,
|
||||
)
|
||||
elif cur_task == "code":
|
||||
answers = [extract_code(x) for x in answers]
|
||||
|
@ -65,6 +66,7 @@ class MathCodeSingleStepEnv(EnvironmentService):
|
|||
self.id2info,
|
||||
answers,
|
||||
[qid for _ in range(group_size)],
|
||||
max_workers=1,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
|
Loading…
Reference in New Issue