mirror of https://github.com/inclusionAI/AReaL
sglang server wrapper
This commit is contained in:
parent
4a26f28adf
commit
805437463f
|
@ -643,19 +643,8 @@ class GRPOConfig(BaseExperimentConfig):
|
|||
ref: PPOActorConfig = field(default_factory=PPOActorConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArgParseResult:
|
||||
config: BaseExperimentConfig
|
||||
config_file: Path
|
||||
additional_args: Optional[argparse.Namespace] = None
|
||||
overrides: Optional[List[str]] = None
|
||||
|
||||
|
||||
def parse_cli_args(
|
||||
argv: List[str], parser: Optional[argparse.ArgumentParser] = None
|
||||
) -> ArgParseResult:
|
||||
if parser is None:
|
||||
parser = argparse.ArgumentParser()
|
||||
def parse_cli_args(argv: List[str]):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config", help="The path of the main configuration file", required=True
|
||||
)
|
||||
|
@ -671,9 +660,7 @@ def parse_cli_args(
|
|||
config_name=str(relpath.name).rstrip(".yaml"),
|
||||
overrides=overrides,
|
||||
)
|
||||
return ArgParseResult(
|
||||
config=cfg, config_file=config_file, additional_args=args, overrides=overrides
|
||||
)
|
||||
return cfg, config_file
|
||||
|
||||
|
||||
def to_structured_cfg(cfg, config_cls):
|
||||
|
@ -684,11 +671,8 @@ def to_structured_cfg(cfg, config_cls):
|
|||
return cfg
|
||||
|
||||
|
||||
def load_expr_config(
|
||||
argv: List[str], config_cls, parser: Optional[argparse.ArgumentParser] = None
|
||||
) -> ArgParseResult:
|
||||
r = parse_cli_args(argv, parser=parser)
|
||||
cfg = r.config
|
||||
def load_expr_config(argv: List[str], config_cls):
|
||||
cfg, config_file = parse_cli_args(argv)
|
||||
cfg = to_structured_cfg(cfg, config_cls=config_cls)
|
||||
cfg = OmegaConf.to_object(cfg)
|
||||
assert isinstance(cfg, BaseExperimentConfig)
|
||||
|
@ -698,6 +682,4 @@ def load_expr_config(
|
|||
|
||||
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
|
||||
name_resolve.reconfigure(cfg.cluster.name_resolve)
|
||||
return ArgParseResult(
|
||||
config=cfg, config_file=r.config_file, additional_args=r.additional_args
|
||||
)
|
||||
return cfg, config_file
|
||||
|
|
|
@ -54,6 +54,8 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
if not self.addresses:
|
||||
raise RuntimeError("No configured SGLang servers.")
|
||||
for addr in self.addresses:
|
||||
# FIXME
|
||||
print(f"waiting for server address {addr}")
|
||||
self._wait_for_server(addr)
|
||||
|
||||
self.server_idx = 0
|
||||
|
@ -88,7 +90,8 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout=30,
|
||||
)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Check {base_url}/metrics failed, reason: {e}")
|
||||
return False
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
||||
|
|
|
@ -243,8 +243,7 @@ class LocalLauncher:
|
|||
|
||||
|
||||
def main_local():
|
||||
r = parse_cli_args(sys.argv[2:])
|
||||
cfg = r.config
|
||||
cfg, _ = parse_cli_args(sys.argv[2:])
|
||||
alloc_mode = AllocationMode.from_str(cfg.allocation_mode)
|
||||
|
||||
launcher = LocalLauncher(cfg.experiment_name, cfg.trial_name, cfg.cluster.fileroot)
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
import requests
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
NameResolveConfig,
|
||||
SGLangConfig,
|
||||
parse_cli_args,
|
||||
to_structured_cfg,
|
||||
)
|
||||
from arealite.api.io_struct import AllocationMode, AllocationType
|
||||
from arealite.utils.network import find_free_ports, gethostip
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
||||
logger = logging.getLogger("SGLangServer Wrapper")
|
||||
|
||||
|
||||
def execute_shell_command(command: str) -> subprocess.Popen:
|
||||
"""
|
||||
Execute a shell command and return its process handle.
|
||||
"""
|
||||
# Replace newline continuations and split the command string.
|
||||
command = command.replace("\\\n", " ").replace("\\", " ")
|
||||
parts = command.split()
|
||||
return subprocess.Popen(
|
||||
parts,
|
||||
text=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
|
||||
def apply_sglang_patch():
|
||||
p = Path(os.path.dirname(__file__))
|
||||
patch_path = str(
|
||||
p.parent.parent
|
||||
/ "patch"
|
||||
/ "sglang"
|
||||
/ f"v{pkg_version.get_version('sglang')}.patch"
|
||||
)
|
||||
|
||||
target_path = ""
|
||||
sglang_meta = subprocess.check_output(
|
||||
"python3 -m pip show sglang", shell=True
|
||||
).decode("ascii")
|
||||
for line in sglang_meta.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("Editable project location: "):
|
||||
target_path = str(Path(line.split(": ")[1]).parent)
|
||||
|
||||
if target_path:
|
||||
proc = subprocess.Popen(
|
||||
["git", "apply", patch_path],
|
||||
cwd=target_path,
|
||||
stderr=sys.stdout,
|
||||
stdout=sys.stdout,
|
||||
)
|
||||
proc.wait()
|
||||
logger.info(f"Applied SGLang patch at {target_path}")
|
||||
|
||||
|
||||
def launch_server_cmd(command: str):
|
||||
"""
|
||||
Launch the server using the given command.
|
||||
If no port is specified, a free port is reserved.
|
||||
"""
|
||||
if not ray.is_initialized():
|
||||
apply_sglang_patch()
|
||||
process = execute_shell_command(command)
|
||||
return process
|
||||
|
||||
|
||||
def wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
|
||||
"""Wait for the server to be ready by polling the /v1/models endpoint.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the server
|
||||
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/v1/models",
|
||||
headers={"Authorization": "Bearer None"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
time.sleep(5)
|
||||
break
|
||||
|
||||
if timeout and time.time() - start_time > timeout:
|
||||
raise TimeoutError("Server did not become ready within timeout period")
|
||||
except requests.exceptions.RequestException:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
class SGLangServerWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
trial_name: str,
|
||||
name_resolve_config: NameResolveConfig,
|
||||
sglang_config: SGLangConfig,
|
||||
tp_size: int,
|
||||
):
|
||||
self.experiment_name = experiment_name
|
||||
self.trial_name = trial_name
|
||||
self.name_resolve_config = name_resolve_config
|
||||
self.config = sglang_config
|
||||
self.tp_size = tp_size
|
||||
self.server_process = None
|
||||
|
||||
def run(self):
|
||||
name_resolve.reconfigure(self.name_resolve_config)
|
||||
|
||||
server_port, dist_init_port = find_free_ports(2, (10000, 50000))
|
||||
dist_init_addr = f"localhost:{dist_init_port}"
|
||||
host_ip = gethostip()
|
||||
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
self.config, tp_size, 0, host_ip, server_port, dist_init_addr=dist_init_addr
|
||||
)
|
||||
self.server_process = launch_server_cmd(cmd)
|
||||
wait_for_server(f"http://{host_ip}:{server_port}")
|
||||
|
||||
name = names.gen_servers(self.experiment_name, self.trial_name)
|
||||
name_resolve.add_subentry(name, f"{host_ip}:{server_port}")
|
||||
|
||||
logger.info(f"SGLang server launched at: http://{host_ip}:{server_port}")
|
||||
return_code = self.server_process.wait()
|
||||
logger.info(
|
||||
f"SGLang server at http://{host_ip}:{server_port} exits, returncode={return_code}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config, _ = parse_cli_args(sys.argv[2:])
|
||||
config.sglang = to_structured_cfg(config.sglang, SGLangConfig)
|
||||
config.cluster.name_resolve = to_structured_cfg(
|
||||
config.cluster.name_resolve, NameResolveConfig
|
||||
)
|
||||
|
||||
allocation_mode = config.allocation_mode
|
||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||
assert allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG
|
||||
tp_size = allocation_mode.gen_tp_size
|
||||
|
||||
sglang_server = SGLangServerWrapper(config.sglang, tp_size)
|
||||
sglang_server.run()
|
|
@ -160,7 +160,8 @@ echo head_node_ip=$head_node_ip
|
|||
wait
|
||||
"""
|
||||
|
||||
SRUN_CMD_TEMPLATE = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD --nodes={nodes} --ntasks={ntasks} --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
|
||||
SRUN_CMD_TEMPLATE = """srun --overlap --mpi=pmi2 -K -l --chdir $PWD --nodelist=${{nodes_array[{node_id}]}} \\
|
||||
--nodes={nodes} --ntasks={ntasks} --gres=gpu:{n_gpus_per_node} --cpus-per-task={cpus_per_task} \\
|
||||
--mem-per-cpu={mem_per_cpu}M {apptainer_name} exec {apptainer_options} --bind {container_mounts} \\
|
||||
{container_env_strings} \\
|
||||
{container_image} \\
|
||||
|
@ -336,6 +337,7 @@ class SlurmLauncher:
|
|||
# resolve CUDA_VISIBLE_DEVICES for each task
|
||||
gpu_id_start = (i % ntasks_per_node) * n_gpus_per_task
|
||||
gpu_id_end = ((i % ntasks_per_node) + 1) * n_gpus_per_task
|
||||
node_id = i // ntasks_per_node
|
||||
_env_vars = {
|
||||
**env_vars,
|
||||
"CUDA_VISIBLE_DEVICES": ",".join(
|
||||
|
@ -353,6 +355,7 @@ class SlurmLauncher:
|
|||
srun_cmd = SRUN_CMD_TEMPLATE.format(
|
||||
nodes=1,
|
||||
ntasks=1,
|
||||
node_id=node_id,
|
||||
n_gpus_per_node=n_gpus_per_node,
|
||||
cpus_per_task=cpus_per_task,
|
||||
mem_per_cpu=mem_per_cpu,
|
||||
|
@ -374,7 +377,6 @@ class SlurmLauncher:
|
|||
f.write(sbatch_script)
|
||||
|
||||
# Submit the job
|
||||
# FIXME: debug only
|
||||
try:
|
||||
output = (
|
||||
subprocess.check_output(["sbatch", sbatch_file_path])
|
||||
|
@ -540,39 +542,10 @@ class SlurmLauncher:
|
|||
)
|
||||
|
||||
|
||||
def slurm_args_parser():
|
||||
parser = argparse.ArgumentParser(description="Slurm Launcher for AReaL")
|
||||
parser.add_argument(
|
||||
"--sglang-server-base-port",
|
||||
type=int,
|
||||
required=False,
|
||||
default=27010,
|
||||
help="Base port for SGLang servers. SGLang servers on the same node will .",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainer-port",
|
||||
type=int,
|
||||
required=False,
|
||||
default=27009,
|
||||
help="Pytorch distributed initialization port for trainer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sglang-version",
|
||||
type=str,
|
||||
required=False,
|
||||
default="0.4.6.post4",
|
||||
help="SGLang version in your GPU inference image.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# usage: python -m arealite.launcher.slurm <entry_point> <config_path> [<args>]
|
||||
r = parse_cli_args(sys.argv[2:], parser=slurm_args_parser())
|
||||
config, config_file = parse_cli_args(sys.argv[2:])
|
||||
entry_point = sys.argv[1]
|
||||
config = r.config
|
||||
config_file = r.config_file
|
||||
args = r.additional_args
|
||||
|
||||
config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig)
|
||||
n_nodes = config.n_nodes
|
||||
|
@ -671,6 +644,7 @@ if __name__ == "__main__":
|
|||
sglang_addrs.extend(
|
||||
[f"{host}:{port}" for port in sglang_ports_on_node]
|
||||
)
|
||||
logger.info(f"Get SGLang addresses: {' '.join(sglang_addrs)}")
|
||||
assert len(sglang_addrs) == n_sglang_servers
|
||||
break
|
||||
time.sleep(10)
|
||||
|
|
|
@ -72,8 +72,8 @@ def gsm8k_reward_fn(
|
|||
|
||||
|
||||
def main_grpo():
|
||||
r = load_expr_config(sys.argv[1:], GRPOConfig)
|
||||
config: GRPOConfig = r.config
|
||||
config, _ = load_expr_config(sys.argv[1:], GRPOConfig)
|
||||
config: GRPOConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
|
|
|
@ -38,8 +38,8 @@ def get_gsm8k_dataset(split, tokenizer, rank, world_size):
|
|||
|
||||
|
||||
def main_sft():
|
||||
r = load_expr_config(sys.argv[1:], SFTConfig)
|
||||
config: SFTConfig = r.config
|
||||
config, _ = load_expr_config(sys.argv[1:], SFTConfig)
|
||||
config: SFTConfig
|
||||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
|
|
|
@ -40,7 +40,7 @@ def execute_shell_command(command: str) -> subprocess.Popen:
|
|||
)
|
||||
|
||||
|
||||
def apply_sglang_path():
|
||||
def apply_sglang_patch():
|
||||
p = Path(os.path.dirname(__file__))
|
||||
patch_path = str(
|
||||
p.parent.parent
|
||||
|
@ -75,7 +75,7 @@ def launch_server_cmd(command: str, port: int = 30000):
|
|||
If no port is specified, a free port is reserved.
|
||||
"""
|
||||
if not ray.is_initialized():
|
||||
apply_sglang_path()
|
||||
apply_sglang_patch()
|
||||
assert port is not None
|
||||
full_command = f"{command} --port {port}"
|
||||
process = execute_shell_command(full_command)
|
||||
|
|
Loading…
Reference in New Issue