mirror of https://github.com/inclusionAI/AReaL
876 lines
24 KiB
Python
876 lines
24 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
|
|
import json
|
|
import multiprocessing
|
|
import re
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
from typing import List, Union
|
|
|
|
import regex
|
|
from latex2sympy2 import latex2sympy
|
|
from sympy import N, simplify
|
|
from sympy.parsing.latex import parse_latex
|
|
from sympy.parsing.sympy_parser import parse_expr
|
|
from word2number import w2n
|
|
|
|
from realhf.base import logging
|
|
|
|
logger = logging.getLogger("math parser")
|
|
|
|
# units mainly from MathQA
|
|
unit_texts = [
|
|
"east",
|
|
"degree",
|
|
"mph",
|
|
"kmph",
|
|
"ft",
|
|
"m sqaure",
|
|
" m east",
|
|
"sq m",
|
|
"deg",
|
|
"mile",
|
|
"q .",
|
|
"monkey",
|
|
"prime",
|
|
"ratio",
|
|
"profit of rs",
|
|
"rd",
|
|
"o",
|
|
"gm",
|
|
"p . m",
|
|
"lb",
|
|
"tile",
|
|
"per",
|
|
"dm",
|
|
"lt",
|
|
"gain",
|
|
"ab",
|
|
"way",
|
|
"west",
|
|
"a .",
|
|
"b .",
|
|
"c .",
|
|
"d .",
|
|
"e .",
|
|
"f .",
|
|
"g .",
|
|
"h .",
|
|
"t",
|
|
"a",
|
|
"h",
|
|
"no change",
|
|
"men",
|
|
"soldier",
|
|
"pie",
|
|
"bc",
|
|
"excess",
|
|
"st",
|
|
"inches",
|
|
"noon",
|
|
"percent",
|
|
"by",
|
|
"gal",
|
|
"kmh",
|
|
"c",
|
|
"acre",
|
|
"rise",
|
|
"a . m",
|
|
"th",
|
|
"π r 2",
|
|
"sq",
|
|
"mark",
|
|
"l",
|
|
"toy",
|
|
"coin",
|
|
"sq . m",
|
|
"gallon",
|
|
"° f",
|
|
"profit",
|
|
"minw",
|
|
"yr",
|
|
"women",
|
|
"feet",
|
|
"am",
|
|
"pm",
|
|
"hr",
|
|
"cu cm",
|
|
"square",
|
|
"v â € ™",
|
|
"are",
|
|
"rupee",
|
|
"rounds",
|
|
"cubic",
|
|
"cc",
|
|
"mtr",
|
|
"s",
|
|
"ohm",
|
|
"number",
|
|
"kmph",
|
|
"day",
|
|
"hour",
|
|
"minute",
|
|
"min",
|
|
"second",
|
|
"man",
|
|
"woman",
|
|
"sec",
|
|
"cube",
|
|
"mt",
|
|
"sq inch",
|
|
"mp",
|
|
"∏ cm ³",
|
|
"hectare",
|
|
"more",
|
|
"sec",
|
|
"unit",
|
|
"cu . m",
|
|
"cm 2",
|
|
"rs .",
|
|
"rs",
|
|
"kg",
|
|
"g",
|
|
"month",
|
|
"km",
|
|
"m",
|
|
"cm",
|
|
"mm",
|
|
"apple",
|
|
"liter",
|
|
"loss",
|
|
"yard",
|
|
"pure",
|
|
"year",
|
|
"increase",
|
|
"decrease",
|
|
"d",
|
|
"less",
|
|
"Surface",
|
|
"litre",
|
|
"pi sq m",
|
|
"s .",
|
|
"metre",
|
|
"meter",
|
|
"inch",
|
|
]
|
|
|
|
unit_texts.extend([t + "s" for t in unit_texts])
|
|
|
|
|
|
def _fix_fracs(string):
|
|
substrs = string.split("\\frac")
|
|
new_str = substrs[0]
|
|
if len(substrs) > 1:
|
|
substrs = substrs[1:]
|
|
for substr in substrs:
|
|
new_str += "\\frac"
|
|
if len(substr) > 0 and substr[0] == "{":
|
|
new_str += substr
|
|
else:
|
|
try:
|
|
assert len(substr) >= 2
|
|
except:
|
|
return string
|
|
a = substr[0]
|
|
b = substr[1]
|
|
if b != "{":
|
|
if len(substr) > 2:
|
|
post_substr = substr[2:]
|
|
new_str += "{" + a + "}{" + b + "}" + post_substr
|
|
else:
|
|
new_str += "{" + a + "}{" + b + "}"
|
|
else:
|
|
if len(substr) > 2:
|
|
post_substr = substr[2:]
|
|
new_str += "{" + a + "}" + b + post_substr
|
|
else:
|
|
new_str += "{" + a + "}" + b
|
|
string = new_str
|
|
return string
|
|
|
|
|
|
def _fix_a_slash_b(string):
|
|
if len(string.split("/")) != 2:
|
|
return string
|
|
a = string.split("/")[0]
|
|
b = string.split("/")[1]
|
|
try:
|
|
if "sqrt" not in a:
|
|
a = int(a)
|
|
if "sqrt" not in b:
|
|
b = int(b)
|
|
assert string == "{}/{}".format(a, b)
|
|
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
|
return new_string
|
|
except:
|
|
return string
|
|
|
|
|
|
def _fix_sqrt(string):
|
|
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
|
return _string
|
|
|
|
|
|
def convert_word_number(text: str) -> str:
|
|
try:
|
|
text = str(w2n.word_to_num(text))
|
|
except:
|
|
pass
|
|
return text
|
|
|
|
|
|
def strip_string(string, skip_unit=False):
|
|
string = str(string).strip()
|
|
# linebreaks
|
|
string = string.replace("\n", "")
|
|
|
|
# right "."
|
|
string = string.rstrip(".")
|
|
|
|
# remove inverse spaces
|
|
# replace \\ with \
|
|
string = string.replace("\\!", "")
|
|
# string = string.replace("\\ ", "")
|
|
# string = string.replace("\\\\", "\\")
|
|
|
|
# matrix
|
|
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
|
|
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
|
|
string = string.replace("bmatrix", "pmatrix")
|
|
|
|
# replace tfrac and dfrac with frac
|
|
string = string.replace("tfrac", "frac")
|
|
string = string.replace("dfrac", "frac")
|
|
string = (
|
|
string.replace("\\neq", "\\ne")
|
|
.replace("\\leq", "\\le")
|
|
.replace("\\geq", "\\ge")
|
|
)
|
|
|
|
# remove \left and \right
|
|
string = string.replace("\\left", "")
|
|
string = string.replace("\\right", "")
|
|
string = string.replace("\\{", "{")
|
|
string = string.replace("\\}", "}")
|
|
|
|
# Remove unit: miles, dollars if after is not none
|
|
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
|
if _string != "" and _string != string:
|
|
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
|
string = _string
|
|
|
|
if not skip_unit:
|
|
# Remove unit: texts
|
|
for _ in range(2):
|
|
for unit_text in unit_texts:
|
|
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
|
|
# the suffix should be either the end of the string or a non-alphanumeric character
|
|
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
|
|
if _string != "":
|
|
string = _string
|
|
|
|
# Remove circ (degrees)
|
|
string = string.replace("^{\\circ}", "")
|
|
string = string.replace("^\\circ", "")
|
|
|
|
# remove dollar signs
|
|
string = string.replace("\\$", "")
|
|
string = string.replace("$", "")
|
|
string = string.replace("\\(", "").replace("\\)", "")
|
|
|
|
# convert word number to digit
|
|
string = convert_word_number(string)
|
|
|
|
# replace "\\text{...}" to "..."
|
|
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
|
|
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
|
|
string = string.replace(key, "")
|
|
string = string.replace("\\emptyset", r"{}")
|
|
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
|
|
|
|
# remove percentage
|
|
string = string.replace("\\%", "")
|
|
string = string.replace("\%", "")
|
|
string = string.replace("%", "")
|
|
|
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
|
string = string.replace(" .", " 0.")
|
|
string = string.replace("{.", "{0.")
|
|
|
|
# cdot
|
|
# string = string.replace("\\cdot", "")
|
|
if (
|
|
string.startswith("{")
|
|
and string.endswith("}")
|
|
and string.isalnum()
|
|
or string.startswith("(")
|
|
and string.endswith(")")
|
|
and string.isalnum()
|
|
or string.startswith("[")
|
|
and string.endswith("]")
|
|
and string.isalnum()
|
|
):
|
|
string = string[1:-1]
|
|
|
|
# inf
|
|
string = string.replace("infinity", "\\infty")
|
|
if "\\infty" not in string:
|
|
string = string.replace("inf", "\\infty")
|
|
string = string.replace("+\\inity", "\\infty")
|
|
|
|
# and
|
|
string = string.replace("and", "")
|
|
string = string.replace("\\mathbf", "")
|
|
|
|
# use regex to remove \mbox{...}
|
|
string = re.sub(r"\\mbox{.*?}", "", string)
|
|
|
|
# quote
|
|
string.replace("'", "")
|
|
string.replace('"', "")
|
|
|
|
# i, j
|
|
if "j" in string and "i" not in string:
|
|
string = string.replace("j", "i")
|
|
|
|
# replace a.000b where b is not number or b is end, with ab, use regex
|
|
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
|
|
string = re.sub(r"(\d+)\.0*$", r"\1", string)
|
|
|
|
# if empty, return empty string
|
|
if len(string) == 0:
|
|
return string
|
|
if string[0] == ".":
|
|
string = "0" + string
|
|
|
|
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
|
if len(string.split("=")) == 2:
|
|
if len(string.split("=")[0]) <= 2:
|
|
string = string.split("=")[1]
|
|
|
|
string = _fix_sqrt(string)
|
|
string = string.replace(" ", "")
|
|
|
|
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
|
string = _fix_fracs(string)
|
|
|
|
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
|
string = _fix_a_slash_b(string)
|
|
|
|
return string
|
|
|
|
|
|
def extract_answer(pred_str, data_name, use_last_number=True):
|
|
pred_str = pred_str.replace("\u043a\u0438", "")
|
|
if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]:
|
|
# TODO check multiple choice
|
|
return choice_answer_clean(pred_str)
|
|
|
|
if "final answer is $" in pred_str and "$. I hope" in pred_str:
|
|
# minerva_math
|
|
tmp = pred_str.split("final answer is $", 1)[1]
|
|
pred = tmp.split("$. I hope", 1)[0].strip()
|
|
elif "boxed" in pred_str:
|
|
ans = pred_str.split("boxed")[-1]
|
|
if len(ans) == 0:
|
|
return ""
|
|
elif ans[0] == "{":
|
|
stack = 1
|
|
a = ""
|
|
for c in ans[1:]:
|
|
if c == "{":
|
|
stack += 1
|
|
a += c
|
|
elif c == "}":
|
|
stack -= 1
|
|
if stack == 0:
|
|
break
|
|
a += c
|
|
else:
|
|
a += c
|
|
else:
|
|
a = ans.split("$")[0].strip()
|
|
pred = a
|
|
elif "he answer is" in pred_str:
|
|
pred = pred_str.split("he answer is")[-1].strip()
|
|
elif "final answer is" in pred_str:
|
|
pred = pred_str.split("final answer is")[-1].strip()
|
|
elif "答案是" in pred_str:
|
|
# Handle Chinese few-shot multiple choice problem answer extraction
|
|
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
|
|
else: # use the last number
|
|
if use_last_number:
|
|
pattern = "-?\d*\.?\d+"
|
|
pred = re.findall(pattern, pred_str.replace(",", ""))
|
|
if len(pred) >= 1:
|
|
pred = pred[-1]
|
|
else:
|
|
pred = ""
|
|
else:
|
|
pred = ""
|
|
|
|
# choice answer
|
|
if data_name in ["sat_math", "aqua"] or "mmlu" in data_name:
|
|
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
|
if tmp:
|
|
pred = tmp[-1]
|
|
else:
|
|
pred = pred.strip().strip(".")
|
|
|
|
# multiple line
|
|
# pred = pred.split("\n")[0]
|
|
pred = re.sub(r"\n\s*", "", pred)
|
|
if pred != "" and pred[0] == ":":
|
|
pred = pred[1:]
|
|
if pred != "" and pred[-1] == ".":
|
|
pred = pred[:-1]
|
|
if pred != "" and pred[-1] == "/":
|
|
pred = pred[:-1]
|
|
pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"])
|
|
return pred
|
|
|
|
|
|
def str_to_pmatrix(input_str):
|
|
input_str = input_str.strip()
|
|
matrix_str = re.findall(r"\{.*,.*\}", input_str)
|
|
pmatrix_list = []
|
|
|
|
for m in matrix_str:
|
|
m = m.strip("{}")
|
|
pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
|
|
pmatrix_list.append(pmatrix)
|
|
|
|
return ", ".join(pmatrix_list)
|
|
|
|
|
|
def parse_digits(num):
|
|
num = regex.sub(",", "", str(num))
|
|
try:
|
|
return float(num)
|
|
except:
|
|
if num.endswith("%"):
|
|
num = num[:-1]
|
|
if num.endswith("\\"):
|
|
num = num[:-1]
|
|
try:
|
|
return float(num) / 100
|
|
except:
|
|
pass
|
|
return None
|
|
|
|
|
|
def is_digit(num):
|
|
# paired with parse_digits
|
|
return parse_digits(num) is not None
|
|
|
|
|
|
def choice_answer_clean(pred: str):
|
|
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
|
|
# Clean the answer based on the dataset
|
|
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
|
if tmp:
|
|
pred = tmp
|
|
else:
|
|
pred = [pred.strip().strip(".")]
|
|
pred = pred[-1]
|
|
# Remove the period at the end, again!
|
|
pred = pred.rstrip(".").rstrip("/")
|
|
return pred
|
|
|
|
|
|
def numeric_equal(prediction: float, reference: float):
|
|
from math import isclose
|
|
|
|
# Note that relative tolerance has significant impact
|
|
# on the result of the synthesized GSM-Hard dataset
|
|
# if reference.is_integer():
|
|
# return isclose(reference, round(prediction), abs_tol=1e-4)
|
|
# else:
|
|
# prediction = round(prediction, len(str(reference).split(".")[-1]))
|
|
return isclose(reference, prediction, rel_tol=1e-4)
|
|
|
|
|
|
def symbolic_equal_process(a, b, output_queue):
|
|
result = symbolic_equal(a, b)
|
|
output_queue.put(result)
|
|
|
|
|
|
def math_equal(
|
|
prediction: Union[bool, float, str],
|
|
reference: Union[float, str],
|
|
include_percentage: bool = True,
|
|
is_close: bool = True,
|
|
timeout: bool = False,
|
|
) -> bool:
|
|
"""
|
|
Exact match of math if and only if:
|
|
1. numerical equal: both can convert to float and are equal
|
|
2. symbolic equal: both can convert to sympy expression and are equal
|
|
"""
|
|
# print("Judge:", prediction, reference)
|
|
if prediction is None or reference is None:
|
|
return False
|
|
if str(prediction.strip().lower()) == str(reference.strip().lower()):
|
|
return True
|
|
if (
|
|
reference in ["A", "B", "C", "D", "E"]
|
|
and choice_answer_clean(prediction) == reference
|
|
):
|
|
return True
|
|
|
|
try: # 1. numerical equal
|
|
if is_digit(prediction) and is_digit(reference):
|
|
prediction = parse_digits(prediction)
|
|
reference = parse_digits(reference)
|
|
# number questions
|
|
if include_percentage:
|
|
gt_result = [reference / 100, reference, reference * 100]
|
|
else:
|
|
gt_result = [reference]
|
|
for item in gt_result:
|
|
try:
|
|
if is_close:
|
|
if numeric_equal(prediction, item):
|
|
return True
|
|
else:
|
|
if item == prediction:
|
|
return True
|
|
except Exception:
|
|
continue
|
|
return False
|
|
except:
|
|
pass
|
|
|
|
if not prediction and prediction not in [0, False]:
|
|
return False
|
|
|
|
# 2. symbolic equal
|
|
reference = str(reference).strip()
|
|
prediction = str(prediction).strip()
|
|
|
|
## pmatrix (amps)
|
|
if "pmatrix" in prediction and not "pmatrix" in reference:
|
|
reference = str_to_pmatrix(reference)
|
|
|
|
## deal with [], (), {}
|
|
pred_str, ref_str = prediction, reference
|
|
if (
|
|
prediction.startswith("[")
|
|
and prediction.endswith("]")
|
|
and not reference.startswith("(")
|
|
) or (
|
|
prediction.startswith("(")
|
|
and prediction.endswith(")")
|
|
and not reference.startswith("[")
|
|
):
|
|
pred_str = pred_str.strip("[]()")
|
|
ref_str = ref_str.strip("[]()")
|
|
for s in ["{", "}", "(", ")"]:
|
|
ref_str = ref_str.replace(s, "")
|
|
pred_str = pred_str.replace(s, "")
|
|
if pred_str.lower() == ref_str.lower():
|
|
return True
|
|
|
|
## [a, b] vs. [c, d], return a==c and b==d
|
|
if (
|
|
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
|
|
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
|
|
):
|
|
pred_parts = prediction[1:-1].split(",")
|
|
ref_parts = reference[1:-1].split(",")
|
|
if len(pred_parts) == len(ref_parts):
|
|
if all(
|
|
[
|
|
math_equal(
|
|
pred_parts[i], ref_parts[i], include_percentage, is_close
|
|
)
|
|
for i in range(len(pred_parts))
|
|
]
|
|
):
|
|
return True
|
|
if (
|
|
(
|
|
prediction.startswith("\\begin{pmatrix}")
|
|
or prediction.startswith("\\begin{bmatrix}")
|
|
)
|
|
and (
|
|
prediction.endswith("\\end{pmatrix}")
|
|
or prediction.endswith("\\end{bmatrix}")
|
|
)
|
|
and (
|
|
reference.startswith("\\begin{pmatrix}")
|
|
or reference.startswith("\\begin{bmatrix}")
|
|
)
|
|
and (
|
|
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
|
|
)
|
|
):
|
|
pred_lines = [
|
|
line.strip()
|
|
for line in prediction[
|
|
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
|
].split("\\\\")
|
|
if line.strip()
|
|
]
|
|
ref_lines = [
|
|
line.strip()
|
|
for line in reference[
|
|
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
|
].split("\\\\")
|
|
if line.strip()
|
|
]
|
|
matched = True
|
|
if len(pred_lines) == len(ref_lines):
|
|
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
|
pred_parts = pred_line.split("&")
|
|
ref_parts = ref_line.split("&")
|
|
if len(pred_parts) == len(ref_parts):
|
|
if not all(
|
|
[
|
|
math_equal(
|
|
pred_parts[i],
|
|
ref_parts[i],
|
|
include_percentage,
|
|
is_close,
|
|
)
|
|
for i in range(len(pred_parts))
|
|
]
|
|
):
|
|
matched = False
|
|
break
|
|
else:
|
|
matched = False
|
|
if not matched:
|
|
break
|
|
else:
|
|
matched = False
|
|
if matched:
|
|
return True
|
|
|
|
if prediction.count("=") == 1 and reference.count("=") == 1:
|
|
pred = prediction.split("=")
|
|
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
|
ref = reference.split("=")
|
|
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
|
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
|
return True
|
|
elif (
|
|
prediction.count("=") == 1
|
|
and len(prediction.split("=")[0].strip()) <= 2
|
|
and "=" not in reference
|
|
):
|
|
if math_equal(
|
|
prediction.split("=")[1], reference, include_percentage, is_close
|
|
):
|
|
return True
|
|
elif (
|
|
reference.count("=") == 1
|
|
and len(reference.split("=")[0].strip()) <= 2
|
|
and "=" not in prediction
|
|
):
|
|
if math_equal(
|
|
prediction, reference.split("=")[1], include_percentage, is_close
|
|
):
|
|
return True
|
|
|
|
# symbolic equal with sympy
|
|
if timeout:
|
|
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
|
return True
|
|
else:
|
|
if symbolic_equal(prediction, reference):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def call_with_timeout(func, *args, timeout=3, **kwargs):
|
|
output_queue = multiprocessing.Queue()
|
|
process_args = args + (output_queue,)
|
|
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
|
process.start()
|
|
process.join(timeout)
|
|
|
|
if process.is_alive():
|
|
process.terminate()
|
|
process.join()
|
|
return False
|
|
|
|
return output_queue.get()
|
|
|
|
|
|
def math_equal_process(param):
|
|
return math_equal(param[-2], param[-1])
|
|
|
|
|
|
def symbolic_equal(a, b):
|
|
def _parse(s):
|
|
for f in [parse_latex, parse_expr, latex2sympy]:
|
|
try:
|
|
return f(s.replace("\\\\", "\\"))
|
|
except:
|
|
try:
|
|
return f(s)
|
|
except:
|
|
pass
|
|
return s
|
|
|
|
a = _parse(a)
|
|
b = _parse(b)
|
|
|
|
# direct equal
|
|
try:
|
|
if str(a) == str(b) or a == b:
|
|
return True
|
|
except:
|
|
pass
|
|
|
|
# simplify equal
|
|
try:
|
|
if a.equals(b) or simplify(a - b) == 0:
|
|
return True
|
|
except:
|
|
pass
|
|
|
|
# equation equal
|
|
try:
|
|
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
|
|
return True
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
if numeric_equal(float(N(a)), float(N(b))):
|
|
return True
|
|
except:
|
|
pass
|
|
|
|
# matrix
|
|
try:
|
|
# if a and b are matrix
|
|
if a.shape == b.shape:
|
|
_a = a.applyfunc(lambda x: round(x, 3))
|
|
_b = b.applyfunc(lambda x: round(x, 3))
|
|
if _a.equals(_b):
|
|
return True
|
|
except:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
def process_results(answer, solution):
|
|
|
|
try:
|
|
extracted_answer = extract_answer(answer, "math", use_last_number=False)
|
|
extracted_solution = extract_answer(solution, "math", use_last_number=True)
|
|
|
|
# if extract_answer.strip() == "":
|
|
# print (answer)
|
|
# raise
|
|
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
|
|
retval = 0
|
|
elif extracted_solution is None or extracted_solution.strip() in [
|
|
"None",
|
|
"none",
|
|
"",
|
|
]:
|
|
retval = 0
|
|
elif math_equal(extracted_answer, extracted_solution, timeout=False):
|
|
# elif call_with_timeout(math_equal, extracted_answer, extracted_solution):
|
|
retval = 1
|
|
else:
|
|
retval = 0
|
|
|
|
return retval, (extracted_answer, extracted_solution)
|
|
except:
|
|
return 0, ("None", "None")
|
|
|
|
|
|
def process_results_process(a, b, output_queue):
|
|
result = process_results(a, b)
|
|
output_queue.put(result)
|
|
|
|
|
|
def verify_math_solution(answer: str, solution: str):
|
|
# answer is generated by the model, solution is the ground truth
|
|
tmp = call_with_timeout(
|
|
process_results_process,
|
|
answer,
|
|
solution,
|
|
)
|
|
if isinstance(tmp, bool):
|
|
return 0
|
|
return tmp[0]
|
|
|
|
|
|
def loadJson(dataDir):
|
|
with open(dataDir, "r") as f:
|
|
if dataDir.endswith(".jsonl"):
|
|
samples = [json.loads(line) for line in f.readlines()]
|
|
else:
|
|
samples = json.load(f)
|
|
|
|
return samples
|
|
|
|
|
|
def parse_line(id2info, prompt_str, generated, query_id):
|
|
info = id2info[query_id.split("@idx:")[0]]
|
|
|
|
label = 0
|
|
for sol in info["solutions"]:
|
|
label = label or verify_math_solution(generated, sol)
|
|
return label
|
|
|
|
|
|
def parse_lines_in_parallel(
|
|
id2info,
|
|
generateds: List,
|
|
query_ids: List,
|
|
max_workers=22,
|
|
check_xml_format=False,
|
|
) -> List:
|
|
assert len(generateds) == len(query_ids), (
|
|
len(generateds),
|
|
len(query_ids),
|
|
)
|
|
|
|
all_jobs = []
|
|
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
for qid, gen in zip(query_ids, generateds):
|
|
info = id2info[qid.split("@idx:")[0]]
|
|
jobs = []
|
|
for sol in info["solutions"]:
|
|
job = executor.submit(verify_math_solution, gen, sol)
|
|
jobs.append(job)
|
|
all_jobs.append(jobs)
|
|
|
|
labels = []
|
|
for jobs in all_jobs:
|
|
label = 0
|
|
for job in as_completed(jobs):
|
|
x = job.result()
|
|
label = label or x
|
|
labels.append(label)
|
|
return labels
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sample = {
|
|
"answers": ["\\boxed{-\\frac{2}{3}}"],
|
|
"solutions": [
|
|
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
|
|
"\\boxed{-\\frac{2}{3}}",
|
|
],
|
|
}
|
|
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
|
|
|
|
print(
|
|
parse_lines_in_parallel(
|
|
id2info,
|
|
sample["answers"] * 100,
|
|
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
|
|
max_workers=8,
|
|
check_xml_format=True,
|
|
)
|
|
)
|