mirror of https://github.com/inclusionAI/AReaL
399 lines
12 KiB
Python
399 lines
12 KiB
Python
"""This logic is largely copied from the Hendrycks' MATH release
|
|
(math_equivalence), and borrowed from:
|
|
|
|
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
|
|
- https://github.com/openai/prm800k
|
|
- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
|
|
- https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
|
|
"""
|
|
|
|
import multiprocessing
|
|
import re
|
|
from collections import defaultdict
|
|
from math import isclose
|
|
from typing import Union
|
|
|
|
import regex
|
|
from latex2sympy2 import latex2sympy
|
|
from sympy.parsing.latex import parse_latex
|
|
from sympy.parsing.sympy_parser import parse_expr
|
|
|
|
from sympy import N, simplify
|
|
|
|
# from .parser import choice_answer_clean, strip_string
|
|
# from parser import choice_answer_clean
|
|
|
|
|
|
def choice_answer_clean(pred: str):
|
|
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
|
|
# Clean the answer based on the dataset
|
|
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
|
if tmp:
|
|
pred = tmp
|
|
else:
|
|
pred = [pred.strip().strip(".")]
|
|
pred = pred[-1]
|
|
# Remove the period at the end, again!
|
|
pred = pred.rstrip(".").rstrip("/")
|
|
return pred
|
|
|
|
|
|
def parse_digits(num):
|
|
num = regex.sub(",", "", str(num))
|
|
try:
|
|
return float(num)
|
|
except:
|
|
if num.endswith("%"):
|
|
num = num[:-1]
|
|
if num.endswith("\\"):
|
|
num = num[:-1]
|
|
try:
|
|
return float(num) / 100
|
|
except:
|
|
pass
|
|
return None
|
|
|
|
|
|
def is_digit(num):
|
|
# paired with parse_digits
|
|
return parse_digits(num) is not None
|
|
|
|
|
|
def str_to_pmatrix(input_str):
|
|
input_str = input_str.strip()
|
|
matrix_str = re.findall(r"\{.*,.*\}", input_str)
|
|
pmatrix_list = []
|
|
|
|
for m in matrix_str:
|
|
m = m.strip("{}")
|
|
pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
|
|
pmatrix_list.append(pmatrix)
|
|
|
|
return ", ".join(pmatrix_list)
|
|
|
|
|
|
def math_equal(
|
|
prediction: Union[bool, float, str],
|
|
reference: Union[float, str],
|
|
include_percentage: bool = True,
|
|
is_close: bool = True,
|
|
timeout: bool = False,
|
|
) -> bool:
|
|
"""
|
|
Exact match of math if and only if:
|
|
1. numerical equal: both can convert to float and are equal
|
|
2. symbolic equal: both can convert to sympy expression and are equal
|
|
"""
|
|
# print("Judge:", prediction, reference)
|
|
if prediction is None or reference is None:
|
|
return False
|
|
if str(prediction.strip().lower()) == str(reference.strip().lower()):
|
|
return True
|
|
if (
|
|
reference in ["A", "B", "C", "D", "E"]
|
|
and choice_answer_clean(prediction) == reference
|
|
):
|
|
return True
|
|
|
|
try: # 1. numerical equal
|
|
if is_digit(prediction) and is_digit(reference):
|
|
prediction = parse_digits(prediction)
|
|
reference = parse_digits(reference)
|
|
# number questions
|
|
if include_percentage:
|
|
gt_result = [reference / 100, reference, reference * 100]
|
|
else:
|
|
gt_result = [reference]
|
|
for item in gt_result:
|
|
try:
|
|
if is_close:
|
|
if numeric_equal(prediction, item):
|
|
return True
|
|
else:
|
|
if item == prediction:
|
|
return True
|
|
except Exception:
|
|
continue
|
|
return False
|
|
except:
|
|
pass
|
|
|
|
if not prediction and prediction not in [0, False]:
|
|
return False
|
|
|
|
# 2. symbolic equal
|
|
reference = str(reference).strip()
|
|
prediction = str(prediction).strip()
|
|
|
|
## pmatrix (amps)
|
|
if "pmatrix" in prediction and not "pmatrix" in reference:
|
|
reference = str_to_pmatrix(reference)
|
|
|
|
## deal with [], (), {}
|
|
pred_str, ref_str = prediction, reference
|
|
if (
|
|
prediction.startswith("[")
|
|
and prediction.endswith("]")
|
|
and not reference.startswith("(")
|
|
) or (
|
|
prediction.startswith("(")
|
|
and prediction.endswith(")")
|
|
and not reference.startswith("[")
|
|
):
|
|
pred_str = pred_str.strip("[]()")
|
|
ref_str = ref_str.strip("[]()")
|
|
for s in ["{", "}", "(", ")"]:
|
|
ref_str = ref_str.replace(s, "")
|
|
pred_str = pred_str.replace(s, "")
|
|
if pred_str.lower() == ref_str.lower():
|
|
return True
|
|
|
|
## [a, b] vs. [c, d], return a==c and b==d
|
|
if (
|
|
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
|
|
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
|
|
):
|
|
pred_parts = prediction[1:-1].split(",")
|
|
ref_parts = reference[1:-1].split(",")
|
|
if len(pred_parts) == len(ref_parts):
|
|
if all(
|
|
[
|
|
math_equal(
|
|
pred_parts[i], ref_parts[i], include_percentage, is_close
|
|
)
|
|
for i in range(len(pred_parts))
|
|
]
|
|
):
|
|
return True
|
|
if (
|
|
(
|
|
prediction.startswith("\\begin{pmatrix}")
|
|
or prediction.startswith("\\begin{bmatrix}")
|
|
)
|
|
and (
|
|
prediction.endswith("\\end{pmatrix}")
|
|
or prediction.endswith("\\end{bmatrix}")
|
|
)
|
|
and (
|
|
reference.startswith("\\begin{pmatrix}")
|
|
or reference.startswith("\\begin{bmatrix}")
|
|
)
|
|
and (
|
|
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
|
|
)
|
|
):
|
|
pred_lines = [
|
|
line.strip()
|
|
for line in prediction[
|
|
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
|
].split("\\\\")
|
|
if line.strip()
|
|
]
|
|
ref_lines = [
|
|
line.strip()
|
|
for line in reference[
|
|
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
|
].split("\\\\")
|
|
if line.strip()
|
|
]
|
|
matched = True
|
|
if len(pred_lines) == len(ref_lines):
|
|
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
|
pred_parts = pred_line.split("&")
|
|
ref_parts = ref_line.split("&")
|
|
if len(pred_parts) == len(ref_parts):
|
|
if not all(
|
|
[
|
|
math_equal(
|
|
pred_parts[i],
|
|
ref_parts[i],
|
|
include_percentage,
|
|
is_close,
|
|
)
|
|
for i in range(len(pred_parts))
|
|
]
|
|
):
|
|
matched = False
|
|
break
|
|
else:
|
|
matched = False
|
|
if not matched:
|
|
break
|
|
else:
|
|
matched = False
|
|
if matched:
|
|
return True
|
|
|
|
if prediction.count("=") == 1 and reference.count("=") == 1:
|
|
pred = prediction.split("=")
|
|
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
|
ref = reference.split("=")
|
|
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
|
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
|
return True
|
|
elif (
|
|
prediction.count("=") == 1
|
|
and len(prediction.split("=")[0].strip()) <= 2
|
|
and "=" not in reference
|
|
):
|
|
if math_equal(
|
|
prediction.split("=")[1], reference, include_percentage, is_close
|
|
):
|
|
return True
|
|
elif (
|
|
reference.count("=") == 1
|
|
and len(reference.split("=")[0].strip()) <= 2
|
|
and "=" not in prediction
|
|
):
|
|
if math_equal(
|
|
prediction, reference.split("=")[1], include_percentage, is_close
|
|
):
|
|
return True
|
|
|
|
# symbolic equal with sympy
|
|
if timeout:
|
|
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
|
return True
|
|
else:
|
|
if symbolic_equal(prediction, reference):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def math_equal_process(param):
|
|
return math_equal(param[-2], param[-1])
|
|
|
|
|
|
def numeric_equal(prediction: float, reference: float):
|
|
# Note that relative tolerance has significant impact
|
|
# on the result of the synthesized GSM-Hard dataset
|
|
# if reference.is_integer():
|
|
# return isclose(reference, round(prediction), abs_tol=1e-4)
|
|
# else:
|
|
# prediction = round(prediction, len(str(reference).split(".")[-1]))
|
|
return isclose(reference, prediction, rel_tol=1e-4)
|
|
|
|
|
|
def symbolic_equal(a, b):
|
|
def _parse(s):
|
|
for f in [parse_latex, parse_expr, latex2sympy]:
|
|
try:
|
|
return f(s.replace("\\\\", "\\"))
|
|
except:
|
|
try:
|
|
return f(s)
|
|
except:
|
|
pass
|
|
return s
|
|
|
|
a = _parse(a)
|
|
b = _parse(b)
|
|
|
|
# direct equal
|
|
try:
|
|
if str(a) == str(b) or a == b:
|
|
return True
|
|
except:
|
|
pass
|
|
|
|
# simplify equal
|
|
try:
|
|
if a.equals(b) or simplify(a - b) == 0:
|
|
return True
|
|
except:
|
|
pass
|
|
|
|
# equation equal
|
|
try:
|
|
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
|
|
return True
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
if numeric_equal(float(N(a)), float(N(b))):
|
|
return True
|
|
except:
|
|
pass
|
|
|
|
# matrix
|
|
try:
|
|
# if a and b are matrix
|
|
if a.shape == b.shape:
|
|
_a = a.applyfunc(lambda x: round(x, 3))
|
|
_b = b.applyfunc(lambda x: round(x, 3))
|
|
if _a.equals(_b):
|
|
return True
|
|
except:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
def symbolic_equal_process(a, b, output_queue):
|
|
result = symbolic_equal(a, b)
|
|
output_queue.put(result)
|
|
|
|
|
|
def call_with_timeout(func, *args, timeout=3, **kwargs):
|
|
output_queue = multiprocessing.Queue()
|
|
process_args = args + (output_queue,)
|
|
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
|
process.start()
|
|
process.join(timeout)
|
|
|
|
if process.is_alive():
|
|
process.terminate()
|
|
process.join()
|
|
return False
|
|
|
|
return output_queue.get()
|
|
|
|
|
|
def _test_math_equal():
|
|
# print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
|
|
# print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
|
|
# print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
|
|
# print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True))
|
|
# print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True))
|
|
|
|
# pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}'
|
|
# gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})'
|
|
|
|
# pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}'
|
|
# gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}'
|
|
|
|
# pred = '-34x-45y+20z-100=0'
|
|
# gt = '34x+45y-20z+100=0'
|
|
|
|
# pred = '\\frac{100}{3}'
|
|
# gt = '33.3'
|
|
|
|
# pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}'
|
|
# gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})'
|
|
|
|
# pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}'
|
|
# gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}'
|
|
|
|
# pred = '(+5)(b+2)'
|
|
# gt = '(a+5)(b+2)'
|
|
|
|
# pred = '\\frac{1+\\sqrt{5}}{2}'
|
|
# gt = '2'
|
|
|
|
# pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4'
|
|
# pred = '1', gt = '1\\\\sqrt{19}'
|
|
|
|
# pred = "(0.6,2.6667]"
|
|
# gt = "(\\frac{3}{5},\\frac{8}{3}]"
|
|
|
|
gt = "x+0.5+0.5"
|
|
pred = "x+1"
|
|
|
|
print(math_equal(pred, gt, timeout=True))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_test_math_equal()
|