mirror of https://github.com/inclusionAI/AReaL
[Bug] Fix the dependency of a virtual environment for sympy==1.12 (#92)
* change to math local eval * . * update docker image tag
This commit is contained in:
parent
c7d6ccc18e
commit
b3f5392f44
|
@ -18,12 +18,6 @@ ENV NVTE_WITH_USERBUFFERS=1 NVTE_FRAMEWORK=pytorch MPI_HOME=/usr/local/mpi
|
|||
ENV PATH="${PATH}:/opt/hpcx/ompi/bin:/opt/hpcx/ucx/bin"
|
||||
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/opt/hpcx/ompi/lib:/opt/hpcx/ucx/lib/"
|
||||
|
||||
RUN git clone --depth=1 https://github.com/QwenLM/Qwen2.5-Math /qwen2_5-math && mv /qwen2_5-math/evaluation/latex2sympy /latex2sympy && rm -rf /qwen2_5-math \
|
||||
&& python3 -m venv /sympy && \
|
||||
/sympy/bin/pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && /sympy/bin/pip config set global.extra-index-url "" \
|
||||
&& /sympy/bin/pip install /latex2sympy && \
|
||||
/sympy/bin/pip install regex numpy tqdm datasets python_dateutil sympy==1.12 antlr4-python3-runtime==4.11.1 word2number Pebble timeout-decorator prettytable
|
||||
|
||||
RUN pip uninstall cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml \
|
||||
cugraph-pyg lightning_thunder opt_einsum nvfuser looseversion lightning_utilities -y
|
||||
RUN pip3 install -U uv nvidia-ml-py pipdeptree importlib_metadata packaging platformdirs typing_extensions wheel zipp
|
||||
|
|
|
@ -24,7 +24,7 @@ The following hardware configuration has been extensively tested:
|
|||
| Git LFS | Required for downloading models, datasets, and AReaL code. See [installation guide](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) |
|
||||
| Docker | 27.5.1 |
|
||||
| NVIDIA Container Toolkit | See [installation guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) |
|
||||
| AReaL Image | `ghcr.io/inclusionai/areal-runtime:v0.3.0` (includes runtime dependencies and Ray components) |
|
||||
| AReaL Image | `ghcr.io/inclusionai/areal-runtime:v0.3.0.post1` (includes runtime dependencies and Ray components) |
|
||||
|
||||
**Note**: This tutorial does not cover the installation of NVIDIA Drivers, CUDA, or shared storage mounting, as these depend on your specific node configuration and system version. Please complete these installations independently.
|
||||
|
||||
|
@ -37,11 +37,11 @@ The following hardware configuration has been extensively tested:
|
|||
We recommend using Docker with our provided image. The Dockerfile is available in the top-level directory of the AReaL repository.
|
||||
|
||||
```bash
|
||||
docker pull ghcr.io/inclusionai/areal-runtime:v0.3.0
|
||||
docker pull ghcr.io/inclusionai/areal-runtime:v0.3.0.post1
|
||||
docker run -it --name areal-node1 \
|
||||
--privileged --gpus all --network host \
|
||||
--shm-size 700g -v /path/to/mount:/path/to/mount \
|
||||
ghcr.io/inclusionai/areal-runtime:v0.3.0 \
|
||||
ghcr.io/inclusionai/areal-runtime:v0.3.0.post1 \
|
||||
/bin/bash
|
||||
git clone https://github.com/inclusionAI/AReaL
|
||||
cd AReaL
|
||||
|
|
|
@ -3,13 +3,12 @@ from concurrent.futures import TimeoutError
|
|||
from parser import *
|
||||
|
||||
import numpy as np
|
||||
from grader import *
|
||||
from pebble import ProcessPool
|
||||
from python_executor import PythonExecutor
|
||||
from tqdm import tqdm
|
||||
from utils import load_jsonl
|
||||
|
||||
from grader import *
|
||||
|
||||
|
||||
def evaluate(
|
||||
data_name,
|
||||
|
|
|
@ -2,9 +2,8 @@ from collections import Counter, defaultdict
|
|||
from parser import strip_string
|
||||
|
||||
import timeout_decorator
|
||||
from utils import load_jsonl
|
||||
|
||||
from grader import math_equal
|
||||
from utils import load_jsonl
|
||||
|
||||
|
||||
@timeout_decorator.timeout(5)
|
||||
|
|
|
@ -4,5 +4,8 @@ cd /sglang
|
|||
git apply $AREAL_PATH/patch/sglang/v0.4.6.post4.patch
|
||||
cd $AREAL_PATH
|
||||
|
||||
# Package used for calculating math reward
|
||||
pip install -e evaluation/latex2sympy
|
||||
|
||||
# Install AReaL
|
||||
pip install -e .
|
|
@ -7,13 +7,8 @@ pip install megatron-core==0.11.0 nvidia-ml-py
|
|||
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# the sympy virtual env for reward computation
|
||||
pip install virtualenv
|
||||
rm -rf ./sympy
|
||||
python3 -m venv ./sympy
|
||||
# equivalent to install `./evaluation/latex2sympy` in the sympy virtual env
|
||||
./sympy/bin/pip install git+https://github.com/QwenLM/Qwen2.5-Math.git#subdirectory=evaluation/latex2sympy
|
||||
./sympy/bin/pip install regex numpy tqdm datasets python_dateutil sympy==1.12 antlr4-python3-runtime==4.11.1 word2number Pebble timeout-decorator prettytable
|
||||
# Package used for calculating math reward
|
||||
pip install -e evaluation/latex2sympy
|
||||
|
||||
# Install an editable sglang
|
||||
rm -rf ./sglang
|
||||
|
|
397
grader.py
397
grader.py
|
@ -1,397 +0,0 @@
|
|||
"""This logic is largely copied from the Hendrycks' MATH release
|
||||
(math_equivalence), and borrowed from:
|
||||
|
||||
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
|
||||
- https://github.com/openai/prm800k
|
||||
- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
|
||||
- https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from math import isclose
|
||||
from typing import Union
|
||||
|
||||
import regex
|
||||
from latex2sympy2 import latex2sympy
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
||||
# from .parser import choice_answer_clean, strip_string
|
||||
# from parser import choice_answer_clean
|
||||
|
||||
|
||||
def choice_answer_clean(pred: str):
|
||||
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
|
||||
# Clean the answer based on the dataset
|
||||
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
||||
if tmp:
|
||||
pred = tmp
|
||||
else:
|
||||
pred = [pred.strip().strip(".")]
|
||||
pred = pred[-1]
|
||||
# Remove the period at the end, again!
|
||||
pred = pred.rstrip(".").rstrip("/")
|
||||
return pred
|
||||
|
||||
|
||||
def parse_digits(num):
|
||||
num = regex.sub(",", "", str(num))
|
||||
try:
|
||||
return float(num)
|
||||
except:
|
||||
if num.endswith("%"):
|
||||
num = num[:-1]
|
||||
if num.endswith("\\"):
|
||||
num = num[:-1]
|
||||
try:
|
||||
return float(num) / 100
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def is_digit(num):
|
||||
# paired with parse_digits
|
||||
return parse_digits(num) is not None
|
||||
|
||||
|
||||
def str_to_pmatrix(input_str):
|
||||
input_str = input_str.strip()
|
||||
matrix_str = re.findall(r"\{.*,.*\}", input_str)
|
||||
pmatrix_list = []
|
||||
|
||||
for m in matrix_str:
|
||||
m = m.strip("{}")
|
||||
pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
|
||||
pmatrix_list.append(pmatrix)
|
||||
|
||||
return ", ".join(pmatrix_list)
|
||||
|
||||
|
||||
def math_equal(
|
||||
prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
include_percentage: bool = True,
|
||||
is_close: bool = True,
|
||||
timeout: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Exact match of math if and only if:
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
# print("Judge:", prediction, reference)
|
||||
if prediction is None or reference is None:
|
||||
return False
|
||||
if str(prediction.strip().lower()) == str(reference.strip().lower()):
|
||||
return True
|
||||
if (
|
||||
reference in ["A", "B", "C", "D", "E"]
|
||||
and choice_answer_clean(prediction) == reference
|
||||
):
|
||||
return True
|
||||
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction) and is_digit(reference):
|
||||
prediction = parse_digits(prediction)
|
||||
reference = parse_digits(reference)
|
||||
# number questions
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if numeric_equal(prediction, item):
|
||||
return True
|
||||
else:
|
||||
if item == prediction:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# 2. symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
## pmatrix (amps)
|
||||
if "pmatrix" in prediction and not "pmatrix" in reference:
|
||||
reference = str_to_pmatrix(reference)
|
||||
|
||||
## deal with [], (), {}
|
||||
pred_str, ref_str = prediction, reference
|
||||
if (
|
||||
prediction.startswith("[")
|
||||
and prediction.endswith("]")
|
||||
and not reference.startswith("(")
|
||||
) or (
|
||||
prediction.startswith("(")
|
||||
and prediction.endswith(")")
|
||||
and not reference.startswith("[")
|
||||
):
|
||||
pred_str = pred_str.strip("[]()")
|
||||
ref_str = ref_str.strip("[]()")
|
||||
for s in ["{", "}", "(", ")"]:
|
||||
ref_str = ref_str.replace(s, "")
|
||||
pred_str = pred_str.replace(s, "")
|
||||
if pred_str.lower() == ref_str.lower():
|
||||
return True
|
||||
|
||||
## [a, b] vs. [c, d], return a==c and b==d
|
||||
if (
|
||||
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
|
||||
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
|
||||
):
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if all(
|
||||
[
|
||||
math_equal(
|
||||
pred_parts[i], ref_parts[i], include_percentage, is_close
|
||||
)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
return True
|
||||
if (
|
||||
(
|
||||
prediction.startswith("\\begin{pmatrix}")
|
||||
or prediction.startswith("\\begin{bmatrix}")
|
||||
)
|
||||
and (
|
||||
prediction.endswith("\\end{pmatrix}")
|
||||
or prediction.endswith("\\end{bmatrix}")
|
||||
)
|
||||
and (
|
||||
reference.startswith("\\begin{pmatrix}")
|
||||
or reference.startswith("\\begin{bmatrix}")
|
||||
)
|
||||
and (
|
||||
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
|
||||
)
|
||||
):
|
||||
pred_lines = [
|
||||
line.strip()
|
||||
for line in prediction[
|
||||
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
||||
].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
ref_lines = [
|
||||
line.strip()
|
||||
for line in reference[
|
||||
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
||||
].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
matched = True
|
||||
if len(pred_lines) == len(ref_lines):
|
||||
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
||||
pred_parts = pred_line.split("&")
|
||||
ref_parts = ref_line.split("&")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if not all(
|
||||
[
|
||||
math_equal(
|
||||
pred_parts[i],
|
||||
ref_parts[i],
|
||||
include_percentage,
|
||||
is_close,
|
||||
)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
matched = False
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if not matched:
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if matched:
|
||||
return True
|
||||
|
||||
if prediction.count("=") == 1 and reference.count("=") == 1:
|
||||
pred = prediction.split("=")
|
||||
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
||||
ref = reference.split("=")
|
||||
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
||||
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
||||
return True
|
||||
elif (
|
||||
prediction.count("=") == 1
|
||||
and len(prediction.split("=")[0].strip()) <= 2
|
||||
and "=" not in reference
|
||||
):
|
||||
if math_equal(
|
||||
prediction.split("=")[1], reference, include_percentage, is_close
|
||||
):
|
||||
return True
|
||||
elif (
|
||||
reference.count("=") == 1
|
||||
and len(reference.split("=")[0].strip()) <= 2
|
||||
and "=" not in prediction
|
||||
):
|
||||
if math_equal(
|
||||
prediction, reference.split("=")[1], include_percentage, is_close
|
||||
):
|
||||
return True
|
||||
|
||||
# symbolic equal with sympy
|
||||
if timeout:
|
||||
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
||||
return True
|
||||
else:
|
||||
if symbolic_equal(prediction, reference):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def math_equal_process(param):
|
||||
return math_equal(param[-2], param[-1])
|
||||
|
||||
|
||||
def numeric_equal(prediction: float, reference: float):
|
||||
# Note that relative tolerance has significant impact
|
||||
# on the result of the synthesized GSM-Hard dataset
|
||||
# if reference.is_integer():
|
||||
# return isclose(reference, round(prediction), abs_tol=1e-4)
|
||||
# else:
|
||||
# prediction = round(prediction, len(str(reference).split(".")[-1]))
|
||||
return isclose(reference, prediction, rel_tol=1e-4)
|
||||
|
||||
|
||||
def symbolic_equal(a, b):
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr, latex2sympy]:
|
||||
try:
|
||||
return f(s.replace("\\\\", "\\"))
|
||||
except:
|
||||
try:
|
||||
return f(s)
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
# direct equal
|
||||
try:
|
||||
if str(a) == str(b) or a == b:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# simplify equal
|
||||
try:
|
||||
if a.equals(b) or simplify(a - b) == 0:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# equation equal
|
||||
try:
|
||||
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if numeric_equal(float(N(a)), float(N(b))):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# matrix
|
||||
try:
|
||||
# if a and b are matrix
|
||||
if a.shape == b.shape:
|
||||
_a = a.applyfunc(lambda x: round(x, 3))
|
||||
_b = b.applyfunc(lambda x: round(x, 3))
|
||||
if _a.equals(_b):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def symbolic_equal_process(a, b, output_queue):
|
||||
result = symbolic_equal(a, b)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
def call_with_timeout(func, *args, timeout=3, **kwargs):
|
||||
output_queue = multiprocessing.Queue()
|
||||
process_args = args + (output_queue,)
|
||||
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join(timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join()
|
||||
return False
|
||||
|
||||
return output_queue.get()
|
||||
|
||||
|
||||
def _test_math_equal():
|
||||
# print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
|
||||
# print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
|
||||
# print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
|
||||
# print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True))
|
||||
# print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True))
|
||||
|
||||
# pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}'
|
||||
# gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})'
|
||||
|
||||
# pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}'
|
||||
# gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}'
|
||||
|
||||
# pred = '-34x-45y+20z-100=0'
|
||||
# gt = '34x+45y-20z+100=0'
|
||||
|
||||
# pred = '\\frac{100}{3}'
|
||||
# gt = '33.3'
|
||||
|
||||
# pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}'
|
||||
# gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})'
|
||||
|
||||
# pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}'
|
||||
# gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}'
|
||||
|
||||
# pred = '(+5)(b+2)'
|
||||
# gt = '(a+5)(b+2)'
|
||||
|
||||
# pred = '\\frac{1+\\sqrt{5}}{2}'
|
||||
# gt = '2'
|
||||
|
||||
# pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4'
|
||||
# pred = '1', gt = '1\\\\sqrt{19}'
|
||||
|
||||
# pred = "(0.6,2.6667]"
|
||||
# gt = "(\\frac{3}{5},\\frac{8}{3}]"
|
||||
|
||||
gt = "x+0.5+0.5"
|
||||
pred = "x+1"
|
||||
|
||||
print(math_equal(pred, gt, timeout=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_math_equal()
|
|
@ -1,69 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from parser import extract_answer
|
||||
|
||||
from grader import call_with_timeout, math_equal
|
||||
|
||||
|
||||
def process_results(answer, solution):
|
||||
|
||||
try:
|
||||
extracted_answer = extract_answer(answer, "math", use_last_number=False)
|
||||
extracted_solution = extract_answer(solution, "math", use_last_number=True)
|
||||
|
||||
# if extract_answer.strip() == "":
|
||||
# print (answer)
|
||||
# raise
|
||||
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
|
||||
retval = 0
|
||||
elif extracted_solution is None or extracted_solution.strip() in [
|
||||
"None",
|
||||
"none",
|
||||
"",
|
||||
]:
|
||||
retval = 0
|
||||
elif math_equal(extracted_answer, extracted_solution, timeout=False):
|
||||
# elif call_with_timeout(math_equal, extracted_answer, extracted_solution):
|
||||
retval = 1
|
||||
else:
|
||||
retval = 0
|
||||
|
||||
return retval, (extracted_answer, extracted_solution)
|
||||
except:
|
||||
return 0, ("None", "None")
|
||||
|
||||
|
||||
def process_results_process(a, b, output_queue):
|
||||
result = process_results(a, b)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tmp_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
all_input_data = []
|
||||
with open(f"/tmp/{args.tmp_id}-input.jsonl", "r") as temp_file:
|
||||
for line in temp_file.readlines():
|
||||
all_input_data.append(json.loads(line))
|
||||
|
||||
with open(f"/tmp/{args.tmp_id}-output.jsonl", "w", encoding="utf-8") as temp_file:
|
||||
for input_data in all_input_data:
|
||||
# r, (ans, sol) = process_results(
|
||||
# input_data["answer"], input_data["solution"]
|
||||
# )
|
||||
tmp = call_with_timeout(
|
||||
process_results_process, input_data["answer"], input_data["solution"]
|
||||
)
|
||||
if isinstance(tmp, bool):
|
||||
r, (ans, sol) = 0, ("None", "None")
|
||||
else:
|
||||
r, (ans, sol) = tmp
|
||||
|
||||
res = {"retval": r, "ans": ans, "sol": sol}
|
||||
temp_file.write(json.dumps(res) + "\n")
|
||||
|
||||
# print (process_results("answer is: \\boxed{2.0}", "the anser is: \\boxed{200\\%}"))
|
767
parser.py
767
parser.py
|
@ -1,767 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, List, TypeVar, Union
|
||||
|
||||
import regex
|
||||
import sympy
|
||||
from latex2sympy2 import latex2sympy
|
||||
from word2number import w2n
|
||||
|
||||
# from utils import *
|
||||
|
||||
|
||||
def _fix_fracs(string):
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) > 0 and substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string):
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
if "sqrt" not in a:
|
||||
a = int(a)
|
||||
if "sqrt" not in b:
|
||||
b = int(b)
|
||||
assert string == "{}/{}".format(a, b)
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string):
|
||||
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
||||
return _string
|
||||
|
||||
|
||||
def convert_word_number(text: str) -> str:
|
||||
try:
|
||||
text = str(w2n.word_to_num(text))
|
||||
except:
|
||||
pass
|
||||
return text
|
||||
|
||||
|
||||
# units mainly from MathQA
|
||||
unit_texts = [
|
||||
"east",
|
||||
"degree",
|
||||
"mph",
|
||||
"kmph",
|
||||
"ft",
|
||||
"m sqaure",
|
||||
" m east",
|
||||
"sq m",
|
||||
"deg",
|
||||
"mile",
|
||||
"q .",
|
||||
"monkey",
|
||||
"prime",
|
||||
"ratio",
|
||||
"profit of rs",
|
||||
"rd",
|
||||
"o",
|
||||
"gm",
|
||||
"p . m",
|
||||
"lb",
|
||||
"tile",
|
||||
"per",
|
||||
"dm",
|
||||
"lt",
|
||||
"gain",
|
||||
"ab",
|
||||
"way",
|
||||
"west",
|
||||
"a .",
|
||||
"b .",
|
||||
"c .",
|
||||
"d .",
|
||||
"e .",
|
||||
"f .",
|
||||
"g .",
|
||||
"h .",
|
||||
"t",
|
||||
"a",
|
||||
"h",
|
||||
"no change",
|
||||
"men",
|
||||
"soldier",
|
||||
"pie",
|
||||
"bc",
|
||||
"excess",
|
||||
"st",
|
||||
"inches",
|
||||
"noon",
|
||||
"percent",
|
||||
"by",
|
||||
"gal",
|
||||
"kmh",
|
||||
"c",
|
||||
"acre",
|
||||
"rise",
|
||||
"a . m",
|
||||
"th",
|
||||
"π r 2",
|
||||
"sq",
|
||||
"mark",
|
||||
"l",
|
||||
"toy",
|
||||
"coin",
|
||||
"sq . m",
|
||||
"gallon",
|
||||
"° f",
|
||||
"profit",
|
||||
"minw",
|
||||
"yr",
|
||||
"women",
|
||||
"feet",
|
||||
"am",
|
||||
"pm",
|
||||
"hr",
|
||||
"cu cm",
|
||||
"square",
|
||||
"v â € ™",
|
||||
"are",
|
||||
"rupee",
|
||||
"rounds",
|
||||
"cubic",
|
||||
"cc",
|
||||
"mtr",
|
||||
"s",
|
||||
"ohm",
|
||||
"number",
|
||||
"kmph",
|
||||
"day",
|
||||
"hour",
|
||||
"minute",
|
||||
"min",
|
||||
"second",
|
||||
"man",
|
||||
"woman",
|
||||
"sec",
|
||||
"cube",
|
||||
"mt",
|
||||
"sq inch",
|
||||
"mp",
|
||||
"∏ cm ³",
|
||||
"hectare",
|
||||
"more",
|
||||
"sec",
|
||||
"unit",
|
||||
"cu . m",
|
||||
"cm 2",
|
||||
"rs .",
|
||||
"rs",
|
||||
"kg",
|
||||
"g",
|
||||
"month",
|
||||
"km",
|
||||
"m",
|
||||
"cm",
|
||||
"mm",
|
||||
"apple",
|
||||
"liter",
|
||||
"loss",
|
||||
"yard",
|
||||
"pure",
|
||||
"year",
|
||||
"increase",
|
||||
"decrease",
|
||||
"d",
|
||||
"less",
|
||||
"Surface",
|
||||
"litre",
|
||||
"pi sq m",
|
||||
"s .",
|
||||
"metre",
|
||||
"meter",
|
||||
"inch",
|
||||
]
|
||||
|
||||
unit_texts.extend([t + "s" for t in unit_texts])
|
||||
|
||||
|
||||
def strip_string(string, skip_unit=False):
|
||||
string = str(string).strip()
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# right "."
|
||||
string = string.rstrip(".")
|
||||
|
||||
# remove inverse spaces
|
||||
# replace \\ with \
|
||||
string = string.replace("\\!", "")
|
||||
# string = string.replace("\\ ", "")
|
||||
# string = string.replace("\\\\", "\\")
|
||||
|
||||
# matrix
|
||||
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
|
||||
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
|
||||
string = string.replace("bmatrix", "pmatrix")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
string = (
|
||||
string.replace("\\neq", "\\ne")
|
||||
.replace("\\leq", "\\le")
|
||||
.replace("\\geq", "\\ge")
|
||||
)
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\right", "")
|
||||
string = string.replace("\\{", "{")
|
||||
string = string.replace("\\}", "}")
|
||||
|
||||
# Remove unit: miles, dollars if after is not none
|
||||
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
||||
if _string != "" and _string != string:
|
||||
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
||||
string = _string
|
||||
|
||||
if not skip_unit:
|
||||
# Remove unit: texts
|
||||
for _ in range(2):
|
||||
for unit_text in unit_texts:
|
||||
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
|
||||
# the suffix should be either the end of the string or a non-alphanumeric character
|
||||
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
|
||||
if _string != "":
|
||||
string = _string
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
string = string.replace("$", "")
|
||||
string = string.replace("\\(", "").replace("\\)", "")
|
||||
|
||||
# convert word number to digit
|
||||
string = convert_word_number(string)
|
||||
|
||||
# replace "\\text{...}" to "..."
|
||||
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
|
||||
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
|
||||
string = string.replace(key, "")
|
||||
string = string.replace("\\emptyset", r"{}")
|
||||
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace("\%", "")
|
||||
string = string.replace("%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
|
||||
# cdot
|
||||
# string = string.replace("\\cdot", "")
|
||||
if (
|
||||
string.startswith("{")
|
||||
and string.endswith("}")
|
||||
and string.isalnum()
|
||||
or string.startswith("(")
|
||||
and string.endswith(")")
|
||||
and string.isalnum()
|
||||
or string.startswith("[")
|
||||
and string.endswith("]")
|
||||
and string.isalnum()
|
||||
):
|
||||
string = string[1:-1]
|
||||
|
||||
# inf
|
||||
string = string.replace("infinity", "\\infty")
|
||||
if "\\infty" not in string:
|
||||
string = string.replace("inf", "\\infty")
|
||||
string = string.replace("+\\inity", "\\infty")
|
||||
|
||||
# and
|
||||
string = string.replace("and", "")
|
||||
string = string.replace("\\mathbf", "")
|
||||
|
||||
# use regex to remove \mbox{...}
|
||||
string = re.sub(r"\\mbox{.*?}", "", string)
|
||||
|
||||
# quote
|
||||
string.replace("'", "")
|
||||
string.replace('"', "")
|
||||
|
||||
# i, j
|
||||
if "j" in string and "i" not in string:
|
||||
string = string.replace("j", "i")
|
||||
|
||||
# replace a.000b where b is not number or b is end, with ab, use regex
|
||||
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
|
||||
string = re.sub(r"(\d+)\.0*$", r"\1", string)
|
||||
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split("=")) == 2:
|
||||
if len(string.split("=")[0]) <= 2:
|
||||
string = string.split("=")[1]
|
||||
|
||||
string = _fix_sqrt(string)
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def extract_multi_choice_answer(pred_str):
|
||||
# TODO: SFT models
|
||||
if "Problem:" in pred_str:
|
||||
pred_str = pred_str.split("Problem:", 1)[0]
|
||||
pred_str = pred_str.replace("choice is", "answer is")
|
||||
patt = regex.search(r"answer is \(?(?P<ans>[abcde])\)?", pred_str.lower())
|
||||
if patt is not None:
|
||||
return patt.group("ans").upper()
|
||||
return "placeholder"
|
||||
|
||||
|
||||
direct_answer_trigger_for_fewshot = ("choice is", "answer is")
|
||||
|
||||
|
||||
def choice_answer_clean(pred: str):
|
||||
pred = pred.strip("\n")
|
||||
|
||||
# Determine if this is ICL, if so, use \n\n to split the first chunk.
|
||||
ICL = False
|
||||
for trigger in direct_answer_trigger_for_fewshot:
|
||||
if pred.count(trigger) > 1:
|
||||
ICL = True
|
||||
if ICL:
|
||||
pred = pred.split("\n\n")[0]
|
||||
|
||||
# Split the trigger to find the answer.
|
||||
preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred)
|
||||
if len(preds) > 1:
|
||||
answer_flag = True
|
||||
pred = preds[-1]
|
||||
else:
|
||||
answer_flag = False
|
||||
|
||||
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
|
||||
|
||||
# Clean the answer based on the dataset
|
||||
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
||||
if tmp:
|
||||
pred = tmp
|
||||
else:
|
||||
pred = [pred.strip().strip(".")]
|
||||
|
||||
if len(pred) == 0:
|
||||
pred = ""
|
||||
else:
|
||||
if answer_flag:
|
||||
# choose the first element in list ...
|
||||
pred = pred[0]
|
||||
else:
|
||||
# choose the last e
|
||||
pred = pred[-1]
|
||||
|
||||
# Remove the period at the end, again!
|
||||
pred = pred.rstrip(".").rstrip("/")
|
||||
|
||||
return pred
|
||||
|
||||
|
||||
def find_box(pred_str: str):
|
||||
ans = pred_str.split("boxed")[-1]
|
||||
if not ans:
|
||||
return ""
|
||||
if ans[0] == "{":
|
||||
stack = 1
|
||||
a = ""
|
||||
for c in ans[1:]:
|
||||
if c == "{":
|
||||
stack += 1
|
||||
a += c
|
||||
elif c == "}":
|
||||
stack -= 1
|
||||
if stack == 0:
|
||||
break
|
||||
a += c
|
||||
else:
|
||||
a += c
|
||||
else:
|
||||
a = ans.split("$")[0].strip()
|
||||
return a
|
||||
|
||||
|
||||
def clean_units(pred_str: str):
|
||||
"""Clean the units in the number."""
|
||||
|
||||
def convert_pi_to_number(code_string):
|
||||
code_string = code_string.replace("\\pi", "π")
|
||||
# Replace \pi or π not preceded by a digit or } with 3.14
|
||||
code_string = re.sub(r"(?<![\d}])\\?π", "3.14", code_string)
|
||||
# Replace instances where π is preceded by a digit but without a multiplication symbol, e.g., "3π" -> "3*3.14"
|
||||
code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string)
|
||||
# Handle cases where π is within braces or followed by a multiplication symbol
|
||||
# This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14"
|
||||
code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string)
|
||||
code_string = re.sub(r"\*(\\?π)", "*3.14", code_string)
|
||||
return code_string
|
||||
|
||||
pred_str = convert_pi_to_number(pred_str)
|
||||
pred_str = pred_str.replace("%", "/100")
|
||||
pred_str = pred_str.replace("$", "")
|
||||
pred_str = pred_str.replace("¥", "")
|
||||
pred_str = pred_str.replace("°C", "")
|
||||
pred_str = pred_str.replace(" C", "")
|
||||
pred_str = pred_str.replace("°", "")
|
||||
return pred_str
|
||||
|
||||
|
||||
def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
|
||||
if any([option in pred.lower() for option in ["yes", "true"]]):
|
||||
pred = "True"
|
||||
elif any([option in pred.lower() for option in ["no", "false"]]):
|
||||
pred = "False"
|
||||
elif any(
|
||||
[
|
||||
option in pred.lower()
|
||||
for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]
|
||||
]
|
||||
):
|
||||
pass
|
||||
else:
|
||||
# Some of the models somehow get used to boxed output from pre-training
|
||||
if "boxed" in pred:
|
||||
pred = find_box(pred)
|
||||
|
||||
if answer_flag:
|
||||
# Extract the numbers out of the string
|
||||
pred = pred.split("=")[-1].strip()
|
||||
pred = clean_units(pred)
|
||||
try:
|
||||
tmp = str(latex2sympy(pred))
|
||||
pred = str(eval(tmp))
|
||||
except Exception:
|
||||
if re.match(r"-?[\d\.]+\s\D+$", pred):
|
||||
pred = pred.split(" ")[0]
|
||||
elif re.match(r"-?[\d\.]+\s[^\s]+$", pred):
|
||||
pred = pred.split(" ")[0]
|
||||
else:
|
||||
# desparate search over the last number
|
||||
preds = re.findall(r"-?\d*\.?\d+", pred)
|
||||
if len(preds) >= 1:
|
||||
pred = preds[-1]
|
||||
else:
|
||||
pred = ""
|
||||
|
||||
return pred
|
||||
|
||||
|
||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||
pred_str = pred_str.replace("\u043a\u0438", "")
|
||||
if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]:
|
||||
# TODO check multiple choice
|
||||
return choice_answer_clean(pred_str)
|
||||
|
||||
if "final answer is $" in pred_str and "$. I hope" in pred_str:
|
||||
# minerva_math
|
||||
tmp = pred_str.split("final answer is $", 1)[1]
|
||||
pred = tmp.split("$. I hope", 1)[0].strip()
|
||||
elif "boxed" in pred_str:
|
||||
ans = pred_str.split("boxed")[-1]
|
||||
if len(ans) == 0:
|
||||
return ""
|
||||
elif ans[0] == "{":
|
||||
stack = 1
|
||||
a = ""
|
||||
for c in ans[1:]:
|
||||
if c == "{":
|
||||
stack += 1
|
||||
a += c
|
||||
elif c == "}":
|
||||
stack -= 1
|
||||
if stack == 0:
|
||||
break
|
||||
a += c
|
||||
else:
|
||||
a += c
|
||||
else:
|
||||
a = ans.split("$")[0].strip()
|
||||
pred = a
|
||||
elif "he answer is" in pred_str:
|
||||
pred = pred_str.split("he answer is")[-1].strip()
|
||||
elif "final answer is" in pred_str:
|
||||
pred = pred_str.split("final answer is")[-1].strip()
|
||||
elif "答案是" in pred_str:
|
||||
# Handle Chinese few-shot multiple choice problem answer extraction
|
||||
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
|
||||
else: # use the last number
|
||||
if use_last_number:
|
||||
pattern = "-?\d*\.?\d+"
|
||||
pred = re.findall(pattern, pred_str.replace(",", ""))
|
||||
if len(pred) >= 1:
|
||||
pred = pred[-1]
|
||||
else:
|
||||
pred = ""
|
||||
else:
|
||||
pred = ""
|
||||
|
||||
# choice answer
|
||||
if data_name in ["sat_math", "aqua"] or "mmlu" in data_name:
|
||||
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
||||
if tmp:
|
||||
pred = tmp[-1]
|
||||
else:
|
||||
pred = pred.strip().strip(".")
|
||||
|
||||
# multiple line
|
||||
# pred = pred.split("\n")[0]
|
||||
pred = re.sub(r"\n\s*", "", pred)
|
||||
if pred != "" and pred[0] == ":":
|
||||
pred = pred[1:]
|
||||
if pred != "" and pred[-1] == ".":
|
||||
pred = pred[:-1]
|
||||
if pred != "" and pred[-1] == "/":
|
||||
pred = pred[:-1]
|
||||
pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"])
|
||||
return pred
|
||||
|
||||
|
||||
STRIP_EXCEPTIONS = ["carp_en", "minerva_math"]
|
||||
|
||||
|
||||
def parse_ground_truth(example: Dict[str, Any], data_name):
|
||||
if "gt_cot" in example and "gt" in example:
|
||||
if data_name in ["math"]:
|
||||
gt_ans = extract_answer(example["gt_cot"], data_name)
|
||||
elif data_name in STRIP_EXCEPTIONS:
|
||||
gt_ans = example["gt"]
|
||||
else:
|
||||
gt_ans = strip_string(example["gt"])
|
||||
return example["gt_cot"], gt_ans
|
||||
|
||||
# parse ground truth
|
||||
if data_name in ["math", "minerva_math", "math_500"]:
|
||||
gt_cot = example["solution"]
|
||||
gt_ans = extract_answer(gt_cot, data_name)
|
||||
elif data_name == "gsm8k":
|
||||
gt_cot, gt_ans = example["answer"].split("####")
|
||||
elif data_name == "svamp":
|
||||
gt_cot, gt_ans = example["Equation"], example["Answer"]
|
||||
elif data_name == "asdiv":
|
||||
gt_cot = example["formula"]
|
||||
gt_ans = re.sub(r"\(.*?\)", "", example["answer"])
|
||||
elif data_name == "mawps":
|
||||
gt_cot, gt_ans = None, example["target"]
|
||||
elif data_name == "tabmwp":
|
||||
gt_cot = example["solution"]
|
||||
gt_ans = example["answer"]
|
||||
if example["ans_type"] in ["integer_number", "decimal_number"]:
|
||||
if "/" in gt_ans:
|
||||
gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1])
|
||||
elif "," in gt_ans:
|
||||
gt_ans = float(gt_ans.replace(",", ""))
|
||||
elif "%" in gt_ans:
|
||||
gt_ans = float(gt_ans.split("%")[0]) / 100
|
||||
else:
|
||||
gt_ans = float(gt_ans)
|
||||
elif data_name == "carp_en":
|
||||
gt_cot, gt_ans = example["steps"], example["answer"]
|
||||
elif data_name == "mmlu_stem":
|
||||
abcd = "ABCD"
|
||||
gt_cot, gt_ans = None, abcd[example["answer"]]
|
||||
elif data_name == "sat_math":
|
||||
gt_cot, gt_ans = None, example["Answer"]
|
||||
elif data_name == "aqua":
|
||||
gt_cot, gt_ans = None, example["correct"]
|
||||
elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]:
|
||||
gt_cot, gt_ans = None, example["answer"].replace("$", "").strip()
|
||||
elif data_name == "gaokao_math_qa":
|
||||
gt_cot, gt_ans = None, example["label"]
|
||||
elif data_name in ["gaokao2024_mix", "cn_middle_school"]:
|
||||
if len(example["choice_answer"]) > 0:
|
||||
gt_cot, gt_ans = None, example["choice_answer"]
|
||||
else:
|
||||
gt_cot, gt_ans = None, example["answer"]
|
||||
elif data_name == "olympiadbench":
|
||||
gt_cot, gt_ans = None, example["final_answer"][0].strip("$")
|
||||
elif data_name in [
|
||||
"aime24",
|
||||
"amc23",
|
||||
"cmath",
|
||||
"gaokao2024_I",
|
||||
"gaokao2024_II",
|
||||
"imo2024",
|
||||
]:
|
||||
gt_cot, gt_ans = None, example["answer"]
|
||||
elif data_name.startswith("train"): # train_amc_aime
|
||||
gt_cot, gt_ans = None, example["final_answer"]
|
||||
else:
|
||||
raise NotImplementedError(f"`{data_name}`")
|
||||
# post process
|
||||
gt_cot = str(gt_cot).strip()
|
||||
if data_name not in STRIP_EXCEPTIONS:
|
||||
gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en")
|
||||
else:
|
||||
gt_ans = (
|
||||
gt_ans.replace("\\neq", "\\ne")
|
||||
.replace("\\leq", "\\le")
|
||||
.replace("\\geq", "\\ge")
|
||||
)
|
||||
return gt_cot, gt_ans
|
||||
|
||||
|
||||
def parse_question(example, data_name):
|
||||
question = ""
|
||||
if data_name == "asdiv":
|
||||
question = f"{example['body'].strip()} {example['question'].strip()}"
|
||||
elif data_name == "svamp":
|
||||
body = example["Body"].strip()
|
||||
if not body.endswith("."):
|
||||
body = body + "."
|
||||
question = f'{body} {example["Question"].strip()}'
|
||||
elif data_name == "tabmwp":
|
||||
title_str = (
|
||||
f'regarding "{example["table_title"]}" ' if example["table_title"] else ""
|
||||
)
|
||||
question = f"Read the following table {title_str}and answer a question:\n"
|
||||
question += f'{example["table"]}\n{example["question"]}'
|
||||
if example["choices"]:
|
||||
question += (
|
||||
f' Please select from the following options: {example["choices"]}'
|
||||
)
|
||||
elif data_name == "carp_en":
|
||||
question = example["content"]
|
||||
elif data_name == "mmlu_stem":
|
||||
options = example["choices"]
|
||||
assert len(options) == 4
|
||||
for i, (label, option) in enumerate(zip("ABCD", options)):
|
||||
options[i] = f"({label}) {str(option).strip()}"
|
||||
options = " ".join(options)
|
||||
# question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
|
||||
question = f"{example['question'].strip()}\nAnswer Choices: {options}"
|
||||
elif data_name == "sat_math":
|
||||
options = example["options"].strip()
|
||||
assert "A" == options[0]
|
||||
options = "(" + options
|
||||
for ch in "BCD":
|
||||
if f" {ch}) " in options:
|
||||
options = regex.sub(f" {ch}\) ", f" ({ch}) ", options)
|
||||
# question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
|
||||
question = f"{example['question'].strip()}\nAnswer Choices: {options}"
|
||||
elif "aqua" in data_name:
|
||||
options = example["options"]
|
||||
choice = "(" + "(".join(options)
|
||||
choice = choice.replace("(", " (").replace(")", ") ").strip()
|
||||
choice = "\nAnswer Choices: " + choice
|
||||
question = example["question"].strip() + choice
|
||||
elif data_name == "gaokao_math_qa":
|
||||
options_dict = example["options"]
|
||||
options = []
|
||||
for key in options_dict:
|
||||
options.append(f"({key}) {options_dict[key]}")
|
||||
options = " ".join(options)
|
||||
question = f"{example['question'].strip()}\n选项: {options}"
|
||||
elif data_name.startswith("train"): # train_amc_aime:
|
||||
question = example["prompt"]
|
||||
else:
|
||||
for key in ["question", "problem", "Question", "input"]:
|
||||
if key in example:
|
||||
question = example[key]
|
||||
break
|
||||
# assert question != ""
|
||||
# Yes or No question
|
||||
_, gt_ans = parse_ground_truth(example, data_name)
|
||||
if isinstance(gt_ans, str):
|
||||
gt_lower = gt_ans.lower()
|
||||
if gt_lower in ["true", "false"]:
|
||||
question += " (True or False)"
|
||||
if gt_lower in ["yes", "no"]:
|
||||
question += " (Yes or No)"
|
||||
return question.strip()
|
||||
|
||||
|
||||
# def run_execute(executor, result, prompt_type, data_name, execute=False):
|
||||
# if not result or result == "error":
|
||||
# return None, None
|
||||
# report = None
|
||||
|
||||
# if "program_only" in prompt_type:
|
||||
# prediction = extract_program_output(result)
|
||||
# elif prompt_type in ["pot", "pal"] and execute:
|
||||
# code = extract_program(result)
|
||||
# prediction, report = executor.apply(code)
|
||||
# else:
|
||||
# prediction = extract_answer(result, data_name)
|
||||
|
||||
# # prediction = strip_string(prediction, skip_unit=data_name == "carp_en")
|
||||
# prediction = strip_string(prediction, skip_unit=data_name in STRIP_EXCEPTIONS)
|
||||
# return prediction, report
|
||||
|
||||
|
||||
def _test_extract_answer():
|
||||
text = """
|
||||
This is still not equal to $0$, so we must have made another mistake.
|
||||
|
||||
When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that:
|
||||
|
||||
\[\frac{386}{64} - 7 = \frac{386}{64} - \frac{7 \cdot 64}{1 \cdot 64} = \frac{386 - 448}{64} = \frac{-62}{64}.\]
|
||||
|
||||
This is still not equal to $0$, so we must have made another mistake.
|
||||
|
||||
When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that:
|
||||
|
||||
\[\frac{386}{64} 025
|
||||
"""
|
||||
print(extract_answer(text, "math-oai", use_last_number=True))
|
||||
print(choice_answer_clean("\mathrm{(D)\}1,008,016"))
|
||||
# should output a dict
|
||||
|
||||
print(extract_answer("The product of the inner terms is \\($ 15x $\\).", "math"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_extract_answer()
|
|
@ -50,8 +50,7 @@ dependencies = [
|
|||
"orjson>=3.10.16",
|
||||
"pydantic",
|
||||
"PyYAML",
|
||||
"omegaconf",
|
||||
"hydra-core",
|
||||
"hydra-core==1.4.0.dev1",
|
||||
"packaging",
|
||||
"tabulate",
|
||||
|
||||
|
@ -99,6 +98,12 @@ dependencies = [
|
|||
"distro-info>=1.0",
|
||||
"python-debian>=0.1.49",
|
||||
"func_timeout",
|
||||
"regex",
|
||||
"python_dateutil",
|
||||
"word2number",
|
||||
"Pebble",
|
||||
"timeout-decorator",
|
||||
"prettytable",
|
||||
|
||||
# Development tools (consider moving to optional dependencies)
|
||||
"pytest",
|
||||
|
|
|
@ -1,39 +1,806 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from typing import *
|
||||
import multiprocessing
|
||||
import re
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import List, Union
|
||||
|
||||
import regex
|
||||
from latex2sympy2 import latex2sympy
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
from word2number import w2n
|
||||
|
||||
from realhf.base import logging
|
||||
|
||||
logger = logging.getLogger("math parser")
|
||||
|
||||
# units mainly from MathQA
|
||||
unit_texts = [
|
||||
"east",
|
||||
"degree",
|
||||
"mph",
|
||||
"kmph",
|
||||
"ft",
|
||||
"m sqaure",
|
||||
" m east",
|
||||
"sq m",
|
||||
"deg",
|
||||
"mile",
|
||||
"q .",
|
||||
"monkey",
|
||||
"prime",
|
||||
"ratio",
|
||||
"profit of rs",
|
||||
"rd",
|
||||
"o",
|
||||
"gm",
|
||||
"p . m",
|
||||
"lb",
|
||||
"tile",
|
||||
"per",
|
||||
"dm",
|
||||
"lt",
|
||||
"gain",
|
||||
"ab",
|
||||
"way",
|
||||
"west",
|
||||
"a .",
|
||||
"b .",
|
||||
"c .",
|
||||
"d .",
|
||||
"e .",
|
||||
"f .",
|
||||
"g .",
|
||||
"h .",
|
||||
"t",
|
||||
"a",
|
||||
"h",
|
||||
"no change",
|
||||
"men",
|
||||
"soldier",
|
||||
"pie",
|
||||
"bc",
|
||||
"excess",
|
||||
"st",
|
||||
"inches",
|
||||
"noon",
|
||||
"percent",
|
||||
"by",
|
||||
"gal",
|
||||
"kmh",
|
||||
"c",
|
||||
"acre",
|
||||
"rise",
|
||||
"a . m",
|
||||
"th",
|
||||
"π r 2",
|
||||
"sq",
|
||||
"mark",
|
||||
"l",
|
||||
"toy",
|
||||
"coin",
|
||||
"sq . m",
|
||||
"gallon",
|
||||
"° f",
|
||||
"profit",
|
||||
"minw",
|
||||
"yr",
|
||||
"women",
|
||||
"feet",
|
||||
"am",
|
||||
"pm",
|
||||
"hr",
|
||||
"cu cm",
|
||||
"square",
|
||||
"v â € ™",
|
||||
"are",
|
||||
"rupee",
|
||||
"rounds",
|
||||
"cubic",
|
||||
"cc",
|
||||
"mtr",
|
||||
"s",
|
||||
"ohm",
|
||||
"number",
|
||||
"kmph",
|
||||
"day",
|
||||
"hour",
|
||||
"minute",
|
||||
"min",
|
||||
"second",
|
||||
"man",
|
||||
"woman",
|
||||
"sec",
|
||||
"cube",
|
||||
"mt",
|
||||
"sq inch",
|
||||
"mp",
|
||||
"∏ cm ³",
|
||||
"hectare",
|
||||
"more",
|
||||
"sec",
|
||||
"unit",
|
||||
"cu . m",
|
||||
"cm 2",
|
||||
"rs .",
|
||||
"rs",
|
||||
"kg",
|
||||
"g",
|
||||
"month",
|
||||
"km",
|
||||
"m",
|
||||
"cm",
|
||||
"mm",
|
||||
"apple",
|
||||
"liter",
|
||||
"loss",
|
||||
"yard",
|
||||
"pure",
|
||||
"year",
|
||||
"increase",
|
||||
"decrease",
|
||||
"d",
|
||||
"less",
|
||||
"Surface",
|
||||
"litre",
|
||||
"pi sq m",
|
||||
"s .",
|
||||
"metre",
|
||||
"meter",
|
||||
"inch",
|
||||
]
|
||||
|
||||
def get_box(s):
|
||||
pos = -1
|
||||
cnt = 0
|
||||
for i in range(len(s)):
|
||||
if s[i] == "{":
|
||||
cnt += 1
|
||||
if cnt == 1:
|
||||
pos = i + 1
|
||||
if s[i] == "}":
|
||||
cnt -= 1
|
||||
if cnt == 0:
|
||||
return s[pos:i]
|
||||
unit_texts.extend([t + "s" for t in unit_texts])
|
||||
|
||||
|
||||
def _fix_fracs(string):
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) > 0 and substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string):
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
if "sqrt" not in a:
|
||||
a = int(a)
|
||||
if "sqrt" not in b:
|
||||
b = int(b)
|
||||
assert string == "{}/{}".format(a, b)
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string):
|
||||
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
||||
return _string
|
||||
|
||||
|
||||
def convert_word_number(text: str) -> str:
|
||||
try:
|
||||
text = str(w2n.word_to_num(text))
|
||||
except:
|
||||
pass
|
||||
return text
|
||||
|
||||
|
||||
def strip_string(string, skip_unit=False):
|
||||
string = str(string).strip()
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# right "."
|
||||
string = string.rstrip(".")
|
||||
|
||||
# remove inverse spaces
|
||||
# replace \\ with \
|
||||
string = string.replace("\\!", "")
|
||||
# string = string.replace("\\ ", "")
|
||||
# string = string.replace("\\\\", "\\")
|
||||
|
||||
# matrix
|
||||
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
|
||||
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
|
||||
string = string.replace("bmatrix", "pmatrix")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
string = (
|
||||
string.replace("\\neq", "\\ne")
|
||||
.replace("\\leq", "\\le")
|
||||
.replace("\\geq", "\\ge")
|
||||
)
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\right", "")
|
||||
string = string.replace("\\{", "{")
|
||||
string = string.replace("\\}", "}")
|
||||
|
||||
# Remove unit: miles, dollars if after is not none
|
||||
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
||||
if _string != "" and _string != string:
|
||||
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
||||
string = _string
|
||||
|
||||
if not skip_unit:
|
||||
# Remove unit: texts
|
||||
for _ in range(2):
|
||||
for unit_text in unit_texts:
|
||||
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
|
||||
# the suffix should be either the end of the string or a non-alphanumeric character
|
||||
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
|
||||
if _string != "":
|
||||
string = _string
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
string = string.replace("$", "")
|
||||
string = string.replace("\\(", "").replace("\\)", "")
|
||||
|
||||
# convert word number to digit
|
||||
string = convert_word_number(string)
|
||||
|
||||
# replace "\\text{...}" to "..."
|
||||
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
|
||||
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
|
||||
string = string.replace(key, "")
|
||||
string = string.replace("\\emptyset", r"{}")
|
||||
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace("\%", "")
|
||||
string = string.replace("%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
|
||||
# cdot
|
||||
# string = string.replace("\\cdot", "")
|
||||
if (
|
||||
string.startswith("{")
|
||||
and string.endswith("}")
|
||||
and string.isalnum()
|
||||
or string.startswith("(")
|
||||
and string.endswith(")")
|
||||
and string.isalnum()
|
||||
or string.startswith("[")
|
||||
and string.endswith("]")
|
||||
and string.isalnum()
|
||||
):
|
||||
string = string[1:-1]
|
||||
|
||||
# inf
|
||||
string = string.replace("infinity", "\\infty")
|
||||
if "\\infty" not in string:
|
||||
string = string.replace("inf", "\\infty")
|
||||
string = string.replace("+\\inity", "\\infty")
|
||||
|
||||
# and
|
||||
string = string.replace("and", "")
|
||||
string = string.replace("\\mathbf", "")
|
||||
|
||||
# use regex to remove \mbox{...}
|
||||
string = re.sub(r"\\mbox{.*?}", "", string)
|
||||
|
||||
# quote
|
||||
string.replace("'", "")
|
||||
string.replace('"', "")
|
||||
|
||||
# i, j
|
||||
if "j" in string and "i" not in string:
|
||||
string = string.replace("j", "i")
|
||||
|
||||
# replace a.000b where b is not number or b is end, with ab, use regex
|
||||
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
|
||||
string = re.sub(r"(\d+)\.0*$", r"\1", string)
|
||||
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split("=")) == 2:
|
||||
if len(string.split("=")[0]) <= 2:
|
||||
string = string.split("=")[1]
|
||||
|
||||
string = _fix_sqrt(string)
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def extract_answer(pred_str, data_name, use_last_number=True):
|
||||
pred_str = pred_str.replace("\u043a\u0438", "")
|
||||
if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]:
|
||||
# TODO check multiple choice
|
||||
return choice_answer_clean(pred_str)
|
||||
|
||||
if "final answer is $" in pred_str and "$. I hope" in pred_str:
|
||||
# minerva_math
|
||||
tmp = pred_str.split("final answer is $", 1)[1]
|
||||
pred = tmp.split("$. I hope", 1)[0].strip()
|
||||
elif "boxed" in pred_str:
|
||||
ans = pred_str.split("boxed")[-1]
|
||||
if len(ans) == 0:
|
||||
return ""
|
||||
elif ans[0] == "{":
|
||||
stack = 1
|
||||
a = ""
|
||||
for c in ans[1:]:
|
||||
if c == "{":
|
||||
stack += 1
|
||||
a += c
|
||||
elif c == "}":
|
||||
stack -= 1
|
||||
if stack == 0:
|
||||
break
|
||||
a += c
|
||||
else:
|
||||
a += c
|
||||
else:
|
||||
a = ans.split("$")[0].strip()
|
||||
pred = a
|
||||
elif "he answer is" in pred_str:
|
||||
pred = pred_str.split("he answer is")[-1].strip()
|
||||
elif "final answer is" in pred_str:
|
||||
pred = pred_str.split("final answer is")[-1].strip()
|
||||
elif "答案是" in pred_str:
|
||||
# Handle Chinese few-shot multiple choice problem answer extraction
|
||||
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
|
||||
else: # use the last number
|
||||
if use_last_number:
|
||||
pattern = "-?\d*\.?\d+"
|
||||
pred = re.findall(pattern, pred_str.replace(",", ""))
|
||||
if len(pred) >= 1:
|
||||
pred = pred[-1]
|
||||
else:
|
||||
pred = ""
|
||||
else:
|
||||
pred = ""
|
||||
|
||||
# choice answer
|
||||
if data_name in ["sat_math", "aqua"] or "mmlu" in data_name:
|
||||
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
||||
if tmp:
|
||||
pred = tmp[-1]
|
||||
else:
|
||||
pred = pred.strip().strip(".")
|
||||
|
||||
# multiple line
|
||||
# pred = pred.split("\n")[0]
|
||||
pred = re.sub(r"\n\s*", "", pred)
|
||||
if pred != "" and pred[0] == ":":
|
||||
pred = pred[1:]
|
||||
if pred != "" and pred[-1] == ".":
|
||||
pred = pred[:-1]
|
||||
if pred != "" and pred[-1] == "/":
|
||||
pred = pred[:-1]
|
||||
pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"])
|
||||
return pred
|
||||
|
||||
|
||||
def str_to_pmatrix(input_str):
|
||||
input_str = input_str.strip()
|
||||
matrix_str = re.findall(r"\{.*,.*\}", input_str)
|
||||
pmatrix_list = []
|
||||
|
||||
for m in matrix_str:
|
||||
m = m.strip("{}")
|
||||
pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
|
||||
pmatrix_list.append(pmatrix)
|
||||
|
||||
return ", ".join(pmatrix_list)
|
||||
|
||||
|
||||
def parse_digits(num):
|
||||
num = regex.sub(",", "", str(num))
|
||||
try:
|
||||
return float(num)
|
||||
except:
|
||||
if num.endswith("%"):
|
||||
num = num[:-1]
|
||||
if num.endswith("\\"):
|
||||
num = num[:-1]
|
||||
try:
|
||||
return float(num) / 100
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_answer(answer):
|
||||
pos = answer.find("\\boxed{")
|
||||
if pos == -1:
|
||||
return []
|
||||
return [get_box(answer[pos:])] + get_answer(answer[pos + 1 :])
|
||||
def is_digit(num):
|
||||
# paired with parse_digits
|
||||
return parse_digits(num) is not None
|
||||
|
||||
|
||||
def choice_answer_clean(pred: str):
|
||||
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
|
||||
# Clean the answer based on the dataset
|
||||
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
||||
if tmp:
|
||||
pred = tmp
|
||||
else:
|
||||
pred = [pred.strip().strip(".")]
|
||||
pred = pred[-1]
|
||||
# Remove the period at the end, again!
|
||||
pred = pred.rstrip(".").rstrip("/")
|
||||
return pred
|
||||
|
||||
|
||||
def numeric_equal(prediction: float, reference: float):
|
||||
from math import isclose
|
||||
|
||||
# Note that relative tolerance has significant impact
|
||||
# on the result of the synthesized GSM-Hard dataset
|
||||
# if reference.is_integer():
|
||||
# return isclose(reference, round(prediction), abs_tol=1e-4)
|
||||
# else:
|
||||
# prediction = round(prediction, len(str(reference).split(".")[-1]))
|
||||
return isclose(reference, prediction, rel_tol=1e-4)
|
||||
|
||||
|
||||
def symbolic_equal_process(a, b, output_queue):
|
||||
result = symbolic_equal(a, b)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
def math_equal(
|
||||
prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
include_percentage: bool = True,
|
||||
is_close: bool = True,
|
||||
timeout: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Exact match of math if and only if:
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
# print("Judge:", prediction, reference)
|
||||
if prediction is None or reference is None:
|
||||
return False
|
||||
if str(prediction.strip().lower()) == str(reference.strip().lower()):
|
||||
return True
|
||||
if (
|
||||
reference in ["A", "B", "C", "D", "E"]
|
||||
and choice_answer_clean(prediction) == reference
|
||||
):
|
||||
return True
|
||||
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction) and is_digit(reference):
|
||||
prediction = parse_digits(prediction)
|
||||
reference = parse_digits(reference)
|
||||
# number questions
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if numeric_equal(prediction, item):
|
||||
return True
|
||||
else:
|
||||
if item == prediction:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# 2. symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
## pmatrix (amps)
|
||||
if "pmatrix" in prediction and not "pmatrix" in reference:
|
||||
reference = str_to_pmatrix(reference)
|
||||
|
||||
## deal with [], (), {}
|
||||
pred_str, ref_str = prediction, reference
|
||||
if (
|
||||
prediction.startswith("[")
|
||||
and prediction.endswith("]")
|
||||
and not reference.startswith("(")
|
||||
) or (
|
||||
prediction.startswith("(")
|
||||
and prediction.endswith(")")
|
||||
and not reference.startswith("[")
|
||||
):
|
||||
pred_str = pred_str.strip("[]()")
|
||||
ref_str = ref_str.strip("[]()")
|
||||
for s in ["{", "}", "(", ")"]:
|
||||
ref_str = ref_str.replace(s, "")
|
||||
pred_str = pred_str.replace(s, "")
|
||||
if pred_str.lower() == ref_str.lower():
|
||||
return True
|
||||
|
||||
## [a, b] vs. [c, d], return a==c and b==d
|
||||
if (
|
||||
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
|
||||
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
|
||||
):
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if all(
|
||||
[
|
||||
math_equal(
|
||||
pred_parts[i], ref_parts[i], include_percentage, is_close
|
||||
)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
return True
|
||||
if (
|
||||
(
|
||||
prediction.startswith("\\begin{pmatrix}")
|
||||
or prediction.startswith("\\begin{bmatrix}")
|
||||
)
|
||||
and (
|
||||
prediction.endswith("\\end{pmatrix}")
|
||||
or prediction.endswith("\\end{bmatrix}")
|
||||
)
|
||||
and (
|
||||
reference.startswith("\\begin{pmatrix}")
|
||||
or reference.startswith("\\begin{bmatrix}")
|
||||
)
|
||||
and (
|
||||
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
|
||||
)
|
||||
):
|
||||
pred_lines = [
|
||||
line.strip()
|
||||
for line in prediction[
|
||||
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
||||
].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
ref_lines = [
|
||||
line.strip()
|
||||
for line in reference[
|
||||
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
||||
].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
matched = True
|
||||
if len(pred_lines) == len(ref_lines):
|
||||
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
||||
pred_parts = pred_line.split("&")
|
||||
ref_parts = ref_line.split("&")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if not all(
|
||||
[
|
||||
math_equal(
|
||||
pred_parts[i],
|
||||
ref_parts[i],
|
||||
include_percentage,
|
||||
is_close,
|
||||
)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
matched = False
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if not matched:
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if matched:
|
||||
return True
|
||||
|
||||
if prediction.count("=") == 1 and reference.count("=") == 1:
|
||||
pred = prediction.split("=")
|
||||
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
||||
ref = reference.split("=")
|
||||
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
||||
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
||||
return True
|
||||
elif (
|
||||
prediction.count("=") == 1
|
||||
and len(prediction.split("=")[0].strip()) <= 2
|
||||
and "=" not in reference
|
||||
):
|
||||
if math_equal(
|
||||
prediction.split("=")[1], reference, include_percentage, is_close
|
||||
):
|
||||
return True
|
||||
elif (
|
||||
reference.count("=") == 1
|
||||
and len(reference.split("=")[0].strip()) <= 2
|
||||
and "=" not in prediction
|
||||
):
|
||||
if math_equal(
|
||||
prediction, reference.split("=")[1], include_percentage, is_close
|
||||
):
|
||||
return True
|
||||
|
||||
# symbolic equal with sympy
|
||||
if timeout:
|
||||
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
||||
return True
|
||||
else:
|
||||
if symbolic_equal(prediction, reference):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def call_with_timeout(func, *args, timeout=3, **kwargs):
|
||||
output_queue = multiprocessing.Queue()
|
||||
process_args = args + (output_queue,)
|
||||
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join(timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join()
|
||||
return False
|
||||
|
||||
return output_queue.get()
|
||||
|
||||
|
||||
def math_equal_process(param):
|
||||
return math_equal(param[-2], param[-1])
|
||||
|
||||
|
||||
def symbolic_equal(a, b):
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr, latex2sympy]:
|
||||
try:
|
||||
return f(s.replace("\\\\", "\\"))
|
||||
except:
|
||||
try:
|
||||
return f(s)
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
# direct equal
|
||||
try:
|
||||
if str(a) == str(b) or a == b:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# simplify equal
|
||||
try:
|
||||
if a.equals(b) or simplify(a - b) == 0:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# equation equal
|
||||
try:
|
||||
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if numeric_equal(float(N(a)), float(N(b))):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# matrix
|
||||
try:
|
||||
# if a and b are matrix
|
||||
if a.shape == b.shape:
|
||||
_a = a.applyfunc(lambda x: round(x, 3))
|
||||
_b = b.applyfunc(lambda x: round(x, 3))
|
||||
if _a.equals(_b):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def process_results(answer, solution):
|
||||
|
||||
try:
|
||||
extracted_answer = extract_answer(answer, "math", use_last_number=False)
|
||||
extracted_solution = extract_answer(solution, "math", use_last_number=True)
|
||||
|
||||
# if extract_answer.strip() == "":
|
||||
# print (answer)
|
||||
# raise
|
||||
if extracted_answer is None or extracted_answer.strip() in ["None", "none", ""]:
|
||||
retval = 0
|
||||
elif extracted_solution is None or extracted_solution.strip() in [
|
||||
"None",
|
||||
"none",
|
||||
"",
|
||||
]:
|
||||
retval = 0
|
||||
elif math_equal(extracted_answer, extracted_solution, timeout=False):
|
||||
# elif call_with_timeout(math_equal, extracted_answer, extracted_solution):
|
||||
retval = 1
|
||||
else:
|
||||
retval = 0
|
||||
|
||||
return retval, (extracted_answer, extracted_solution)
|
||||
except:
|
||||
return 0, ("None", "None")
|
||||
|
||||
|
||||
def process_results_process(a, b, output_queue):
|
||||
result = process_results(a, b)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
def verify_math_solution(answer: str, solution: str):
|
||||
# answer is generated by the model, solution is the ground truth
|
||||
tmp = call_with_timeout(
|
||||
process_results_process,
|
||||
answer,
|
||||
solution,
|
||||
)
|
||||
if isinstance(tmp, bool):
|
||||
return 0
|
||||
return tmp[0]
|
||||
|
||||
|
||||
def loadJson(dataDir):
|
||||
|
@ -49,56 +816,10 @@ def loadJson(dataDir):
|
|||
def parse_line(id2info, prompt_str, generated, query_id):
|
||||
info = id2info[query_id.split("@idx:")[0]]
|
||||
|
||||
tmp_id = str(uuid.uuid4())
|
||||
with open(f"/tmp/{tmp_id}-input.jsonl", "w", encoding="utf-8") as f:
|
||||
for cur_solution in info["solutions"]:
|
||||
f.write(json.dumps({"answer": generated, "solution": cur_solution}) + "\n")
|
||||
|
||||
venv_python = "/sympy/bin/python3"
|
||||
if not os.path.exists(venv_python):
|
||||
venv_python = "sympy/bin/python3"
|
||||
if not os.path.exists(venv_python):
|
||||
venv_python = sys.executable
|
||||
# logger.info(f"math verify working dir: `{os.getcwd()}`")
|
||||
pro = subprocess.Popen(
|
||||
" ".join(
|
||||
[
|
||||
venv_python,
|
||||
"math_verify_utils_qwen.py",
|
||||
"--tmp_id",
|
||||
tmp_id,
|
||||
]
|
||||
),
|
||||
shell=True,
|
||||
preexec_fn=os.setsid,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=sys.stdout,
|
||||
)
|
||||
pro.wait()
|
||||
try:
|
||||
os.killpg(os.getpgid(pro.pid), signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
label = 0
|
||||
try:
|
||||
with open(f"/tmp/{tmp_id}-output.jsonl", "r") as f:
|
||||
for line in f.readlines():
|
||||
output_data = json.loads(line)
|
||||
label = output_data["retval"] or label
|
||||
except FileNotFoundError as e:
|
||||
# The subprocess may fail to parse the input (maybe due to reaching the maximum recursion length)
|
||||
# We just return 0 for the reward.
|
||||
logger.warning(
|
||||
f"Failed to parse: query_id `{query_id}`, prompt `{prompt_str}`, seq `{generated}`. Set 0 reward."
|
||||
)
|
||||
label = 0
|
||||
finally:
|
||||
if os.path.exists(f"/tmp/{tmp_id}-input.jsonl"):
|
||||
os.remove(f"/tmp/{tmp_id}-input.jsonl")
|
||||
if os.path.exists(f"/tmp/{tmp_id}-output.jsonl"):
|
||||
os.remove(f"/tmp/{tmp_id}-output.jsonl")
|
||||
return label
|
||||
for sol in info["solutions"]:
|
||||
label = label or verify_math_solution(generated, sol)
|
||||
return label
|
||||
|
||||
|
||||
def parse_lines_in_parallel(
|
||||
|
@ -112,83 +833,24 @@ def parse_lines_in_parallel(
|
|||
len(generateds),
|
||||
len(query_ids),
|
||||
)
|
||||
bs = len(query_ids)
|
||||
mbs = (bs + max_workers - 1) // max_workers
|
||||
|
||||
tmp_ids = []
|
||||
all_query_indices = []
|
||||
for i in range(max_workers):
|
||||
tmp_id = str(uuid.uuid4())
|
||||
query_indices = []
|
||||
s = slice(i * mbs, (i + 1) * mbs)
|
||||
offset = i * mbs
|
||||
with open(f"/tmp/{tmp_id}-input.jsonl", "w", encoding="utf-8") as f:
|
||||
for idx, (query_id, generated) in enumerate(
|
||||
zip(query_ids[s], generateds[s])
|
||||
):
|
||||
info = id2info[query_id.split("@idx:")[0]]
|
||||
for cur_solution in info["solutions"]:
|
||||
f.write(
|
||||
json.dumps({"answer": generated, "solution": cur_solution})
|
||||
+ "\n"
|
||||
)
|
||||
query_indices.append(idx + offset)
|
||||
tmp_ids.append(tmp_id)
|
||||
all_query_indices.append(query_indices)
|
||||
all_jobs = []
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
for qid, gen in zip(query_ids, generateds):
|
||||
info = id2info[qid.split("@idx:")[0]]
|
||||
jobs = []
|
||||
for sol in info["solutions"]:
|
||||
job = executor.submit(verify_math_solution, gen, sol)
|
||||
jobs.append(job)
|
||||
all_jobs.append(jobs)
|
||||
|
||||
venv_python = "/sympy/bin/python3"
|
||||
if not os.path.exists(venv_python):
|
||||
venv_python = "sympy/bin/python3"
|
||||
if not os.path.exists(venv_python):
|
||||
venv_python = sys.executable
|
||||
# logger.info(f"math verify working dir: `{os.getcwd()}`")
|
||||
procs = []
|
||||
for tmp_id in tmp_ids:
|
||||
pro = subprocess.Popen(
|
||||
" ".join(
|
||||
[
|
||||
venv_python,
|
||||
"math_verify_utils_qwen.py",
|
||||
"--tmp_id",
|
||||
tmp_id,
|
||||
# "--check_xml_format",
|
||||
# "True" if check_xml_format else "False",
|
||||
]
|
||||
),
|
||||
shell=True,
|
||||
preexec_fn=os.setsid,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=sys.stdout,
|
||||
)
|
||||
procs.append(pro)
|
||||
for pro in procs:
|
||||
try:
|
||||
pro.wait()
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
os.killpg(os.getpgid(pro.pid), signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
labels = [0 for _ in query_ids]
|
||||
for i, (tmp_id, query_indices) in enumerate(zip(tmp_ids, all_query_indices)):
|
||||
try:
|
||||
with open(f"/tmp/{tmp_id}-output.jsonl", "r") as f:
|
||||
for _ansidx, line in enumerate(f.readlines()):
|
||||
output_data = json.loads(line)
|
||||
labels[query_indices[_ansidx]] = (
|
||||
output_data["retval"] or labels[query_indices[_ansidx]]
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
# The subprocess may fail to parse the input (maybe due to reaching the maximum recursion length)
|
||||
# We just return 0 for the reward.
|
||||
logger.warning(f"Failed to parse generated answers. Set 0 reward.")
|
||||
finally:
|
||||
if os.path.exists(f"/tmp/{tmp_id}-input.jsonl"):
|
||||
os.remove(f"/tmp/{tmp_id}-input.jsonl")
|
||||
if os.path.exists(f"/tmp/{tmp_id}-output.jsonl"):
|
||||
os.remove(f"/tmp/{tmp_id}-output.jsonl")
|
||||
labels = []
|
||||
for jobs in all_jobs:
|
||||
label = 0
|
||||
for job in as_completed(jobs):
|
||||
x = job.result()
|
||||
label = label or x
|
||||
labels.append(label)
|
||||
return labels
|
||||
|
||||
|
||||
|
|
|
@ -19,10 +19,9 @@ blosc
|
|||
colorama
|
||||
colorlog
|
||||
einops
|
||||
hydra-core
|
||||
hydra-core==1.4.0.dev1
|
||||
matplotlib
|
||||
numba
|
||||
omegaconf
|
||||
packaging
|
||||
pandas
|
||||
pybind11>=2.10.0
|
||||
|
@ -64,3 +63,9 @@ jupyter-book
|
|||
uvloop>=0.21.0
|
||||
uvicorn>=0.34.2
|
||||
fastapi>=0.115.12
|
||||
regex
|
||||
python_dateutil
|
||||
word2number
|
||||
Pebble
|
||||
timeout-decorator
|
||||
prettytable
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,27 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from realhf.impl.dataset.math_parser import verify_math_solution
|
||||
|
||||
|
||||
def test_verify_math_solution():
|
||||
# The generated file is too large. Only upload sampled cases to git.
|
||||
path = Path("/storage/testing/dataset/math_generated.jsonl")
|
||||
line_numbers = np.random.choice(int(1e4), 10)
|
||||
if not os.path.exists(path):
|
||||
path = Path(__file__).parent / "math_answers_sample_cases.jsonl"
|
||||
line_numbers = list(range(10))
|
||||
with open(path, "r") as f:
|
||||
for i, line in enumerate(f):
|
||||
if i not in line_numbers:
|
||||
continue
|
||||
line = json.loads(line)
|
||||
for ans, r in zip(line["generateds"], line["rewards"]):
|
||||
label = 0
|
||||
for sol in line["solutions"]:
|
||||
label = label or verify_math_solution(ans, sol)
|
||||
assert (label - 0.5) * 10 == r
|
|
@ -6,6 +6,7 @@ import signal
|
|||
import sys
|
||||
import threading
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
import psutil
|
||||
|
@ -128,9 +129,11 @@ def _run_experiment(exp_cfg, expr_name, trial_name):
|
|||
REAL_RECOVER_RUN="0",
|
||||
REAL_SAVE_RECOVER_STATES="1",
|
||||
)
|
||||
git_path = Path(__file__).parent.parent / ".git"
|
||||
runtime_env = {
|
||||
"env_vars": env_vars,
|
||||
"working_dir": os.getcwd(),
|
||||
"excludes": [str(git_path)],
|
||||
}
|
||||
logger.info(f"Ray workers runtime env: {runtime_env}")
|
||||
ray_log_path = exp_cfg.ray_temp_path
|
||||
|
|
Loading…
Reference in New Issue