mirror of https://github.com/inclusionAI/AReaL
648 lines
22 KiB
Python
648 lines
22 KiB
Python
import argparse
|
|
import itertools
|
|
import json
|
|
import os
|
|
import random
|
|
import time
|
|
from datetime import datetime
|
|
from parser import *
|
|
|
|
import ray
|
|
import torch
|
|
from data_loader import load_data
|
|
from evaluate import evaluate
|
|
from model_utils import generate_completions, load_hf_lm_and_tokenizer
|
|
from python_executor import PythonExecutor
|
|
from rm_maj_eval import group_pred
|
|
from tqdm import tqdm
|
|
from trajectory import *
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from utils import construct_prompt, load_jsonl, save_jsonl, set_seed
|
|
from vllm import LLM, SamplingParams
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--data_names", default="gsm8k,math", type=str)
|
|
parser.add_argument("--data_dir", default="./data", type=str)
|
|
parser.add_argument("--model_name_or_path", default="gpt-4", type=str)
|
|
parser.add_argument("--output_dir", default="./output", type=str)
|
|
parser.add_argument("--prompt_type", default="tool-integrated", type=str)
|
|
parser.add_argument("--split", default="test", type=str)
|
|
parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data
|
|
parser.add_argument("--seed", default=0, type=int)
|
|
parser.add_argument("--start", default=0, type=int)
|
|
parser.add_argument("--end", default=-1, type=int)
|
|
parser.add_argument("--temperature", default=0, type=float)
|
|
parser.add_argument("--n_sampling", default=1, type=int)
|
|
parser.add_argument("--top_p", default=1, type=float)
|
|
parser.add_argument("--top_k", default=-1, type=int)
|
|
parser.add_argument("--max_tokens_per_call", default=4096, type=int)
|
|
parser.add_argument("--shuffle", action="store_true")
|
|
parser.add_argument("--use_vllm", action="store_true")
|
|
parser.add_argument("--save_outputs", action="store_true")
|
|
parser.add_argument("--overwrite", action="store_true")
|
|
parser.add_argument("--use_safetensors", action="store_true")
|
|
parser.add_argument("--num_shots", type=int, default=0)
|
|
parser.add_argument(
|
|
"--apply_chat_template",
|
|
action="store_true",
|
|
help="Apply chat template to prompt.",
|
|
)
|
|
parser.add_argument("--tensor_parallel_size", type=int, default=1)
|
|
parser.add_argument(
|
|
"--adapt_few_shot",
|
|
action="store_true",
|
|
help="Few shot for multiple-choice questions, zero shot for others.",
|
|
)
|
|
args = parser.parse_args()
|
|
args.top_p = (
|
|
1 if args.temperature == 0 else args.top_p
|
|
) # top_p must be 1 when using greedy sampling (vllm)
|
|
args.top_k = -1 if args.temperature == 0 else args.top_k
|
|
|
|
available_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
|
args.data_parallel_size = len(available_gpus) // args.tensor_parallel_size
|
|
return args
|
|
|
|
|
|
def eval_maj_k_metrics(data_list, k=8):
|
|
print(f"evaluating maj@{k}")
|
|
|
|
count, right_count = 0, 0
|
|
for sample in data_list:
|
|
assert len(sample["score"]) >= k, sample
|
|
groups, majority_pred = group_pred(
|
|
sample["pred"][:k], strip=False, use_symbol=False
|
|
)
|
|
idx = groups[majority_pred][0]
|
|
right_count += sample["score"][idx]
|
|
count += 1
|
|
|
|
task_acc = right_count / count * 100
|
|
print(f"maj@{k}: {task_acc:.1f}")
|
|
return task_acc
|
|
|
|
|
|
def pass_at_k(data_list, k=8):
|
|
print(f"evaluating pass@{k}")
|
|
|
|
count, right_count = 0, 0
|
|
for sample in data_list:
|
|
assert len(sample["score"]) >= k, sample
|
|
# assert sum(sample['score']) in [0, k]
|
|
for x in sample["score"][:k]:
|
|
if x:
|
|
right_count += 1
|
|
break
|
|
count += 1
|
|
|
|
task_acc = right_count / count * 100
|
|
print(f"pass@{k}: {task_acc:.1f}")
|
|
return task_acc
|
|
|
|
|
|
def pass_at_1(data_list, k=1):
|
|
assert k == 1
|
|
print(f"evaluating pass@{k}")
|
|
|
|
count, right_count = 0, 0
|
|
for sample in data_list:
|
|
assert len(sample["score"]) >= k, sample
|
|
for x in sample["score"]:
|
|
right_count += x
|
|
count += 1
|
|
|
|
print("pass_at_1", right_count, count)
|
|
task_acc = right_count / count * 100
|
|
print(f"pass@{k}: {task_acc:.1f}")
|
|
return task_acc
|
|
|
|
|
|
def pass_at_k_v2(data_list, k=8):
|
|
|
|
def cur_pass_k(n, c, k):
|
|
if n - c < k:
|
|
return 1.0
|
|
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
|
|
|
# count, right_count = 0, 0
|
|
pass_at_ks = []
|
|
for sample in data_list:
|
|
assert len(sample["score"]) >= k, sample
|
|
correct = sum(sample["score"])
|
|
pass_at_ks.append(cur_pass_k(len(sample["score"]), correct, k))
|
|
|
|
return np.mean(pass_at_ks) * 100
|
|
|
|
|
|
def generate_in_parallel(requests, model_args, sampling_params, data_parallel_size):
|
|
@ray.remote
|
|
def run_inference_one_model(
|
|
model_args: dict, sampling_params, requests, cuda_visisble_devices
|
|
):
|
|
os.environ["VLLM_LOGGING_LEVEL"] = "INFO"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
|
[str(x) for x in cuda_visisble_devices]
|
|
)
|
|
# print("OS.ENVIRON", json.dumps({x: os.environ[x] for x in sorted(dict(os.environ))}))
|
|
llm = LLM(**model_args)
|
|
return llm.generate(requests, sampling_params=sampling_params)
|
|
|
|
# print("OUT_OS_ENVIRON", json.dumps({x: os.environ[x] for x in sorted(dict(os.environ))}))
|
|
all_cuda_visisble_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
|
requests = [list(x) for x in distribute(data_parallel_size, requests)]
|
|
inputs = (
|
|
(model_args, sampling_params, req, cuda_visisble_devices)
|
|
for req, cuda_visisble_devices in zip(
|
|
requests, np.array_split(all_cuda_visisble_devices, data_parallel_size)
|
|
)
|
|
)
|
|
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
|
|
results = ray.get(object_refs)
|
|
ray.shutdown()
|
|
return undistribute(results)
|
|
|
|
|
|
# from more_itertools import distribute
|
|
from itertools import islice, tee
|
|
|
|
|
|
def distribute(n, iterable):
|
|
"""Distribute the items from *iterable* among *n* smaller iterables.
|
|
|
|
>>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
|
|
>>> list(group_1)
|
|
[1, 3, 5]
|
|
>>> list(group_2)
|
|
[2, 4, 6]
|
|
|
|
If the length of *iterable* is not evenly divisible by *n*, then the
|
|
length of the returned iterables will not be identical:
|
|
|
|
>>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
|
|
>>> [list(c) for c in children]
|
|
[[1, 4, 7], [2, 5], [3, 6]]
|
|
|
|
If the length of *iterable* is smaller than *n*, then the last returned
|
|
iterables will be empty:
|
|
|
|
>>> children = distribute(5, [1, 2, 3])
|
|
>>> [list(c) for c in children]
|
|
[[1], [2], [3], [], []]
|
|
|
|
This function uses :func:`itertools.tee` and may require significant
|
|
storage.
|
|
|
|
If you need the order items in the smaller iterables to match the
|
|
original iterable, see :func:`divide`.
|
|
"""
|
|
if n < 1:
|
|
raise ValueError("n must be at least 1")
|
|
|
|
children = tee(iterable, n)
|
|
return [islice(it, index, None, n) for index, it in enumerate(children)]
|
|
|
|
|
|
def undistribute(iterable):
|
|
"""
|
|
Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute .
|
|
|
|
Re-interleaves results that have been split using more_itertools.distribute:
|
|
>>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
|
|
>>> list(group_1)
|
|
[1, 3, 5]
|
|
>>> list(group_2)
|
|
[2, 4, 6]
|
|
>>> undistribute([group_1, group_2])
|
|
[1, 2, 3, 4, 5, 6]
|
|
|
|
Handles non-uniform component lengths:
|
|
|
|
>>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
|
|
>>> [list(c) for c in children]
|
|
[[1, 4, 7], [2, 5], [3, 6]]
|
|
>>> undistribute(children)
|
|
[1, 2, 3, 4, 5, 6, 7]
|
|
|
|
Also handles when some iterables are empty:
|
|
|
|
>>> children = distribute(5, [1, 2, 3])
|
|
>>> [list(c) for c in children]
|
|
[[1], [2], [3], [], []]
|
|
>>> undistribute(children)
|
|
[1, 2, 3]
|
|
|
|
"""
|
|
|
|
return [
|
|
x
|
|
for x in itertools.chain.from_iterable(
|
|
itertools.zip_longest(*[list(x) for x in iterable])
|
|
)
|
|
if x is not None
|
|
]
|
|
|
|
|
|
def prepare_data(data_name, args):
|
|
examples = load_data(data_name, args.split, args.data_dir)
|
|
|
|
# sample `num_test_sample` from dataset
|
|
if args.num_test_sample > 0:
|
|
# examples = random.sample(examples, min(args.num_test_sample, len(examples)))
|
|
examples = examples[: args.num_test_sample]
|
|
|
|
# shuffle
|
|
if args.shuffle:
|
|
raise RuntimeError
|
|
random.seed(datetime.now().timestamp())
|
|
random.shuffle(examples)
|
|
|
|
# select start and end
|
|
examples = examples[args.start : len(examples) if args.end == -1 else args.end]
|
|
|
|
# get out_file name
|
|
dt_string = datetime.now().strftime("%m-%d_%H-%M")
|
|
model_name = "/".join(args.model_name_or_path.split("/")[-2:])
|
|
out_file_prefix = f"{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature:.1f}_topp{args.top_p:.2f}_topk{args.top_k}"
|
|
output_dir = args.output_dir
|
|
if not os.path.exists(output_dir):
|
|
output_dir = f"outputs/{output_dir}"
|
|
|
|
eval_dir = f"math_eval_{args.max_tokens_per_call}"
|
|
|
|
out_file = f"{output_dir}/{eval_dir}/{data_name}/{out_file_prefix}_s{args.start}_e{args.end}_n{args.n_sampling}.jsonl"
|
|
os.makedirs(f"{output_dir}/{eval_dir}/{data_name}", exist_ok=True)
|
|
|
|
# load all processed samples
|
|
processed_samples = []
|
|
if not args.overwrite:
|
|
processed_files = [
|
|
f
|
|
for f in os.listdir(f"{output_dir}/{eval_dir}/{data_name}/")
|
|
# if f.endswith(".jsonl") and f.startswith(out_file_prefix)
|
|
if f == os.path.basename(out_file)
|
|
]
|
|
for f in processed_files:
|
|
processed_samples.extend(
|
|
list(load_jsonl(f"{output_dir}/{eval_dir}/{data_name}/{f}"))
|
|
)
|
|
|
|
# dedepulicate
|
|
processed_samples = {sample["idx"]: sample for sample in processed_samples}
|
|
processed_idxs = list(processed_samples.keys())
|
|
processed_samples = list(processed_samples.values())
|
|
examples = [example for example in examples if example["idx"] not in processed_idxs]
|
|
return examples, processed_samples, out_file
|
|
|
|
|
|
def setup(args):
|
|
# load model
|
|
available_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
|
if args.use_vllm:
|
|
# breakpoint()
|
|
if args.data_parallel_size <= 1:
|
|
llm = LLM(
|
|
model=args.model_name_or_path,
|
|
tensor_parallel_size=args.tensor_parallel_size,
|
|
# distributed_executor_backend="ray",
|
|
trust_remote_code=True,
|
|
enforce_eager=True,
|
|
# dtype="float16",
|
|
disable_custom_all_reduce=True,
|
|
disable_sliding_window=True,
|
|
max_model_len=32768,
|
|
enable_chunked_prefill=False,
|
|
swap_space=32,
|
|
)
|
|
else:
|
|
print(
|
|
f"TP = {args.tensor_parallel_size}\n",
|
|
f"DP = {args.data_parallel_size}",
|
|
)
|
|
llm = dict(
|
|
model=args.model_name_or_path,
|
|
tensor_parallel_size=args.tensor_parallel_size,
|
|
# distributed_executor_backend="ray",
|
|
trust_remote_code=True,
|
|
enforce_eager=True,
|
|
# dtype="float16",
|
|
disable_custom_all_reduce=True,
|
|
disable_sliding_window=True,
|
|
max_model_len=32768,
|
|
enable_chunked_prefill=False,
|
|
swap_space=32,
|
|
)
|
|
tokenizer = None
|
|
if args.apply_chat_template:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
args.model_name_or_path, trust_remote_code=True
|
|
)
|
|
else:
|
|
llm, tokenizer = load_hf_lm_and_tokenizer(
|
|
model_name_or_path=args.model_name_or_path,
|
|
load_in_half=True,
|
|
use_fast_tokenizer=True,
|
|
use_safetensors=args.use_safetensors,
|
|
)
|
|
|
|
# infer & eval
|
|
data_list = args.data_names.split(",")
|
|
results = []
|
|
for data_name in data_list:
|
|
results.append(main(llm, tokenizer, data_name, args))
|
|
|
|
# add "avg" result to data_list and results
|
|
data_list.append("avg")
|
|
results.append(
|
|
{
|
|
"acc": sum([result["acc"] for result in results]) / len(results),
|
|
}
|
|
)
|
|
|
|
# print all results
|
|
pad = max([len(data_name) for data_name in data_list])
|
|
print("\t".join(data_name.ljust(pad, " ") for data_name in data_list))
|
|
print("\t".join([f"{result['acc']:.1f}".ljust(pad, " ") for result in results]))
|
|
|
|
|
|
def is_multi_choice(answer):
|
|
for c in answer:
|
|
if c not in ["A", "B", "C", "D", "E"]:
|
|
return False
|
|
return True
|
|
|
|
|
|
def main(llm, tokenizer, data_name, args):
|
|
examples, processed_samples, out_file = prepare_data(data_name, args)
|
|
print("=" * 50)
|
|
print("data:", data_name, " ,remain samples:", len(examples))
|
|
if len(examples) > 0:
|
|
print(examples[0])
|
|
|
|
# init python executor
|
|
if "pal" in args.prompt_type:
|
|
executor = PythonExecutor(get_answer_expr="solution()")
|
|
else:
|
|
executor = PythonExecutor(get_answer_from_stdout=True)
|
|
|
|
samples = []
|
|
for example in tqdm(examples, total=len(examples)):
|
|
idx = example["idx"]
|
|
|
|
# parse question and answer
|
|
example["question"] = parse_question(example, data_name)
|
|
if example["question"] == "":
|
|
continue
|
|
gt_cot, gt_ans = parse_ground_truth(example, data_name)
|
|
example["gt_ans"] = gt_ans
|
|
full_prompt = construct_prompt(example, data_name, args)
|
|
|
|
if idx == args.start:
|
|
print(full_prompt)
|
|
|
|
sample = {
|
|
"idx": idx,
|
|
"question": example["question"],
|
|
"gt_cot": gt_cot,
|
|
"gt": gt_ans,
|
|
"prompt": full_prompt,
|
|
}
|
|
|
|
# add remain fields
|
|
for key in [
|
|
"level",
|
|
"type",
|
|
"unit",
|
|
"solution_type",
|
|
"choices",
|
|
"solution",
|
|
"ques_type",
|
|
"ans_type",
|
|
"answer_type",
|
|
"dataset",
|
|
"subfield",
|
|
"filed",
|
|
"theorem",
|
|
"answer",
|
|
]:
|
|
if key in example:
|
|
sample[key] = example[key]
|
|
samples.append(sample)
|
|
|
|
# repeat n times
|
|
input_prompts = [
|
|
sample["prompt"] for sample in samples for _ in range(args.n_sampling)
|
|
]
|
|
if args.apply_chat_template:
|
|
input_prompts = [
|
|
tokenizer.apply_chat_template(
|
|
[{"role": "user", "content": prompt.strip()}],
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
)
|
|
for prompt in input_prompts
|
|
]
|
|
remain_prompts = input_prompts
|
|
remain_prompts = [(i, prompt) for i, prompt in enumerate(remain_prompts)]
|
|
end_prompts = []
|
|
|
|
# max_func_call = 1 if args.prompt_type in ["cot", "pal"] else 4
|
|
max_func_call = 1 if args.prompt_type in ["cot", "pal"] or args.use_vllm else 4
|
|
|
|
stop_words = ["</s>", "<|im_end|>", "<|endoftext|>"]
|
|
|
|
if args.prompt_type in ["cot"]:
|
|
stop_words.append("\n\nQuestion:")
|
|
if args.prompt_type in ["pal", "tool-integrated", "jiuzhang_tora"]:
|
|
stop_words.extend(["\n\n---", "```output"])
|
|
elif args.prompt_type in ["wizard_zs", "platypus_fs"]:
|
|
stop_words.extend(["Instruction", "Response"])
|
|
elif "jiuzhang" in args.prompt_type:
|
|
stop_words.append("\n\n## Question")
|
|
elif "numina" in args.prompt_type:
|
|
stop_words.append("\n### Problem")
|
|
elif "pure" in args.prompt_type:
|
|
stop_words.append("\n\n\n")
|
|
|
|
# start inference
|
|
# measure time use
|
|
start_time = time.time()
|
|
for epoch in range(max_func_call):
|
|
print("-" * 20, "Epoch", epoch)
|
|
current_prompts = remain_prompts
|
|
if len(current_prompts) == 0:
|
|
break
|
|
|
|
# get all outputs
|
|
prompts = [item[1] for item in current_prompts]
|
|
if args.use_vllm:
|
|
sampling_params = SamplingParams(
|
|
temperature=args.temperature,
|
|
seed=args.seed,
|
|
top_p=args.top_p,
|
|
top_k=args.top_k,
|
|
max_tokens=args.max_tokens_per_call,
|
|
n=args.n_sampling,
|
|
stop=stop_words,
|
|
stop_token_ids=(
|
|
[151645, 151643]
|
|
if "qwen2" in args.model_name_or_path.lower()
|
|
else None
|
|
),
|
|
)
|
|
if args.data_parallel_size <= 1:
|
|
|
|
outputs = llm.generate(prompts[:: args.n_sampling], sampling_params)
|
|
else:
|
|
outputs = generate_in_parallel(
|
|
prompts[:: args.n_sampling],
|
|
llm,
|
|
sampling_params,
|
|
args.data_parallel_size,
|
|
)
|
|
|
|
outputs = sorted(
|
|
outputs, key=lambda x: int(x.request_id)
|
|
) # sort outputs by request_id
|
|
outputs = [x.text for output in outputs for x in output.outputs]
|
|
else:
|
|
outputs = generate_completions(
|
|
model=llm,
|
|
tokenizer=tokenizer,
|
|
prompts=prompts,
|
|
max_new_tokens=args.max_tokens_per_call,
|
|
batch_size=16,
|
|
stop_id_sequences=stop_words,
|
|
)
|
|
|
|
assert len(outputs) == len(current_prompts)
|
|
|
|
# process all outputs
|
|
remain_prompts = []
|
|
remain_codes = []
|
|
for (i, query), output in zip(current_prompts, outputs):
|
|
output = output.rstrip()
|
|
query += output
|
|
if args.prompt_type == "pal":
|
|
remain_prompts.append((i, query))
|
|
if "```python" in output:
|
|
output = extract_program(query)
|
|
remain_codes.append(output)
|
|
elif args.prompt_type == "cot":
|
|
end_prompts.append((i, query))
|
|
elif "boxed" not in output and output.endswith("```"):
|
|
program = extract_program(query)
|
|
remain_prompts.append((i, query))
|
|
remain_codes.append(program)
|
|
else:
|
|
end_prompts.append((i, query))
|
|
|
|
# execute the remain prompts
|
|
remain_results = executor.batch_apply(remain_codes)
|
|
for k in range(len(remain_prompts)):
|
|
i, query = remain_prompts[k]
|
|
res, report = remain_results[k]
|
|
exec_result = res if res else report
|
|
if "pal" in args.prompt_type:
|
|
exec_result = "\\boxed{" + exec_result + "}"
|
|
exec_result = f"\n```output\n{exec_result}\n```\n"
|
|
query += exec_result
|
|
# not end
|
|
if epoch == max_func_call - 1:
|
|
query += "\nReach max function call limit."
|
|
remain_prompts[k] = (i, query)
|
|
|
|
# unsolved samples
|
|
print("Unsolved samples:", len(remain_prompts))
|
|
end_prompts.extend(remain_prompts)
|
|
# sort by idx
|
|
end_prompts = sorted(end_prompts, key=lambda x: x[0])
|
|
|
|
# remove input_prompt from end_prompt
|
|
codes = []
|
|
assert len(input_prompts) == len(end_prompts)
|
|
for i in range(len(input_prompts)):
|
|
_, end_prompt = end_prompts[i]
|
|
code = end_prompt.split(input_prompts[i])[-1].strip()
|
|
for stop_word in stop_words:
|
|
if stop_word in code:
|
|
code = code.split(stop_word)[0].strip()
|
|
codes.append(code)
|
|
|
|
# extract preds
|
|
results = [
|
|
run_execute(executor, code, args.prompt_type, data_name) for code in codes
|
|
]
|
|
time_use = time.time() - start_time
|
|
print("time_use", time_use)
|
|
# put results back to examples
|
|
all_samples = []
|
|
for i, sample in enumerate(samples):
|
|
code = codes[i * args.n_sampling : (i + 1) * args.n_sampling]
|
|
result = results[i * args.n_sampling : (i + 1) * args.n_sampling]
|
|
preds = [item[0] for item in result]
|
|
reports = [item[1] for item in result]
|
|
for j in range(len(preds)):
|
|
if sample["gt"] in ["A", "B", "C", "D", "E"] and preds[j] not in [
|
|
"A",
|
|
"B",
|
|
"C",
|
|
"D",
|
|
"E",
|
|
]:
|
|
preds[j] = choice_answer_clean(code[j])
|
|
elif is_multi_choice(sample["gt"]) and not is_multi_choice(preds[j]):
|
|
# remove any non-choice char
|
|
preds[j] = "".join(
|
|
[c for c in preds[j] if c in ["A", "B", "C", "D", "E"]]
|
|
)
|
|
|
|
sample.pop("prompt")
|
|
sample.update({"code": code, "pred": preds, "report": reports})
|
|
all_samples.append(sample)
|
|
|
|
# add processed samples
|
|
all_samples.extend(processed_samples)
|
|
all_samples, result_json = evaluate(
|
|
samples=all_samples,
|
|
data_name=data_name,
|
|
prompt_type=args.prompt_type,
|
|
execute=True,
|
|
)
|
|
|
|
if args.n_sampling > 1:
|
|
result_json[f"maj@{args.n_sampling}"] = eval_maj_k_metrics(
|
|
all_samples, k=args.n_sampling
|
|
)
|
|
result_json[f"pass@{args.n_sampling}"] = pass_at_k_v2(
|
|
all_samples, k=args.n_sampling
|
|
)
|
|
if args.n_sampling > 16:
|
|
result_json[f"pass@16"] = pass_at_k_v2(all_samples, k=16)
|
|
if args.n_sampling > 8:
|
|
result_json[f"pass@8"] = pass_at_k_v2(all_samples, k=8)
|
|
result_json["pass@1"] = pass_at_k_v2(all_samples, k=1)
|
|
result_json["acc"] = result_json["pass@1"]
|
|
|
|
# save outputs
|
|
if len(processed_samples) < len(all_samples) and args.save_outputs:
|
|
save_jsonl(all_samples, out_file)
|
|
|
|
result_json["time_use_in_second"] = time_use
|
|
result_json["time_use_in_minite"] = (
|
|
f"{int(time_use // 60)}:{int(time_use % 60):02d}"
|
|
)
|
|
|
|
with open(
|
|
out_file.replace(".jsonl", f"_{args.prompt_type}_metrics.json"), "w"
|
|
) as f:
|
|
json.dump(result_json, f, indent=4)
|
|
return result_json
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
set_seed(args.seed)
|
|
setup(args)
|