mirror of https://github.com/inclusionAI/AReaL
150 lines
4.2 KiB
Python
150 lines
4.2 KiB
Python
import concurrent.futures
|
|
import json
|
|
import os
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from collections import defaultdict
|
|
from io import StringIO
|
|
from typing import Dict, List
|
|
|
|
from functioncall.base.utils import load_jsonl, logger
|
|
from realhf.base import logging
|
|
|
|
SINGLE_CASE_EXEC_TIMEOUT = 6
|
|
|
|
logger = logging.getLogger("function call")
|
|
|
|
|
|
def capture_stdout(code):
|
|
original_stdout = sys.stdout
|
|
fake_stdout = StringIO()
|
|
|
|
try:
|
|
sys.stdout = fake_stdout
|
|
exec(code, {"__builtins__": __builtins__})
|
|
except Exception as e:
|
|
return f"error: {str(e)}, traceback: {traceback.format_exc()}"
|
|
finally:
|
|
sys.stdout = original_stdout
|
|
return fake_stdout.getvalue()
|
|
|
|
|
|
def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
|
|
|
|
tmp_id = str(uuid.uuid4())
|
|
input_data = {
|
|
"sample": problem,
|
|
"test": generation,
|
|
"debug": debug,
|
|
"timeout": timeout,
|
|
}
|
|
with open(f"/tmp/{tmp_id}-input.json", "w") as temp_file:
|
|
json.dump(input_data, temp_file)
|
|
start_time = time.time()
|
|
|
|
venv_python = "python3"
|
|
pro = subprocess.Popen(
|
|
" ".join(
|
|
[
|
|
venv_python,
|
|
"functioncall/code/function/testing_util.py",
|
|
"--tmp_id",
|
|
tmp_id,
|
|
]
|
|
),
|
|
shell=True,
|
|
preexec_fn=os.setsid,
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.DEVNULL,
|
|
)
|
|
try:
|
|
pro.wait(200)
|
|
except Exception as e:
|
|
pass
|
|
try:
|
|
os.killpg(os.getpgid(pro.pid), signal.SIGTERM)
|
|
except ProcessLookupError:
|
|
pass
|
|
|
|
result = {"result": [False], "info": {}}
|
|
try:
|
|
with open(f"/tmp/{tmp_id}-output.json", "r") as f:
|
|
result = json.load(f)
|
|
except FileNotFoundError as e:
|
|
logger.warning(
|
|
f"{problem['query_id']}: Failed to parse generated answers. FileNotFoundError. Set 0 reward."
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"{problem['query_id']}: Failed to parse generated answers. {e}. Set 0 reward."
|
|
)
|
|
finally:
|
|
if os.path.exists(f"/tmp/{tmp_id}-input.json"):
|
|
os.remove(f"/tmp/{tmp_id}-input.json")
|
|
if os.path.exists(f"/tmp/{tmp_id}-output.json"):
|
|
os.remove(f"/tmp/{tmp_id}-output.json")
|
|
|
|
execution_time = time.time() - start_time
|
|
logger.info(
|
|
f'[call_verify] query_id: {problem["query_id"]}, start_time: {str(start_time)}, Time elapsed: {execution_time * 1000:.0f} ms'
|
|
)
|
|
return result["result"], result["info"]
|
|
|
|
|
|
def code_verify(id2info, generateds, query_ids, debug=False):
|
|
assert len(generateds) == len(query_ids)
|
|
problems = [id2info[qid] for qid in query_ids]
|
|
|
|
final_results = []
|
|
|
|
infer_args = []
|
|
for query_id, generated, problem in zip(query_ids, generateds, problems):
|
|
infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT))
|
|
|
|
run_results = []
|
|
num_process = max(1, os.cpu_count() // 8)
|
|
with concurrent.futures.ProcessPoolExecutor(num_process) as executor:
|
|
run_results = executor.map(call_verify, *zip(*infer_args))
|
|
|
|
for run_result in run_results:
|
|
curr_res, metadata = run_result
|
|
if any(x != True for x in curr_res):
|
|
final_results.append(0)
|
|
else:
|
|
final_results.append(1)
|
|
|
|
return final_results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
data_list = load_jsonl("functioncall/test/test_success_dataset.jsonl")
|
|
id2info = defaultdict(dict)
|
|
for item in data_list:
|
|
id2info[item["query_id"]] = item
|
|
|
|
def create_test_params(count=10):
|
|
query_ids = []
|
|
generateds = []
|
|
cnt = 0
|
|
|
|
for d in data_list:
|
|
if cnt >= count:
|
|
break
|
|
if not d["solutions"] or d["query_id"] not in id2info:
|
|
continue
|
|
query_ids.append(d["query_id"])
|
|
generateds.extend(d["solutions"])
|
|
cnt += 1
|
|
|
|
return generateds, query_ids
|
|
|
|
generateds, query_ids = create_test_params(100)
|
|
scale = 1
|
|
print(f"generateds:, query_ids:{query_ids}")
|
|
result = code_verify(id2info, generateds * scale, query_ids * scale)
|
|
print(result)
|