AReaL/realhf/impl/dataset/math_parser.py

870 lines
24 KiB
Python

# Copyright 2025 Ant Group Inc.
import json
import multiprocessing
import re
from typing import List, Union
import regex
from latex2sympy2 import latex2sympy
from pebble import ProcessExpired, ProcessPool
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
from word2number import w2n
from realhf.base import logging
logger = logging.getLogger("math parser")
# units mainly from MathQA
unit_texts = [
"east",
"degree",
"mph",
"kmph",
"ft",
"m sqaure",
" m east",
"sq m",
"deg",
"mile",
"q .",
"monkey",
"prime",
"ratio",
"profit of rs",
"rd",
"o",
"gm",
"p . m",
"lb",
"tile",
"per",
"dm",
"lt",
"gain",
"ab",
"way",
"west",
"a .",
"b .",
"c .",
"d .",
"e .",
"f .",
"g .",
"h .",
"t",
"a",
"h",
"no change",
"men",
"soldier",
"pie",
"bc",
"excess",
"st",
"inches",
"noon",
"percent",
"by",
"gal",
"kmh",
"c",
"acre",
"rise",
"a . m",
"th",
"π r 2",
"sq",
"mark",
"l",
"toy",
"coin",
"sq . m",
"gallon",
"° f",
"profit",
"minw",
"yr",
"women",
"feet",
"am",
"pm",
"hr",
"cu cm",
"square",
"v â € ™",
"are",
"rupee",
"rounds",
"cubic",
"cc",
"mtr",
"s",
"ohm",
"number",
"kmph",
"day",
"hour",
"minute",
"min",
"second",
"man",
"woman",
"sec",
"cube",
"mt",
"sq inch",
"mp",
"∏ cm ³",
"hectare",
"more",
"sec",
"unit",
"cu . m",
"cm 2",
"rs .",
"rs",
"kg",
"g",
"month",
"km",
"m",
"cm",
"mm",
"apple",
"liter",
"loss",
"yard",
"pure",
"year",
"increase",
"decrease",
"d",
"less",
"Surface",
"litre",
"pi sq m",
"s .",
"metre",
"meter",
"inch",
]
unit_texts.extend([t + "s" for t in unit_texts])
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) > 0 and substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
if "sqrt" not in a:
a = int(a)
if "sqrt" not in b:
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except:
return string
def _fix_sqrt(string):
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
return _string
def convert_word_number(text: str) -> str:
try:
text = str(w2n.word_to_num(text))
except:
pass
return text
def strip_string(string, skip_unit=False):
string = str(string).strip()
# linebreaks
string = string.replace("\n", "")
# right "."
string = string.rstrip(".")
# remove inverse spaces
# replace \\ with \
string = string.replace("\\!", "")
# string = string.replace("\\ ", "")
# string = string.replace("\\\\", "\\")
# matrix
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
string = string.replace("bmatrix", "pmatrix")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
string = (
string.replace("\\neq", "\\ne")
.replace("\\leq", "\\le")
.replace("\\geq", "\\ge")
)
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
string = string.replace("\\{", "{")
string = string.replace("\\}", "}")
# Remove unit: miles, dollars if after is not none
_string = re.sub(r"\\text{.*?}$", "", string).strip()
if _string != "" and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
if not skip_unit:
# Remove unit: texts
for _ in range(2):
for unit_text in unit_texts:
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
# the suffix should be either the end of the string or a non-alphanumeric character
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
if _string != "":
string = _string
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
string = string.replace("$", "")
string = string.replace("\\(", "").replace("\\)", "")
# convert word number to digit
string = convert_word_number(string)
# replace "\\text{...}" to "..."
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
string = string.replace(key, "")
string = string.replace("\\emptyset", r"{}")
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# cdot
# string = string.replace("\\cdot", "")
if (
string.startswith("{")
and string.endswith("}")
and string.isalnum()
or string.startswith("(")
and string.endswith(")")
and string.isalnum()
or string.startswith("[")
and string.endswith("]")
and string.isalnum()
):
string = string[1:-1]
# inf
string = string.replace("infinity", "\\infty")
if "\\infty" not in string:
string = string.replace("inf", "\\infty")
string = string.replace("+\\inity", "\\infty")
# and
string = string.replace("and", "")
string = string.replace("\\mathbf", "")
# use regex to remove \mbox{...}
string = re.sub(r"\\mbox{.*?}", "", string)
# quote
string.replace("'", "")
string.replace('"', "")
# i, j
if "j" in string and "i" not in string:
string = string.replace("j", "i")
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
string = re.sub(r"(\d+)\.0*$", r"\1", string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
string = _fix_sqrt(string)
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def extract_answer(pred_str, data_name, use_last_number=True):
pred_str = pred_str.replace("\u043a\u0438", "")
if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]:
# TODO check multiple choice
return choice_answer_clean(pred_str)
if "final answer is $" in pred_str and "$. I hope" in pred_str:
# minerva_math
tmp = pred_str.split("final answer is $", 1)[1]
pred = tmp.split("$. I hope", 1)[0].strip()
elif "boxed" in pred_str:
ans = pred_str.split("boxed")[-1]
if len(ans) == 0:
return ""
elif ans[0] == "{":
stack = 1
a = ""
for c in ans[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = ans.split("$")[0].strip()
pred = a
elif "he answer is" in pred_str:
pred = pred_str.split("he answer is")[-1].strip()
elif "final answer is" in pred_str:
pred = pred_str.split("final answer is")[-1].strip()
elif "答案是" in pred_str:
# Handle Chinese few-shot multiple choice problem answer extraction
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
else: # use the last number
if use_last_number:
pattern = "-?\d*\.?\d+"
pred = re.findall(pattern, pred_str.replace(",", ""))
if len(pred) >= 1:
pred = pred[-1]
else:
pred = ""
else:
pred = ""
# choice answer
if data_name in ["sat_math", "aqua"] or "mmlu" in data_name:
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
if tmp:
pred = tmp[-1]
else:
pred = pred.strip().strip(".")
# multiple line
# pred = pred.split("\n")[0]
pred = re.sub(r"\n\s*", "", pred)
if pred != "" and pred[0] == ":":
pred = pred[1:]
if pred != "" and pred[-1] == ".":
pred = pred[:-1]
if pred != "" and pred[-1] == "/":
pred = pred[:-1]
pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"])
return pred
def str_to_pmatrix(input_str):
input_str = input_str.strip()
matrix_str = re.findall(r"\{.*,.*\}", input_str)
pmatrix_list = []
for m in matrix_str:
m = m.strip("{}")
pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
pmatrix_list.append(pmatrix)
return ", ".join(pmatrix_list)
def parse_digits(num):
num = regex.sub(",", "", str(num))
try:
return float(num)
except:
if num.endswith("%"):
num = num[:-1]
if num.endswith("\\"):
num = num[:-1]
try:
return float(num) / 100
except:
pass
return None
def is_digit(num):
# paired with parse_digits
return parse_digits(num) is not None
def choice_answer_clean(pred: str):
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
# Clean the answer based on the dataset
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
if tmp:
pred = tmp
else:
pred = [pred.strip().strip(".")]
pred = pred[-1]
# Remove the period at the end, again!
pred = pred.rstrip(".").rstrip("/")
return pred
def numeric_equal(prediction: float, reference: float):
from math import isclose
# Note that relative tolerance has significant impact
# on the result of the synthesized GSM-Hard dataset
# if reference.is_integer():
# return isclose(reference, round(prediction), abs_tol=1e-4)
# else:
# prediction = round(prediction, len(str(reference).split(".")[-1]))
return isclose(reference, prediction, rel_tol=1e-4)
def symbolic_equal_process(a, b, output_queue):
result = symbolic_equal(a, b)
output_queue.put(result)
def math_equal(
prediction: Union[bool, float, str],
reference: Union[float, str],
include_percentage: bool = True,
is_close: bool = True,
timeout: bool = False,
) -> bool:
"""
Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
# print("Judge:", prediction, reference)
if prediction is None or reference is None:
return False
if str(prediction.strip().lower()) == str(reference.strip().lower()):
return True
if (
reference in ["A", "B", "C", "D", "E"]
and choice_answer_clean(prediction) == reference
):
return True
try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference):
prediction = parse_digits(prediction)
reference = parse_digits(reference)
# number questions
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
try:
if is_close:
if numeric_equal(prediction, item):
return True
else:
if item == prediction:
return True
except Exception:
continue
return False
except:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
## pmatrix (amps)
if "pmatrix" in prediction and not "pmatrix" in reference:
reference = str_to_pmatrix(reference)
## deal with [], (), {}
pred_str, ref_str = prediction, reference
if (
prediction.startswith("[")
and prediction.endswith("]")
and not reference.startswith("(")
) or (
prediction.startswith("(")
and prediction.endswith(")")
and not reference.startswith("[")
):
pred_str = pred_str.strip("[]()")
ref_str = ref_str.strip("[]()")
for s in ["{", "}", "(", ")"]:
ref_str = ref_str.replace(s, "")
pred_str = pred_str.replace(s, "")
if pred_str.lower() == ref_str.lower():
return True
## [a, b] vs. [c, d], return a==c and b==d
if (
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
):
pred_parts = prediction[1:-1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts):
if all(
[
math_equal(
pred_parts[i], ref_parts[i], include_percentage, is_close
)
for i in range(len(pred_parts))
]
):
return True
if (
(
prediction.startswith("\\begin{pmatrix}")
or prediction.startswith("\\begin{bmatrix}")
)
and (
prediction.endswith("\\end{pmatrix}")
or prediction.endswith("\\end{bmatrix}")
)
and (
reference.startswith("\\begin{pmatrix}")
or reference.startswith("\\begin{bmatrix}")
)
and (
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
)
):
pred_lines = [
line.strip()
for line in prediction[
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
].split("\\\\")
if line.strip()
]
ref_lines = [
line.strip()
for line in reference[
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
].split("\\\\")
if line.strip()
]
matched = True
if len(pred_lines) == len(ref_lines):
for pred_line, ref_line in zip(pred_lines, ref_lines):
pred_parts = pred_line.split("&")
ref_parts = ref_line.split("&")
if len(pred_parts) == len(ref_parts):
if not all(
[
math_equal(
pred_parts[i],
ref_parts[i],
include_percentage,
is_close,
)
for i in range(len(pred_parts))
]
):
matched = False
break
else:
matched = False
if not matched:
break
else:
matched = False
if matched:
return True
if prediction.count("=") == 1 and reference.count("=") == 1:
pred = prediction.split("=")
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
ref = reference.split("=")
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
return True
elif (
prediction.count("=") == 1
and len(prediction.split("=")[0].strip()) <= 2
and "=" not in reference
):
if math_equal(
prediction.split("=")[1], reference, include_percentage, is_close
):
return True
elif (
reference.count("=") == 1
and len(reference.split("=")[0].strip()) <= 2
and "=" not in prediction
):
if math_equal(
prediction, reference.split("=")[1], include_percentage, is_close
):
return True
# symbolic equal with sympy
if timeout:
if call_with_timeout(symbolic_equal_process, prediction, reference):
return True
else:
if symbolic_equal(prediction, reference):
return True
return False
def call_with_timeout(func, *args, timeout=3, **kwargs):
output_queue = multiprocessing.Queue()
process_args = args + (output_queue,)
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
process.start()
process.join(timeout)
if process.is_alive():
process.terminate()
process.join()
return False
return output_queue.get()
def math_equal_process(param):
return math_equal(param[-2], param[-1])
def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr, latex2sympy]:
try:
return f(s.replace("\\\\", "\\"))
except:
try:
return f(s)
except:
pass
return s
a = _parse(a)
b = _parse(b)
# direct equal
try:
if str(a) == str(b) or a == b:
return True
except:
pass
# simplify equal
try:
if a.equals(b) or simplify(a - b) == 0:
return True
except:
pass
# equation equal
try:
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
return True
except:
pass
try:
if numeric_equal(float(N(a)), float(N(b))):
return True
except:
pass
# matrix
try:
# if a and b are matrix
if a.shape == b.shape:
_a = a.applyfunc(lambda x: round(x, 3))
_b = b.applyfunc(lambda x: round(x, 3))
if _a.equals(_b):
return True
except:
pass
return False
def process_results(answer, solution):
try:
extracted_answer = extract_answer(answer, "math", use_last_number=False)
extracted_solution = extract_answer(solution, "math", use_last_number=True)
# if extract_answer.strip() == "":
# print (answer)
# raise
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
retval = 0
elif extracted_solution is None or extracted_solution.strip() in [
"None",
"none",
"",
]:
retval = 0
elif math_equal(extracted_answer, extracted_solution, timeout=False):
# elif call_with_timeout(math_equal, extracted_answer, extracted_solution):
retval = 1
else:
retval = 0
return retval, (extracted_answer, extracted_solution)
except:
return 0, ("None", "None")
def loadJson(dataDir):
with open(dataDir, "r") as f:
if dataDir.endswith(".jsonl"):
samples = [json.loads(line) for line in f.readlines()]
else:
samples = json.load(f)
return samples
def parse_line(id2info, prompt_str, generated, query_id):
info = id2info[query_id.split("@idx:")[0]]
label = 0
for sol in info["solutions"]:
label = label or process_results(generated, sol)
return label
def parse_lines_in_parallel(
id2info,
generateds: List,
query_ids: List,
max_workers=22,
check_xml_format=False,
) -> List:
assert len(generateds) == len(query_ids), (
len(generateds),
len(query_ids),
)
all_jobs = []
with ProcessPool(max_workers=max_workers) as executor:
for qid, gen in zip(query_ids, generateds):
info = id2info[qid.split("@idx:")[0]]
jobs = []
for sol in info["solutions"]:
job = executor.schedule(process_results, args=[gen, sol], timeout=15)
jobs.append(job)
all_jobs.append(jobs)
labels = []
for jobs in all_jobs:
label = 0
for job in jobs:
try:
x = job.result()
except TimeoutError:
# print("[debug: timeout]")
logger.warning(f"Timeout occurred while justifying the math answer.")
x = (0, "timeout", "timeout")
except ProcessExpired as e:
logger.warning(f"Process terminated abnormally: {e}")
x = (0, "error", "error")
except Exception as e:
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
x = (0, "error", "error")
label = label or x[0]
labels.append(label)
return labels
if __name__ == "__main__":
sample = {
"answers": ["\\boxed{-\\frac{2}{3}}"],
"solutions": [
"1. **Apply the operation $\\otimes$ to the innermost parentheses first:**\n \\[\n (1 \\otimes 2) \\otimes 3 = \\left(\\frac{1^2}{2}\\right) \\otimes 3 = \\frac{1}{2} \\otimes 3\n \\]\n \\[\n 1 \\otimes (2 \\otimes 3) = 1 \\otimes \\left(\\frac{2^2}{3}\\right) = 1 \\otimes \\frac{4}{3}\n \\]\n\n2. **Calculate each part using the definition of $\\otimes$:**\n \\[\n \\frac{1}{2} \\otimes 3 = \\frac{\\left(\\frac{1}{2}\\right)^2}{3} = \\frac{\\frac{1}{4}}{3} = \\frac{1}{12}\n \\]\n \\[\n 1 \\otimes \\frac{4}{3} = \\frac{1^2}{\\frac{4}{3}} = \\frac{1}{\\frac{4}{3}} = \\frac{3}{4}\n \\]\n\n3. **Subtract the two results:**\n \\[\n \\left(\\frac{1}{12}\\right) - \\left(\\frac{3}{4}\\right) = \\frac{1}{12} - \\frac{9}{12} = -\\frac{8}{12} = -\\frac{2}{3}\n \\]\n\n4. **Conclude with the final answer:**\n \\[\n \\boxed{A}\n \\]",
"\\boxed{-\\frac{2}{3}}",
],
}
id2info = {"fe11b471-1aa9-4867-958f-a0a811c85f92": sample}
print(
parse_lines_in_parallel(
id2info,
sample["answers"] * 100,
["fe11b471-1aa9-4867-958f-a0a811c85f92" for _ in range(100)],
max_workers=8,
check_xml_format=True,
)
)