AReaL/realhf/system/generation_server.py

232 lines
7.1 KiB
Python

import os
import subprocess
import sys
import time
from pathlib import Path
import ray
import requests
from realhf.api.cli_args import SGLangConfig
from realhf.api.core.system_api import ExpStatus
from realhf.api.core.system_api import GenerationServer as GenerationServerConfig
from realhf.base import (
constants,
gpu_utils,
logging,
name_resolve,
names,
network,
pkg_version,
seeding,
)
from realhf.base.cluster import spec as cluster_spec
from realhf.system.worker_base import PollResult, Worker
logger = logging.getLogger(__name__)
def execute_shell_command(command: str) -> subprocess.Popen:
"""
Execute a shell command and return its process handle.
"""
# Replace newline continuations and split the command string.
command = command.replace("\\\n", " ").replace("\\", " ")
parts = command.split()
return subprocess.Popen(
parts,
text=True,
stdout=sys.stdout,
stderr=subprocess.STDOUT,
)
def apply_sglang_path():
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = ""
sglang_meta = subprocess.check_output(
"python3 -m pip show sglang", shell=True
).decode("ascii")
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]).parent)
if target_path:
proc = subprocess.Popen(
["git", "apply", patch_path],
cwd=target_path,
stderr=sys.stdout,
stdout=sys.stdout,
)
proc.wait()
logger.info(f"Applied SGLang patch at {target_path}")
def launch_server_cmd(command: str, port: int = 30000):
"""
Launch the server using the given command.
If no port is specified, a free port is reserved.
"""
if not ray.is_initialized():
apply_sglang_path()
assert port is not None
full_command = f"{command} --port {port}"
process = execute_shell_command(full_command)
return process, port
def terminate_process(process, port=None):
"""
Terminate the process and, if a port was reserved, release it.
"""
from sglang.srt.utils import kill_process_tree
kill_process_tree(process.pid)
def wait_for_server(base_url: str, timeout: int = None) -> None:
"""Wait for the server to be ready by polling the /v1/models endpoint.
Args:
base_url: The base URL of the server
timeout: Maximum time to wait in seconds. None means wait forever.
"""
start_time = time.time()
while True:
try:
response = requests.get(
f"{base_url}/v1/models",
headers={"Authorization": "Bearer None"},
)
if response.status_code == 200:
time.sleep(5)
break
if timeout and time.time() - start_time > timeout:
raise TimeoutError("Server did not become ready within timeout period")
except requests.exceptions.RequestException:
time.sleep(1)
PORT_CLEARANCE_PERIOD = 90
class GenerationServer(Worker):
def _configure(self, config: GenerationServerConfig):
self.config = config
self.worker_index = config.worker_info.worker_index
self.worker_count = config.worker_info.worker_count
self.experiment_name = config.worker_info.experiment_name
self.trial_name = config.worker_info.trial_name
seeding.set_random_seed(
config.base_seed, f"generation_server{self.worker_index}"
)
# Cancel the effect of CUDA device isolation
if ray.is_initialized():
self.base_gpu_id = 0
elif "CUDA_VISIBLE_DEVICES" in os.environ:
self.base_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"])
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(gpu_utils.gpu_count()))
)
else:
servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size
idx_on_this_node = self.worker_index % servers_per_node
self.base_gpu_id = idx_on_this_node * self.config.tp_size
self.server_process = None
self.server_addr = None
return config.worker_info
def launch_server_subprocess(self):
config = self.config
assert config.backend_type == "sglang"
host_ip = network.gethostip()
host = "localhost" if not config.backend_args.enable_metrics else host_ip
# NOTE: Ports returned by `find_multiple_free_ports` are unique,
# but SGLang servers still encounter conflicts.
# Use a clearance period to hack over this issue.
servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size
idx_on_this_node = self.worker_index % servers_per_node
time.sleep(idx_on_this_node * PORT_CLEARANCE_PERIOD / servers_per_node)
ports = network.find_multiple_free_ports(
2,
low=10000,
high=60000,
experiment_name=self.experiment_name,
trial_name=self.trial_name,
)
server_port = ports[0]
nccl_port = ports[1]
cmd = SGLangConfig.build_cmd(
config.backend_args,
config.model_path,
tp_size=config.tp_size,
server_index=self.worker_index,
base_gpu_id=self.base_gpu_id,
dist_init_addr=f"{host}:{nccl_port}",
)
self.server_process, self.server_port = launch_server_cmd(cmd, port=server_port)
self.server_addr = f"http://{host}:{self.server_port}"
wait_for_server(self.server_addr)
name = names.gen_servers(self.experiment_name, self.trial_name)
name_resolve.add_subentry(name, self.server_addr)
key = names.metric_server(
self.experiment_name,
self.trial_name,
"sglang",
f"server{self.worker_index}",
)
name_resolve.add(
key, f"{host}:{self.server_port}", keepalive_ttl=None, delete_on_exit=True
)
logger.info(f"SGLang server launched at: {self.server_addr}")
def _poll(self):
if self.server_process is None:
self.launch_server_subprocess()
# Check experiment finish.
name = names.experiment_status(
constants.experiment_name(), constants.trial_name()
)
try:
exp_status = name_resolve.wait(name, timeout=300)
if exp_status != str(ExpStatus.RUNNING):
self.exit()
return PollResult(0, 0)
except TimeoutError:
raise TimeoutError(
f"Waiting for experiment status timeout. "
"This indicates that the master worker is not running. Exit the worker."
)
time.sleep(5)
return PollResult(0, 0)
def _exit_hook(self, exit_status):
if self.server_process is not None and self.config.backend_type == "sglang":
terminate_process(self.server_process)