mirror of https://github.com/inclusionAI/AReaL
PullRequest: 297 bugfix: reward is always -5
Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/297 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * bugfix: reward is always -5
This commit is contained in:
parent
a5cabddcea
commit
623f7c7407
|
@ -3,11 +3,11 @@
|
|||
import json
|
||||
import multiprocessing
|
||||
import re
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import List, Union
|
||||
|
||||
import regex
|
||||
from latex2sympy2 import latex2sympy
|
||||
from pebble import ProcessPool
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
@ -786,23 +786,6 @@ def process_results(answer, solution):
|
|||
return 0, ("None", "None")
|
||||
|
||||
|
||||
def process_results_process(a, b, output_queue):
|
||||
result = process_results(a, b)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
def verify_math_solution(answer: str, solution: str):
|
||||
# answer is generated by the model, solution is the ground truth
|
||||
tmp = call_with_timeout(
|
||||
process_results_process,
|
||||
answer,
|
||||
solution,
|
||||
)
|
||||
if isinstance(tmp, bool):
|
||||
return 0
|
||||
return tmp[0]
|
||||
|
||||
|
||||
def loadJson(dataDir):
|
||||
with open(dataDir, "r") as f:
|
||||
if dataDir.endswith(".jsonl"):
|
||||
|
@ -818,7 +801,7 @@ def parse_line(id2info, prompt_str, generated, query_id):
|
|||
|
||||
label = 0
|
||||
for sol in info["solutions"]:
|
||||
label = label or verify_math_solution(generated, sol)
|
||||
label = label or process_results(generated, sol)
|
||||
return label
|
||||
|
||||
|
||||
|
@ -835,21 +818,26 @@ def parse_lines_in_parallel(
|
|||
)
|
||||
|
||||
all_jobs = []
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
with ProcessPool(max_workers=max_workers) as executor:
|
||||
for qid, gen in zip(query_ids, generateds):
|
||||
info = id2info[qid.split("@idx:")[0]]
|
||||
jobs = []
|
||||
for sol in info["solutions"]:
|
||||
job = executor.submit(verify_math_solution, gen, sol)
|
||||
job = executor.schedule(process_results, args=[gen, sol], timeout=15)
|
||||
jobs.append(job)
|
||||
all_jobs.append(jobs)
|
||||
|
||||
labels = []
|
||||
for jobs in all_jobs:
|
||||
label = 0
|
||||
for job in as_completed(jobs):
|
||||
x = job.result()
|
||||
label = label or x
|
||||
for job in jobs:
|
||||
try:
|
||||
x = job.result()
|
||||
except TimeoutError:
|
||||
# print("[debug: timeout]")
|
||||
logger.warning(f"Timeout occurred while justifying the math answer.")
|
||||
x = (0, "timeout", "timeout")
|
||||
label = label or x[0]
|
||||
labels.append(label)
|
||||
return labels
|
||||
|
||||
|
|
Loading…
Reference in New Issue