mirror of https://github.com/inclusionAI/AReaL
768 lines
22 KiB
Python
768 lines
22 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
|
|
import random
|
|
import re
|
|
from typing import Any, Dict, Iterable, List, TypeVar, Union
|
|
|
|
import regex
|
|
import sympy
|
|
from latex2sympy2 import latex2sympy
|
|
from word2number import w2n
|
|
|
|
# from utils import *
|
|
|
|
|
|
def _fix_fracs(string):
|
|
substrs = string.split("\\frac")
|
|
new_str = substrs[0]
|
|
if len(substrs) > 1:
|
|
substrs = substrs[1:]
|
|
for substr in substrs:
|
|
new_str += "\\frac"
|
|
if len(substr) > 0 and substr[0] == "{":
|
|
new_str += substr
|
|
else:
|
|
try:
|
|
assert len(substr) >= 2
|
|
except:
|
|
return string
|
|
a = substr[0]
|
|
b = substr[1]
|
|
if b != "{":
|
|
if len(substr) > 2:
|
|
post_substr = substr[2:]
|
|
new_str += "{" + a + "}{" + b + "}" + post_substr
|
|
else:
|
|
new_str += "{" + a + "}{" + b + "}"
|
|
else:
|
|
if len(substr) > 2:
|
|
post_substr = substr[2:]
|
|
new_str += "{" + a + "}" + b + post_substr
|
|
else:
|
|
new_str += "{" + a + "}" + b
|
|
string = new_str
|
|
return string
|
|
|
|
|
|
def _fix_a_slash_b(string):
|
|
if len(string.split("/")) != 2:
|
|
return string
|
|
a = string.split("/")[0]
|
|
b = string.split("/")[1]
|
|
try:
|
|
if "sqrt" not in a:
|
|
a = int(a)
|
|
if "sqrt" not in b:
|
|
b = int(b)
|
|
assert string == "{}/{}".format(a, b)
|
|
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
|
return new_string
|
|
except:
|
|
return string
|
|
|
|
|
|
def _fix_sqrt(string):
|
|
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
|
|
return _string
|
|
|
|
|
|
def convert_word_number(text: str) -> str:
|
|
try:
|
|
text = str(w2n.word_to_num(text))
|
|
except:
|
|
pass
|
|
return text
|
|
|
|
|
|
# units mainly from MathQA
|
|
unit_texts = [
|
|
"east",
|
|
"degree",
|
|
"mph",
|
|
"kmph",
|
|
"ft",
|
|
"m sqaure",
|
|
" m east",
|
|
"sq m",
|
|
"deg",
|
|
"mile",
|
|
"q .",
|
|
"monkey",
|
|
"prime",
|
|
"ratio",
|
|
"profit of rs",
|
|
"rd",
|
|
"o",
|
|
"gm",
|
|
"p . m",
|
|
"lb",
|
|
"tile",
|
|
"per",
|
|
"dm",
|
|
"lt",
|
|
"gain",
|
|
"ab",
|
|
"way",
|
|
"west",
|
|
"a .",
|
|
"b .",
|
|
"c .",
|
|
"d .",
|
|
"e .",
|
|
"f .",
|
|
"g .",
|
|
"h .",
|
|
"t",
|
|
"a",
|
|
"h",
|
|
"no change",
|
|
"men",
|
|
"soldier",
|
|
"pie",
|
|
"bc",
|
|
"excess",
|
|
"st",
|
|
"inches",
|
|
"noon",
|
|
"percent",
|
|
"by",
|
|
"gal",
|
|
"kmh",
|
|
"c",
|
|
"acre",
|
|
"rise",
|
|
"a . m",
|
|
"th",
|
|
"π r 2",
|
|
"sq",
|
|
"mark",
|
|
"l",
|
|
"toy",
|
|
"coin",
|
|
"sq . m",
|
|
"gallon",
|
|
"° f",
|
|
"profit",
|
|
"minw",
|
|
"yr",
|
|
"women",
|
|
"feet",
|
|
"am",
|
|
"pm",
|
|
"hr",
|
|
"cu cm",
|
|
"square",
|
|
"v â € ™",
|
|
"are",
|
|
"rupee",
|
|
"rounds",
|
|
"cubic",
|
|
"cc",
|
|
"mtr",
|
|
"s",
|
|
"ohm",
|
|
"number",
|
|
"kmph",
|
|
"day",
|
|
"hour",
|
|
"minute",
|
|
"min",
|
|
"second",
|
|
"man",
|
|
"woman",
|
|
"sec",
|
|
"cube",
|
|
"mt",
|
|
"sq inch",
|
|
"mp",
|
|
"∏ cm ³",
|
|
"hectare",
|
|
"more",
|
|
"sec",
|
|
"unit",
|
|
"cu . m",
|
|
"cm 2",
|
|
"rs .",
|
|
"rs",
|
|
"kg",
|
|
"g",
|
|
"month",
|
|
"km",
|
|
"m",
|
|
"cm",
|
|
"mm",
|
|
"apple",
|
|
"liter",
|
|
"loss",
|
|
"yard",
|
|
"pure",
|
|
"year",
|
|
"increase",
|
|
"decrease",
|
|
"d",
|
|
"less",
|
|
"Surface",
|
|
"litre",
|
|
"pi sq m",
|
|
"s .",
|
|
"metre",
|
|
"meter",
|
|
"inch",
|
|
]
|
|
|
|
unit_texts.extend([t + "s" for t in unit_texts])
|
|
|
|
|
|
def strip_string(string, skip_unit=False):
|
|
string = str(string).strip()
|
|
# linebreaks
|
|
string = string.replace("\n", "")
|
|
|
|
# right "."
|
|
string = string.rstrip(".")
|
|
|
|
# remove inverse spaces
|
|
# replace \\ with \
|
|
string = string.replace("\\!", "")
|
|
# string = string.replace("\\ ", "")
|
|
# string = string.replace("\\\\", "\\")
|
|
|
|
# matrix
|
|
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
|
|
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
|
|
string = string.replace("bmatrix", "pmatrix")
|
|
|
|
# replace tfrac and dfrac with frac
|
|
string = string.replace("tfrac", "frac")
|
|
string = string.replace("dfrac", "frac")
|
|
string = (
|
|
string.replace("\\neq", "\\ne")
|
|
.replace("\\leq", "\\le")
|
|
.replace("\\geq", "\\ge")
|
|
)
|
|
|
|
# remove \left and \right
|
|
string = string.replace("\\left", "")
|
|
string = string.replace("\\right", "")
|
|
string = string.replace("\\{", "{")
|
|
string = string.replace("\\}", "}")
|
|
|
|
# Remove unit: miles, dollars if after is not none
|
|
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
|
if _string != "" and _string != string:
|
|
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
|
string = _string
|
|
|
|
if not skip_unit:
|
|
# Remove unit: texts
|
|
for _ in range(2):
|
|
for unit_text in unit_texts:
|
|
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
|
|
# the suffix should be either the end of the string or a non-alphanumeric character
|
|
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
|
|
if _string != "":
|
|
string = _string
|
|
|
|
# Remove circ (degrees)
|
|
string = string.replace("^{\\circ}", "")
|
|
string = string.replace("^\\circ", "")
|
|
|
|
# remove dollar signs
|
|
string = string.replace("\\$", "")
|
|
string = string.replace("$", "")
|
|
string = string.replace("\\(", "").replace("\\)", "")
|
|
|
|
# convert word number to digit
|
|
string = convert_word_number(string)
|
|
|
|
# replace "\\text{...}" to "..."
|
|
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
|
|
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
|
|
string = string.replace(key, "")
|
|
string = string.replace("\\emptyset", r"{}")
|
|
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
|
|
|
|
# remove percentage
|
|
string = string.replace("\\%", "")
|
|
string = string.replace("\%", "")
|
|
string = string.replace("%", "")
|
|
|
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
|
string = string.replace(" .", " 0.")
|
|
string = string.replace("{.", "{0.")
|
|
|
|
# cdot
|
|
# string = string.replace("\\cdot", "")
|
|
if (
|
|
string.startswith("{")
|
|
and string.endswith("}")
|
|
and string.isalnum()
|
|
or string.startswith("(")
|
|
and string.endswith(")")
|
|
and string.isalnum()
|
|
or string.startswith("[")
|
|
and string.endswith("]")
|
|
and string.isalnum()
|
|
):
|
|
string = string[1:-1]
|
|
|
|
# inf
|
|
string = string.replace("infinity", "\\infty")
|
|
if "\\infty" not in string:
|
|
string = string.replace("inf", "\\infty")
|
|
string = string.replace("+\\inity", "\\infty")
|
|
|
|
# and
|
|
string = string.replace("and", "")
|
|
string = string.replace("\\mathbf", "")
|
|
|
|
# use regex to remove \mbox{...}
|
|
string = re.sub(r"\\mbox{.*?}", "", string)
|
|
|
|
# quote
|
|
string.replace("'", "")
|
|
string.replace('"', "")
|
|
|
|
# i, j
|
|
if "j" in string and "i" not in string:
|
|
string = string.replace("j", "i")
|
|
|
|
# replace a.000b where b is not number or b is end, with ab, use regex
|
|
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
|
|
string = re.sub(r"(\d+)\.0*$", r"\1", string)
|
|
|
|
# if empty, return empty string
|
|
if len(string) == 0:
|
|
return string
|
|
if string[0] == ".":
|
|
string = "0" + string
|
|
|
|
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
|
if len(string.split("=")) == 2:
|
|
if len(string.split("=")[0]) <= 2:
|
|
string = string.split("=")[1]
|
|
|
|
string = _fix_sqrt(string)
|
|
string = string.replace(" ", "")
|
|
|
|
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
|
string = _fix_fracs(string)
|
|
|
|
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
|
string = _fix_a_slash_b(string)
|
|
|
|
return string
|
|
|
|
|
|
def extract_multi_choice_answer(pred_str):
|
|
# TODO: SFT models
|
|
if "Problem:" in pred_str:
|
|
pred_str = pred_str.split("Problem:", 1)[0]
|
|
pred_str = pred_str.replace("choice is", "answer is")
|
|
patt = regex.search(r"answer is \(?(?P<ans>[abcde])\)?", pred_str.lower())
|
|
if patt is not None:
|
|
return patt.group("ans").upper()
|
|
return "placeholder"
|
|
|
|
|
|
direct_answer_trigger_for_fewshot = ("choice is", "answer is")
|
|
|
|
|
|
def choice_answer_clean(pred: str):
|
|
pred = pred.strip("\n")
|
|
|
|
# Determine if this is ICL, if so, use \n\n to split the first chunk.
|
|
ICL = False
|
|
for trigger in direct_answer_trigger_for_fewshot:
|
|
if pred.count(trigger) > 1:
|
|
ICL = True
|
|
if ICL:
|
|
pred = pred.split("\n\n")[0]
|
|
|
|
# Split the trigger to find the answer.
|
|
preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred)
|
|
if len(preds) > 1:
|
|
answer_flag = True
|
|
pred = preds[-1]
|
|
else:
|
|
answer_flag = False
|
|
|
|
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
|
|
|
|
# Clean the answer based on the dataset
|
|
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
|
if tmp:
|
|
pred = tmp
|
|
else:
|
|
pred = [pred.strip().strip(".")]
|
|
|
|
if len(pred) == 0:
|
|
pred = ""
|
|
else:
|
|
if answer_flag:
|
|
# choose the first element in list ...
|
|
pred = pred[0]
|
|
else:
|
|
# choose the last e
|
|
pred = pred[-1]
|
|
|
|
# Remove the period at the end, again!
|
|
pred = pred.rstrip(".").rstrip("/")
|
|
|
|
return pred
|
|
|
|
|
|
def find_box(pred_str: str):
|
|
ans = pred_str.split("boxed")[-1]
|
|
if not ans:
|
|
return ""
|
|
if ans[0] == "{":
|
|
stack = 1
|
|
a = ""
|
|
for c in ans[1:]:
|
|
if c == "{":
|
|
stack += 1
|
|
a += c
|
|
elif c == "}":
|
|
stack -= 1
|
|
if stack == 0:
|
|
break
|
|
a += c
|
|
else:
|
|
a += c
|
|
else:
|
|
a = ans.split("$")[0].strip()
|
|
return a
|
|
|
|
|
|
def clean_units(pred_str: str):
|
|
"""Clean the units in the number."""
|
|
|
|
def convert_pi_to_number(code_string):
|
|
code_string = code_string.replace("\\pi", "π")
|
|
# Replace \pi or π not preceded by a digit or } with 3.14
|
|
code_string = re.sub(r"(?<![\d}])\\?π", "3.14", code_string)
|
|
# Replace instances where π is preceded by a digit but without a multiplication symbol, e.g., "3π" -> "3*3.14"
|
|
code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string)
|
|
# Handle cases where π is within braces or followed by a multiplication symbol
|
|
# This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14"
|
|
code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string)
|
|
code_string = re.sub(r"\*(\\?π)", "*3.14", code_string)
|
|
return code_string
|
|
|
|
pred_str = convert_pi_to_number(pred_str)
|
|
pred_str = pred_str.replace("%", "/100")
|
|
pred_str = pred_str.replace("$", "")
|
|
pred_str = pred_str.replace("¥", "")
|
|
pred_str = pred_str.replace("°C", "")
|
|
pred_str = pred_str.replace(" C", "")
|
|
pred_str = pred_str.replace("°", "")
|
|
return pred_str
|
|
|
|
|
|
def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
|
|
if any([option in pred.lower() for option in ["yes", "true"]]):
|
|
pred = "True"
|
|
elif any([option in pred.lower() for option in ["no", "false"]]):
|
|
pred = "False"
|
|
elif any(
|
|
[
|
|
option in pred.lower()
|
|
for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]
|
|
]
|
|
):
|
|
pass
|
|
else:
|
|
# Some of the models somehow get used to boxed output from pre-training
|
|
if "boxed" in pred:
|
|
pred = find_box(pred)
|
|
|
|
if answer_flag:
|
|
# Extract the numbers out of the string
|
|
pred = pred.split("=")[-1].strip()
|
|
pred = clean_units(pred)
|
|
try:
|
|
tmp = str(latex2sympy(pred))
|
|
pred = str(eval(tmp))
|
|
except Exception:
|
|
if re.match(r"-?[\d\.]+\s\D+$", pred):
|
|
pred = pred.split(" ")[0]
|
|
elif re.match(r"-?[\d\.]+\s[^\s]+$", pred):
|
|
pred = pred.split(" ")[0]
|
|
else:
|
|
# desparate search over the last number
|
|
preds = re.findall(r"-?\d*\.?\d+", pred)
|
|
if len(preds) >= 1:
|
|
pred = preds[-1]
|
|
else:
|
|
pred = ""
|
|
|
|
return pred
|
|
|
|
|
|
def extract_answer(pred_str, data_name, use_last_number=True):
|
|
pred_str = pred_str.replace("\u043a\u0438", "")
|
|
if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]:
|
|
# TODO check multiple choice
|
|
return choice_answer_clean(pred_str)
|
|
|
|
if "final answer is $" in pred_str and "$. I hope" in pred_str:
|
|
# minerva_math
|
|
tmp = pred_str.split("final answer is $", 1)[1]
|
|
pred = tmp.split("$. I hope", 1)[0].strip()
|
|
elif "boxed" in pred_str:
|
|
ans = pred_str.split("boxed")[-1]
|
|
if len(ans) == 0:
|
|
return ""
|
|
elif ans[0] == "{":
|
|
stack = 1
|
|
a = ""
|
|
for c in ans[1:]:
|
|
if c == "{":
|
|
stack += 1
|
|
a += c
|
|
elif c == "}":
|
|
stack -= 1
|
|
if stack == 0:
|
|
break
|
|
a += c
|
|
else:
|
|
a += c
|
|
else:
|
|
a = ans.split("$")[0].strip()
|
|
pred = a
|
|
elif "he answer is" in pred_str:
|
|
pred = pred_str.split("he answer is")[-1].strip()
|
|
elif "final answer is" in pred_str:
|
|
pred = pred_str.split("final answer is")[-1].strip()
|
|
elif "答案是" in pred_str:
|
|
# Handle Chinese few-shot multiple choice problem answer extraction
|
|
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
|
|
else: # use the last number
|
|
if use_last_number:
|
|
pattern = "-?\d*\.?\d+"
|
|
pred = re.findall(pattern, pred_str.replace(",", ""))
|
|
if len(pred) >= 1:
|
|
pred = pred[-1]
|
|
else:
|
|
pred = ""
|
|
else:
|
|
pred = ""
|
|
|
|
# choice answer
|
|
if data_name in ["sat_math", "aqua"] or "mmlu" in data_name:
|
|
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
|
if tmp:
|
|
pred = tmp[-1]
|
|
else:
|
|
pred = pred.strip().strip(".")
|
|
|
|
# multiple line
|
|
# pred = pred.split("\n")[0]
|
|
pred = re.sub(r"\n\s*", "", pred)
|
|
if pred != "" and pred[0] == ":":
|
|
pred = pred[1:]
|
|
if pred != "" and pred[-1] == ".":
|
|
pred = pred[:-1]
|
|
if pred != "" and pred[-1] == "/":
|
|
pred = pred[:-1]
|
|
pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"])
|
|
return pred
|
|
|
|
|
|
STRIP_EXCEPTIONS = ["carp_en", "minerva_math"]
|
|
|
|
|
|
def parse_ground_truth(example: Dict[str, Any], data_name):
|
|
if "gt_cot" in example and "gt" in example:
|
|
if data_name in ["math"]:
|
|
gt_ans = extract_answer(example["gt_cot"], data_name)
|
|
elif data_name in STRIP_EXCEPTIONS:
|
|
gt_ans = example["gt"]
|
|
else:
|
|
gt_ans = strip_string(example["gt"])
|
|
return example["gt_cot"], gt_ans
|
|
|
|
# parse ground truth
|
|
if data_name in ["math", "minerva_math", "math_500"]:
|
|
gt_cot = example["solution"]
|
|
gt_ans = extract_answer(gt_cot, data_name)
|
|
elif data_name == "gsm8k":
|
|
gt_cot, gt_ans = example["answer"].split("####")
|
|
elif data_name == "svamp":
|
|
gt_cot, gt_ans = example["Equation"], example["Answer"]
|
|
elif data_name == "asdiv":
|
|
gt_cot = example["formula"]
|
|
gt_ans = re.sub(r"\(.*?\)", "", example["answer"])
|
|
elif data_name == "mawps":
|
|
gt_cot, gt_ans = None, example["target"]
|
|
elif data_name == "tabmwp":
|
|
gt_cot = example["solution"]
|
|
gt_ans = example["answer"]
|
|
if example["ans_type"] in ["integer_number", "decimal_number"]:
|
|
if "/" in gt_ans:
|
|
gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1])
|
|
elif "," in gt_ans:
|
|
gt_ans = float(gt_ans.replace(",", ""))
|
|
elif "%" in gt_ans:
|
|
gt_ans = float(gt_ans.split("%")[0]) / 100
|
|
else:
|
|
gt_ans = float(gt_ans)
|
|
elif data_name == "carp_en":
|
|
gt_cot, gt_ans = example["steps"], example["answer"]
|
|
elif data_name == "mmlu_stem":
|
|
abcd = "ABCD"
|
|
gt_cot, gt_ans = None, abcd[example["answer"]]
|
|
elif data_name == "sat_math":
|
|
gt_cot, gt_ans = None, example["Answer"]
|
|
elif data_name == "aqua":
|
|
gt_cot, gt_ans = None, example["correct"]
|
|
elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]:
|
|
gt_cot, gt_ans = None, example["answer"].replace("$", "").strip()
|
|
elif data_name == "gaokao_math_qa":
|
|
gt_cot, gt_ans = None, example["label"]
|
|
elif data_name in ["gaokao2024_mix", "cn_middle_school"]:
|
|
if len(example["choice_answer"]) > 0:
|
|
gt_cot, gt_ans = None, example["choice_answer"]
|
|
else:
|
|
gt_cot, gt_ans = None, example["answer"]
|
|
elif data_name == "olympiadbench":
|
|
gt_cot, gt_ans = None, example["final_answer"][0].strip("$")
|
|
elif data_name in [
|
|
"aime24",
|
|
"amc23",
|
|
"cmath",
|
|
"gaokao2024_I",
|
|
"gaokao2024_II",
|
|
"imo2024",
|
|
]:
|
|
gt_cot, gt_ans = None, example["answer"]
|
|
elif data_name.startswith("train"): # train_amc_aime
|
|
gt_cot, gt_ans = None, example["final_answer"]
|
|
else:
|
|
raise NotImplementedError(f"`{data_name}`")
|
|
# post process
|
|
gt_cot = str(gt_cot).strip()
|
|
if data_name not in STRIP_EXCEPTIONS:
|
|
gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en")
|
|
else:
|
|
gt_ans = (
|
|
gt_ans.replace("\\neq", "\\ne")
|
|
.replace("\\leq", "\\le")
|
|
.replace("\\geq", "\\ge")
|
|
)
|
|
return gt_cot, gt_ans
|
|
|
|
|
|
def parse_question(example, data_name):
|
|
question = ""
|
|
if data_name == "asdiv":
|
|
question = f"{example['body'].strip()} {example['question'].strip()}"
|
|
elif data_name == "svamp":
|
|
body = example["Body"].strip()
|
|
if not body.endswith("."):
|
|
body = body + "."
|
|
question = f'{body} {example["Question"].strip()}'
|
|
elif data_name == "tabmwp":
|
|
title_str = (
|
|
f'regarding "{example["table_title"]}" ' if example["table_title"] else ""
|
|
)
|
|
question = f"Read the following table {title_str}and answer a question:\n"
|
|
question += f'{example["table"]}\n{example["question"]}'
|
|
if example["choices"]:
|
|
question += (
|
|
f' Please select from the following options: {example["choices"]}'
|
|
)
|
|
elif data_name == "carp_en":
|
|
question = example["content"]
|
|
elif data_name == "mmlu_stem":
|
|
options = example["choices"]
|
|
assert len(options) == 4
|
|
for i, (label, option) in enumerate(zip("ABCD", options)):
|
|
options[i] = f"({label}) {str(option).strip()}"
|
|
options = " ".join(options)
|
|
# question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
|
|
question = f"{example['question'].strip()}\nAnswer Choices: {options}"
|
|
elif data_name == "sat_math":
|
|
options = example["options"].strip()
|
|
assert "A" == options[0]
|
|
options = "(" + options
|
|
for ch in "BCD":
|
|
if f" {ch}) " in options:
|
|
options = regex.sub(f" {ch}\) ", f" ({ch}) ", options)
|
|
# question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
|
|
question = f"{example['question'].strip()}\nAnswer Choices: {options}"
|
|
elif "aqua" in data_name:
|
|
options = example["options"]
|
|
choice = "(" + "(".join(options)
|
|
choice = choice.replace("(", " (").replace(")", ") ").strip()
|
|
choice = "\nAnswer Choices: " + choice
|
|
question = example["question"].strip() + choice
|
|
elif data_name == "gaokao_math_qa":
|
|
options_dict = example["options"]
|
|
options = []
|
|
for key in options_dict:
|
|
options.append(f"({key}) {options_dict[key]}")
|
|
options = " ".join(options)
|
|
question = f"{example['question'].strip()}\n选项: {options}"
|
|
elif data_name.startswith("train"): # train_amc_aime:
|
|
question = example["prompt"]
|
|
else:
|
|
for key in ["question", "problem", "Question", "input"]:
|
|
if key in example:
|
|
question = example[key]
|
|
break
|
|
# assert question != ""
|
|
# Yes or No question
|
|
_, gt_ans = parse_ground_truth(example, data_name)
|
|
if isinstance(gt_ans, str):
|
|
gt_lower = gt_ans.lower()
|
|
if gt_lower in ["true", "false"]:
|
|
question += " (True or False)"
|
|
if gt_lower in ["yes", "no"]:
|
|
question += " (Yes or No)"
|
|
return question.strip()
|
|
|
|
|
|
# def run_execute(executor, result, prompt_type, data_name, execute=False):
|
|
# if not result or result == "error":
|
|
# return None, None
|
|
# report = None
|
|
|
|
# if "program_only" in prompt_type:
|
|
# prediction = extract_program_output(result)
|
|
# elif prompt_type in ["pot", "pal"] and execute:
|
|
# code = extract_program(result)
|
|
# prediction, report = executor.apply(code)
|
|
# else:
|
|
# prediction = extract_answer(result, data_name)
|
|
|
|
# # prediction = strip_string(prediction, skip_unit=data_name == "carp_en")
|
|
# prediction = strip_string(prediction, skip_unit=data_name in STRIP_EXCEPTIONS)
|
|
# return prediction, report
|
|
|
|
|
|
def _test_extract_answer():
|
|
text = """
|
|
This is still not equal to $0$, so we must have made another mistake.
|
|
|
|
When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that:
|
|
|
|
\[\frac{386}{64} - 7 = \frac{386}{64} - \frac{7 \cdot 64}{1 \cdot 64} = \frac{386 - 448}{64} = \frac{-62}{64}.\]
|
|
|
|
This is still not equal to $0$, so we must have made another mistake.
|
|
|
|
When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that:
|
|
|
|
\[\frac{386}{64} 025
|
|
"""
|
|
print(extract_answer(text, "math-oai", use_last_number=True))
|
|
print(choice_answer_clean("\mathrm{(D)\}1,008,016"))
|
|
# should output a dict
|
|
|
|
print(extract_answer("The product of the inner terms is \\($ 15x $\\).", "math"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_test_extract_answer()
|