diff --git a/docs/tutorial/installation.md b/docs/tutorial/installation.md index 8fa5328..ce563b5 100644 --- a/docs/tutorial/installation.md +++ b/docs/tutorial/installation.md @@ -67,6 +67,10 @@ cd AReaL bash examples/env/scripts/setup-pip-deps.sh ``` +::::{note} +The SGLang patch is applied via `examples/env/scripts/setup-container-deps.sh` or `examples/env/scripts/setup-pip-deps.sh`. To confirm whether it has been applied, run `git status` in the `/sglang` directory (for Docker) or `AReaL/sglang` (for custom setups). +:::: + ## (Optional) Launch Ray Cluster for Distributed Training On the first node, start the Ray Head: diff --git a/realhf/base/constants.py b/realhf/base/constants.py index 73962e2..5a90a67 100644 --- a/realhf/base/constants.py +++ b/realhf/base/constants.py @@ -110,7 +110,7 @@ def get_save_path(args: "BaseExperimentConfig") -> str: def get_param_realloc_path(args: "BaseExperimentConfig"): - path = f"{args.cluster.fileroot}/.cache/{getpass.getuser()}/param_realloc" + path = f"{args.cluster.fileroot}/.cache/{getpass.getuser()}/param_realloc/{args.experiment_name}/{args.trial_name}" os.makedirs(path, exist_ok=True) return path diff --git a/realhf/impl/dataset/math_parser.py b/realhf/impl/dataset/math_parser.py index 027be8a..5189c5d 100644 --- a/realhf/impl/dataset/math_parser.py +++ b/realhf/impl/dataset/math_parser.py @@ -3,11 +3,11 @@ import json import multiprocessing import re -from concurrent.futures import ProcessPoolExecutor, as_completed from typing import List, Union import regex from latex2sympy2 import latex2sympy +from pebble import ProcessPool from sympy import N, simplify from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr @@ -785,23 +785,6 @@ def process_results(answer, solution): return 0, ("None", "None") -def process_results_process(a, b, output_queue): - result = process_results(a, b) - output_queue.put(result) - - -def verify_math_solution(answer: str, solution: str): - # answer is generated by the model, solution is the ground truth - tmp = call_with_timeout( - process_results_process, - answer, - solution, - ) - if isinstance(tmp, bool): - return 0 - return tmp[0] - - def loadJson(dataDir): with open(dataDir, "r") as f: if dataDir.endswith(".jsonl"): @@ -817,7 +800,7 @@ def parse_line(id2info, prompt_str, generated, query_id): label = 0 for sol in info["solutions"]: - label = label or verify_math_solution(generated, sol) + label = label or process_results(generated, sol) return label @@ -834,21 +817,26 @@ def parse_lines_in_parallel( ) all_jobs = [] - with ProcessPoolExecutor(max_workers=max_workers) as executor: + with ProcessPool(max_workers=max_workers) as executor: for qid, gen in zip(query_ids, generateds): info = id2info[qid.split("@idx:")[0]] jobs = [] for sol in info["solutions"]: - job = executor.submit(verify_math_solution, gen, sol) + job = executor.schedule(process_results, args=[gen, sol], timeout=15) jobs.append(job) all_jobs.append(jobs) labels = [] for jobs in all_jobs: label = 0 - for job in as_completed(jobs): - x = job.result() - label = label or x + for job in jobs: + try: + x = job.result() + except TimeoutError: + # print("[debug: timeout]") + logger.warning(f"Timeout occurred while justifying the math answer.") + x = (0, "timeout", "timeout") + label = label or x[0] labels.append(label) return labels diff --git a/realhf/system/model_function_call.py b/realhf/system/model_function_call.py index d927455..38f10cc 100644 --- a/realhf/system/model_function_call.py +++ b/realhf/system/model_function_call.py @@ -216,7 +216,7 @@ class ModelFunctionCall: for p in payloads.values(): p.post_hooks.append("save") save_dir = os.path.join( - constants.get_log_path(self.args), + constants.get_save_path(self.args), rpc.model_name.role, f"epoch{ctrl.step_info.epoch + 1}" f"epochstep{ctrl.step_info.epoch_step + 1}"