mirror of https://github.com/inclusionAI/AReaL
PullRequest: 26 Set functioncall disabled by default for opensource
Merge branch fix/math of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/26 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
commit
cdf13ff852
|
@ -56,6 +56,7 @@ TRIAL_NAME="${TRAIN_BATCH_SIZE}x${GROUP_SIZE}-n${NODES}"
|
|||
unset CLUSTER_SPEC_PATH
|
||||
CLUSTER_SPEC_PATH=/storage/ray/cluster_config_on_ray.json \
|
||||
REAL_CODE_METADATA_PATH=${REAL_CODE_METADATA_PATH} \
|
||||
FUNCTIONCALL_SERVICE_DOMAIN="" \
|
||||
REAL_GPU_MEMORY_KILL_THRESHOLD=1 \
|
||||
python3 -m realhf.apps.quickstart ppo-code \
|
||||
mode=$MODE \
|
||||
|
|
|
@ -77,7 +77,7 @@ python3 -m realhf.apps.quickstart ppo-math \
|
|||
actor.vllm.enforce_eager=False \
|
||||
actor.vllm.max_seq_len_to_capture=${MAX_SEQ_LEN_TO_CAPTURE} \
|
||||
actor.vllm.max_num_seqs=${MAX_NUM_SEQS} \
|
||||
actor.vllm.gpu_memory_utilization=1 \
|
||||
actor.vllm.gpu_memory_utilization=0.85 \
|
||||
actor.vllm.swap_space=64 \
|
||||
critic.type._class=$MODEL_FAMILY \
|
||||
critic.type.is_critic=True \
|
||||
|
|
|
@ -14,10 +14,9 @@ from functioncall.base import logging
|
|||
logger = logging.getLogger("Functioncall")
|
||||
|
||||
|
||||
FUNCTION_CALLING_SERVICE_DOMAIN = os.getenv(
|
||||
"FUNCTION_CALLING_SERVICE_DOMAIN",
|
||||
# "http://faas-api-service.hcsfaas-function-calling.svc.sa128.alipay.com:8080",
|
||||
"http://110.75.255.101:8080", # xvip
|
||||
FUNCTIONCALL_SERVICE_DOMAIN = os.getenv(
|
||||
"FUNCTIONCALL_SERVICE_DOMAIN",
|
||||
"",
|
||||
)
|
||||
|
||||
|
||||
|
@ -33,16 +32,16 @@ async def async_invoke_function(
|
|||
timeout: aiohttp.ClientTimeout,
|
||||
payload: Dict[str, Any] = None,
|
||||
max_retries: int = 3,
|
||||
initial_retry_interval: float = 0.5,
|
||||
initial_retry_interval: float = 0.1,
|
||||
max_retry_interval: float = 10.0,
|
||||
):
|
||||
if payload is None:
|
||||
payload = {}
|
||||
url = f"{FUNCTION_CALLING_SERVICE_DOMAIN}/hapis/faas.hcs.io/v1/functions/{function_name}/invoke"
|
||||
url = f"{FUNCTIONCALL_SERVICE_DOMAIN}/hapis/faas.hcs.io/v1/functions/{function_name}/invoke"
|
||||
params = {"invocationType": "RequestResponse"}
|
||||
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
while retries < max_retries:
|
||||
try:
|
||||
async with session.post(
|
||||
url,
|
||||
|
@ -64,12 +63,12 @@ async def async_invoke_function(
|
|||
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.warning(
|
||||
f"Request timeout after {timeout}s (attempt {retries + 1}/{max_retries}). "
|
||||
f"URL: {url}, Headers: {session.headers}"
|
||||
f"Request timeout after {timeout}s, URL: {url}, Headers: {session.headers}, payload: {payload}"
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Async invocation failed on attempt {retries + 1}:{str(e)}")
|
||||
logger.error(f"Async invocation failed on attempt {retries + 1}:{str(e)}, URL: {url}, Headers: {session.headers}")
|
||||
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
|
|
|
@ -25,7 +25,7 @@ def handle(event, context):
|
|||
answers = event.get("answers", "")
|
||||
solutions = event.get("solutions", "")
|
||||
|
||||
print(f"answers:{answers}, solutions:{solutions}\n")
|
||||
#print(f"math payload:{event}\n")
|
||||
# answers and solutions are json lists, and call process_results then collect result into a list
|
||||
if isinstance(answers, str):
|
||||
answers = json.loads(answers)
|
||||
|
|
|
@ -22,8 +22,7 @@ def loadJson(dataDir):
|
|||
id2info = None
|
||||
|
||||
|
||||
def math_verify(generateds: List, query_ids: List, batch_size=20, timeout=60) -> List:
|
||||
start_time = time.time()
|
||||
def math_verify(generateds: List, query_ids: List, batch_size=10, timeout=5) -> List:
|
||||
global id2info
|
||||
if id2info is None:
|
||||
id2info = loadJson(
|
||||
|
@ -49,14 +48,17 @@ def math_verify(generateds: List, query_ids: List, batch_size=20, timeout=60) ->
|
|||
query_indices.append(idx)
|
||||
|
||||
# Process in batches
|
||||
start_time = time.time()
|
||||
batch_args_list = []
|
||||
for i in range(0, len(parameters), batch_size):
|
||||
current_batch = parameters[i : i + batch_size]
|
||||
answers, solutions, indices = zip(*parameters[i : i + batch_size])
|
||||
batch_args = {
|
||||
"answers": [g for g, _, _ in current_batch],
|
||||
"solutions": [s for _, s, _ in current_batch],
|
||||
"answers": list(answers),
|
||||
"solutions": list(solutions),
|
||||
"query_ids": [query_ids[i] for i in indices],
|
||||
}
|
||||
|
||||
#print(batch_args)
|
||||
batch_args_list.append(batch_args)
|
||||
|
||||
results_batch = batch_function_call(batch_args_list, "python_math", timeout)
|
||||
|
@ -65,15 +67,15 @@ def math_verify(generateds: List, query_ids: List, batch_size=20, timeout=60) ->
|
|||
# Map results back to original indices
|
||||
index = 0
|
||||
for batch_idx, results in enumerate(results_batch):
|
||||
query_index = query_indices[index]
|
||||
if not isinstance(results, list) or len(results) == 0:
|
||||
index += len(batch_args_list[batch_idx]["answers"])
|
||||
logger.warning(
|
||||
f"Invalid functioncall math results: {results}, batch index:{batch_idx}, query index: {index}."
|
||||
f"Invalid functioncall math results: {results}, batch index:{batch_idx}, query index: {query_index}, params: {batch_args_list[batch_idx]['answers']}."
|
||||
)
|
||||
continue
|
||||
|
||||
for result in results:
|
||||
query_index = query_indices[index]
|
||||
if (
|
||||
isinstance(result, list)
|
||||
and len(result) > 0
|
||||
|
@ -100,7 +102,7 @@ if __name__ == "__main__":
|
|||
"answer": "\\boxed{-\\frac{2}{3}}",
|
||||
}
|
||||
start_time = time.time()
|
||||
batch_size = 10000
|
||||
batch_size = 10
|
||||
result = math_verify(
|
||||
[sample["answer"]] * batch_size, [sample["query_id"] for _ in range(batch_size)]
|
||||
)
|
||||
|
|
|
@ -162,6 +162,7 @@ def main_start(args, recover_count: int = 0):
|
|||
REAL_SAVE_RECOVER_STATES="1" if save_recover_states else "0",
|
||||
REAL_MATH_METADATA_PATH=os.getenv("REAL_MATH_METADATA_PATH", ""),
|
||||
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
)
|
||||
for k, v in BASE_ENVIRONS.items():
|
||||
os.environ[k] = v
|
||||
|
|
|
@ -276,6 +276,14 @@ class PPOCODEConfig(CommonExperimentConfig):
|
|||
"Dataset json path REAL_CODE_METADATA_PATH does not exist."
|
||||
)
|
||||
|
||||
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
|
||||
if not (domain.startswith("http://") and ":" in domain):
|
||||
raise RuntimeError(
|
||||
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
|
||||
)
|
||||
|
||||
|
||||
|
||||
# interfaces
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
"ppo_actor",
|
||||
|
|
|
@ -277,6 +277,13 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
raise RuntimeError(
|
||||
"Dataset json path REAL_MATH_METADATA_PATH does not exist."
|
||||
)
|
||||
|
||||
domain = os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "")
|
||||
if domain and (not (domain.startswith("http://") and ":" in domain)):
|
||||
raise RuntimeError(
|
||||
"function call address FUNCTIONCALL_SERVICE_DOMAIN is invalid."
|
||||
)
|
||||
|
||||
|
||||
# interfaces
|
||||
actor_interface = ModelInterfaceAbstraction(
|
||||
|
|
|
@ -33,7 +33,7 @@ from realhf.impl.model.nn.real_llm_api import ReaLModel
|
|||
|
||||
logger = logging.getLogger("Packed Reward Modeling Interface", "benchmark")
|
||||
|
||||
ENABLE_FUNCTION_CALL = os.getenv("ENABLE_FUNCTION_CALL", True)
|
||||
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
|
||||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
|
||||
|
||||
|
||||
|
|
|
@ -620,6 +620,7 @@ class RayController:
|
|||
REAL_SAVE_RECOVER_STATES=os.environ.get("REAL_SAVE_RECOVER_STATES", ""),
|
||||
REAL_MATH_METADATA_PATH=os.environ.get("REAL_MATH_METADATA_PATH", ""),
|
||||
REAL_CODE_METADATA_PATH=os.getenv("REAL_CODE_METADATA_PATH", ""),
|
||||
FUNCTIONCALL_SERVICE_DOMAIN=os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", ""),
|
||||
)
|
||||
runtime_env = {
|
||||
"env_vars": env_vars,
|
||||
|
|
Loading…
Reference in New Issue