diff --git a/functioncall/code/local_verify.py b/functioncall/code/local_verify.py index d8c9ff6..a174ebb 100644 --- a/functioncall/code/local_verify.py +++ b/functioncall/code/local_verify.py @@ -95,7 +95,7 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT): return result["result"], result["info"] -def code_verify(id2info, generateds, query_ids, debug=False): +def code_verify(id2info, generateds, query_ids, max_workers=None, debug=False): assert len(generateds) == len(query_ids) problems = [id2info[qid] for qid in query_ids] @@ -106,8 +106,10 @@ def code_verify(id2info, generateds, query_ids, debug=False): infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT)) run_results = [] - num_process = max(1, os.cpu_count() // 8) - with concurrent.futures.ProcessPoolExecutor(num_process) as executor: + if max_workers is None: + max_workers = max(1, os.cpu_count() // 8) + + with concurrent.futures.ProcessPoolExecutor(max_workers) as executor: run_results = executor.map(call_verify, *zip(*infer_args)) for run_result in run_results: diff --git a/realhf/impl/dataset/math_parser.py b/realhf/impl/dataset/math_parser.py index 5189c5d..30e8bd1 100644 --- a/realhf/impl/dataset/math_parser.py +++ b/realhf/impl/dataset/math_parser.py @@ -7,7 +7,7 @@ from typing import List, Union import regex from latex2sympy2 import latex2sympy -from pebble import ProcessPool +from pebble import ProcessExpired, ProcessPool from sympy import N, simplify from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr @@ -289,6 +289,7 @@ def strip_string(string, skip_unit=False): # remove percentage string = string.replace("\\%", "") + string = string.replace("\%", "") string = string.replace("%", "") # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string @@ -398,7 +399,7 @@ def extract_answer(pred_str, data_name, use_last_number=True): pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() else: # use the last number if use_last_number: - pattern = r"-?\d*\.?\d+" + pattern = "-?\d*\.?\d+" pred = re.findall(pattern, pred_str.replace(",", "")) if len(pred) >= 1: pred = pred[-1] @@ -836,6 +837,12 @@ def parse_lines_in_parallel( # print("[debug: timeout]") logger.warning(f"Timeout occurred while justifying the math answer.") x = (0, "timeout", "timeout") + except ProcessExpired as e: + logger.warning(f"Process terminated abnormally: {e}") + x = (0, "error", "error") + except Exception as e: + logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}") + x = (0, "error", "error") label = label or x[0] labels.append(label) return labels diff --git a/realhf/impl/environment/math_code_single_step_env.py b/realhf/impl/environment/math_code_single_step_env.py index c07278f..ce9061e 100644 --- a/realhf/impl/environment/math_code_single_step_env.py +++ b/realhf/impl/environment/math_code_single_step_env.py @@ -57,6 +57,7 @@ class MathCodeSingleStepEnv(EnvironmentService): self.id2info, answers, [qid for _ in range(group_size)], + max_workers=1, ) elif cur_task == "code": answers = [extract_code(x) for x in answers] @@ -65,6 +66,7 @@ class MathCodeSingleStepEnv(EnvironmentService): self.id2info, answers, [qid for _ in range(group_size)], + max_workers=1, ) else: raise NotImplementedError()