mirror of https://github.com/inclusionAI/AReaL
188 lines
5.8 KiB
Python
188 lines
5.8 KiB
Python
import json
|
|
import os
|
|
import random
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
|
|
from functioncall.base.call import Language, batch_function_call, get_runtime_name
|
|
from functioncall.base.utils import construct_uid, load_jsonl, logger
|
|
|
|
SINGLE_CASE_EXEC_TIMEOUT = 6
|
|
TEST_CASE_BATCH_SIZE = 1
|
|
FUNCTIONCALL_TIMEOUT = 100
|
|
|
|
|
|
def round_up_memory(memory):
|
|
if memory <= 0:
|
|
return 0
|
|
rounded = ((memory + 255) // 256) * 256
|
|
return 0 if rounded > 1024 else rounded
|
|
|
|
|
|
def construct_testcases(
|
|
inputs: list, outputs: list, index: tuple, remote: bool = False, is_ut: bool = False
|
|
) -> dict:
|
|
result = []
|
|
if is_ut:
|
|
return result
|
|
|
|
for i in range(*index):
|
|
input_, output_ = inputs[i].strip(), outputs[i].strip()
|
|
if not remote:
|
|
result.append({"input": input_, "expectedOutput": output_})
|
|
continue
|
|
|
|
oss_basepath = os.getenv("REAL_OSS_TESTCASE_PATH", "")
|
|
if not oss_basepath:
|
|
raise FileNotFoundError(
|
|
"REAL_OSS_TESTCASE_PATH not set. Cannot use FAAS code reward."
|
|
)
|
|
input_url = (
|
|
input_ if input_.startswith("http") else os.path.join(oss_basepath, input_)
|
|
)
|
|
output_url = (
|
|
output_
|
|
if output_.startswith("http")
|
|
else os.path.join(oss_basepath, output_)
|
|
)
|
|
|
|
result.append({"input": input_url, "expectedOutput": output_url})
|
|
return result
|
|
|
|
|
|
def load_problems_with_testcase_batch(
|
|
id2info, query_ids, generateds, timeout_for_testcase, test_case_batch_size
|
|
):
|
|
problem_list = []
|
|
for idx, query_id in enumerate(query_ids):
|
|
problem = id2info[query_id]
|
|
# parse one problem
|
|
language = problem.get("language", "PYTHON").upper()
|
|
timeout = min(
|
|
100, max(0.1, float(problem.get("timeout", timeout_for_testcase)) * 1.5)
|
|
) # [0.1, 100] s
|
|
memory = round_up_memory(problem.get("memory", 0))
|
|
input_output = json.loads(problem["input_output"])
|
|
fn_name = input_output.get("fn_name", "")
|
|
remote = input_output.get("remote", False)
|
|
inputs = input_output.get("inputs", [])
|
|
outputs = input_output.get("outputs", [])
|
|
|
|
assert len(inputs) == len(
|
|
outputs
|
|
), f"Inputs({len(inputs)}) and outputs({len(outputs)}) mismatch for {query_id}"
|
|
|
|
assert (
|
|
language in Language.__members__
|
|
), f"{language} is not a valid Language name"
|
|
|
|
is_ut = len(inputs) == 0
|
|
|
|
# isFastFail means the function call returns immediately as soon as any testcase fails.
|
|
isFastFail = True
|
|
# create batches for testcases
|
|
case_size = 1 if is_ut else len(inputs)
|
|
test_case_batch_size = min(max(1, test_case_batch_size), case_size)
|
|
|
|
for batch_idx in range(0, case_size, test_case_batch_size):
|
|
end_idx = min(case_size, batch_idx + test_case_batch_size)
|
|
testcases = construct_testcases(
|
|
inputs, outputs, (batch_idx, end_idx), remote, is_ut
|
|
)
|
|
|
|
sub_problem = {
|
|
"uid": construct_uid(query_id, batch_idx, end_idx),
|
|
"language": language,
|
|
"runtime": get_runtime_name("", language),
|
|
"code": generateds[idx],
|
|
"entryFunction": fn_name,
|
|
"isFastFail": isFastFail,
|
|
"isRemote": remote,
|
|
"testcases": testcases,
|
|
"timeout": timeout,
|
|
"memory": memory,
|
|
"query_index": idx,
|
|
}
|
|
problem_list.append(sub_problem)
|
|
|
|
return problem_list
|
|
|
|
|
|
def code_verify(
|
|
id2info,
|
|
generateds,
|
|
query_ids,
|
|
timeout=FUNCTIONCALL_TIMEOUT,
|
|
timeout_for_testcase=SINGLE_CASE_EXEC_TIMEOUT,
|
|
test_case_batch_size=TEST_CASE_BATCH_SIZE,
|
|
):
|
|
assert len(generateds) == len(query_ids), (
|
|
len(generateds),
|
|
len(query_ids),
|
|
)
|
|
payload_list = []
|
|
|
|
payload_list = load_problems_with_testcase_batch(
|
|
id2info,
|
|
query_ids,
|
|
generateds,
|
|
timeout_for_testcase,
|
|
test_case_batch_size,
|
|
)
|
|
|
|
logger.info(
|
|
f"code_verify start, request count: {len(payload_list)}, query size: {len(query_ids)}, query_id_0: {query_ids[0]}"
|
|
)
|
|
rsp_list = batch_function_call(payload_list, "code", timeout)
|
|
|
|
results = [1] * len(query_ids) if len(rsp_list) else [0] * len(query_ids)
|
|
for idx, rsp in enumerate(rsp_list):
|
|
query_index = payload_list[idx]["query_index"]
|
|
query_id = query_ids[query_index]
|
|
|
|
value = 0
|
|
if rsp and rsp.get("success", False):
|
|
value = 1
|
|
else:
|
|
logger.debug(
|
|
f'Functioncall code verify not passed, uid: {rsp.get("uid")}, query id: {query_id}, results: {rsp}'
|
|
)
|
|
|
|
results[query_index] = results[query_index] and value
|
|
|
|
logger.info(
|
|
f"code_verify finished, request count: {len(payload_list)}, query count: {len(query_ids)}, result count: {len(results)}"
|
|
)
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
data_list = load_jsonl(
|
|
"/storage/openpsi/data/functioncall/test/test_success_dataset.jsonl"
|
|
)
|
|
id2info = defaultdict(dict)
|
|
for item in data_list:
|
|
id2info[item["query_id"]] = item
|
|
|
|
def create_test_params(count=10):
|
|
query_ids = []
|
|
generateds = []
|
|
cnt = 0
|
|
|
|
for d in data_list:
|
|
if cnt >= count:
|
|
break
|
|
if d["query_id"] not in id2info:
|
|
continue
|
|
query_ids.append(d["query_id"])
|
|
generateds.extend(d["solutions"])
|
|
cnt += 1
|
|
|
|
return generateds, query_ids
|
|
|
|
generateds, query_ids = create_test_params(100)
|
|
scale = 1
|
|
print(f"generateds:, query_ids:{query_ids}")
|
|
result = code_verify(id2info, generateds * scale, query_ids * scale)
|
|
print(result)
|