PullRequest: 26 Set functioncall disabled by default for opensource

Merge branch fix/math of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/26

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
君末 2025-03-11 19:39:31 +08:00
commit cdf13ff852
10 changed files with 40 additions and 21 deletions

View File

@ -56,6 +56,7 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
unset CLUSTER_SPEC_PATH
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
REAL_CODE_METADATA_PATH=${REAL_CODE_METADATA_PATH} \
FUNCTIONCALL_SERVICE_DOMAIN="" \
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
python3 -m realhf.apps.quickstart ppo-code \
mode=$MODE \

View File

@ -77,7 +77,7 @@ python3 -m realhf.apps.quickstart ppo-math \
actor.vllm.enforce_eager=False \
actor.vllm.max_seq_len_to_capture=${MAX_SEQ_LEN_TO_CAPTURE} \
actor.vllm.max_num_seqs=${MAX_NUM_SEQS} \
actor.vllm.gpu_memory_utilization=1 \
actor.vllm.gpu_memory_utilization=0.85 \
actor.vllm.swap_space=64 \
critic.type._class=$MODEL_FAMILY \
critic.type.is_critic=True \

View File

@ -14,10 +14,9 @@ from functioncall.base import logging
logger = logging.getLogger("Functioncall")
FUNCTION_CALLING_SERVICE_DOMAIN = os.getenv(
"FUNCTION_CALLING_SERVICE_DOMAIN",
# "http://faas-api-service.hcsfaas-function-calling.svc.sa128.alipay.com:8080",
"http://110.75.255.101:8080", # xvip
FUNCTIONCALL_SERVICE_DOMAIN = os.getenv(
"FUNCTIONCALL_SERVICE_DOMAIN",
"",
)
@ -33,16 +32,16 @@ async def async_invoke_function(
timeout: aiohttp.ClientTimeout,
payload: Dict[str, Any] = None,
max_retries: int = 3,
initial_retry_interval: float = 0.5,
initial_retry_interval: float = 0.1,
max_retry_interval: float = 10.0,
):
if payload is None:
payload = {}
url = f"{FUNCTION_CALLING_SERVICE_DOMAIN}/hapis/faas.hcs.io/v1/functions/{function_name}/invoke"
url = f"{FUNCTIONCALL_SERVICE_DOMAIN}/hapis/faas.hcs.io/v1/functions/{function_name}/invoke"
params = {"invocationType": "RequestResponse"}
retries = 0
while retries <= max_retries:
while retries < max_retries:
try:
async with session.post(
url,
@ -64,12 +63,12 @@ async def async_invoke_function(
except asyncio.TimeoutError as e:
logger.warning(
f"Request timeout after {timeout}s (attempt {retries + 1}/{max_retries}). "
f"URL: {url}, Headers: {session.headers}"
f"Request timeout after {timeout}s, URL: {url}, Headers: {session.headers}, payload: {payload}"
)
break
except Exception as e:
logger.error(f"Async invocation failed on attempt {retries + 1}:{str(e)}")
logger.error(f"Async invocation failed on attempt {retries + 1}:{str(e)}, URL: {url}, Headers: {session.headers}")
retries += 1
if retries > max_retries:

View File

@ -25,7 +25,7 @@ def handle(event, context):
answers = event.get("answers", "")
solutions = event.get("solutions", "")
print(f"answers:{answers}, solutions:{solutions}\n")
#print(f"math payload:{event}\n")
# answers and solutions are json lists, and call process_results then collect result into a list
if isinstance(answers, str):
answers = json.loads(answers)

View File

@ -22,8 +22,7 @@ def loadJson(dataDir):
id2info = None
def math_verify(generateds: List, query_ids: List, batch_size=20, timeout=60) -> List:
start_time = time.time()
def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) -> List:
global id2info
if id2info is None:
id2info = loadJson(
@ -49,14 +48,17 @@ def math_verify(generateds: List, query_ids: List, batch_size=20, timeout=60) ->
query_indices.append(idx)
# Process in batches
start_time = time.time()
batch_args_list = []
for i in range(0, len(parameters), batch_size):
current_batch = parameters[i : i + batch_size]
answers, solutions, indices = zip(*parameters[i : i + batch_size])
batch_args = {
"answers": [g for g, _, _ in current_batch],
"solutions": [s for _, s, _ in current_batch],
"answers": list(answers),
"solutions": list(solutions),
"query_ids": [query_ids[i] for i in indices],
}
#print(batch_args)
batch_args_list.append(batch_args)
results_batch = batch_function_call(batch_args_list, "python_math", timeout)
@ -65,15 +67,15 @@ def math_verify(generateds: List, query_ids: List, batch_size=20, timeout=60) ->
# Map results back to original indices
index = 0
for batch_idx, results in enumerate(results_batch):
query_index = query_indices[index]
if not isinstance(results, list) or len(results) == 0:
index += len(batch_args_list[batch_idx]["answers"])
logger.warning(
f"Invalid functioncall math results: {results}, batch index:{batch_idx}, query index: {index}."
f"Invalid functioncall math results: {results}, batch index:{batch_idx}, query index: {query_index}, params: {batch_args_list[batch_idx]['answers']}."
)
continue
for result in results:
query_index = query_indices[index]
if (
isinstance(result, list)
and len(result) > 0
@ -100,7 +102,7 @@ if __name__ == "__main__":
"answer": "\\boxed{-\\frac{2}{3}}",
}
start_time = time.time()
batch_size = 10000
batch_size = 10
result = math_verify(
[sample["answer"]] * batch_size, [sample["query_id"] for _ in range(batch_size)]
)

View File

@ -162,6 +162,7 @@ def main_start(args, recover_count: int = 0):
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
)
for k, v in BASE_ENVIRONS.items():
os.environ[k] = v

View File

@ -276,6 +276,14 @@ class PPOCODEConfig(CommonExperimentConfig):
"Dataset json path REAL_CODE_METADATA_PATH does not exist."
)
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
if not (domain.startswith("http://") and ":" in domain):
raise RuntimeError(
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
)
# interfaces
actor_interface = ModelInterfaceAbstraction(
"ppo_actor",

View File

@ -277,6 +277,13 @@ class PPOMATHConfig(CommonExperimentConfig):
raise RuntimeError(
"Dataset json path REAL_MATH_METADATA_PATH does not exist."
)
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
if domain and (not (domain.startswith("http://") and ":" in domain)):
raise RuntimeError(
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
)
# interfaces
actor_interface = ModelInterfaceAbstraction(

View File

@ -33,7 +33,7 @@ from realhf.impl.model.nn.real_llm_api import ReaLModel
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
ENABLE_FUNCTION_CALL = os.getenv("ENABLE_FUNCTION_CALL", True)
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel

View File

@ -620,6 +620,7 @@ class RayController:
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
)
runtime_env = {
"env_vars": env_vars,