mirror of https://github.com/inclusionAI/AReaL
490 lines
20 KiB
Python
490 lines
20 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
import asyncio
|
|
import os
|
|
import shutil
|
|
import threading
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
|
|
import aiohttp
|
|
import numpy as np
|
|
|
|
from realhf.api.core.model_api import GenReqMeta, GenRespMeta, ModelVersionReq
|
|
from realhf.api.core.system_api import ExpStatus
|
|
from realhf.api.core.system_api import GserverManager as GserverManagerConfig
|
|
from realhf.base import constants, logging, name_resolve, names, network, recover
|
|
from realhf.base.monitor import RolloutStat
|
|
from realhf.system.worker_base import AsyncWorker, PollResult, Worker
|
|
|
|
logger = logging.getLogger("Generation Manager", "system")
|
|
|
|
STALENESS_WARNED = defaultdict(lambda: False)
|
|
|
|
|
|
@dataclass
|
|
class AllocateRolloutInput:
|
|
qid: str
|
|
|
|
|
|
class GserverManager(AsyncWorker):
|
|
"""This worker has the following functionalities:
|
|
1. As a router, it schedules generation requests and returns the
|
|
best server urls to clients for submitting generation requests.
|
|
2. It manages the weight update requests of generation servers.
|
|
The weight update manager must be unique in each experiment.
|
|
|
|
This is currently a hack usage of SGLang. We can integrate the
|
|
functionalities into sgl-router and srt in the future.
|
|
"""
|
|
|
|
def _configure(self, config: GserverManagerConfig):
|
|
self.config = config
|
|
self.model_name = config.model_name
|
|
|
|
assert self.config.worker_info.worker_count == 1
|
|
|
|
self.threading_lock = threading.Lock()
|
|
self.rollout_stat = RolloutStat()
|
|
|
|
self.schedule_policy = config.schedule_policy
|
|
|
|
self._last_param_realloc_step = 0
|
|
|
|
self._qid_to_server_url = {}
|
|
|
|
self._server_token_usage = defaultdict(float)
|
|
self._server_request_counts = defaultdict(int)
|
|
|
|
self._last_thpt_output_time = time.time()
|
|
self._gen_tokens = 0
|
|
|
|
self.experiment_name = config.worker_info.experiment_name
|
|
self.trial_name = config.worker_info.trial_name
|
|
|
|
# manager server
|
|
self.manager_http_server = None
|
|
self.thread = None
|
|
|
|
self.server_urls = []
|
|
|
|
# recover info
|
|
self.__recover_run, self.__recover_info = recover.load_recover_info()
|
|
if self.__recover_run:
|
|
# update weights will be automatically triggered upon the first schedule_request
|
|
# self._last_param_realloc_step will also be updated
|
|
name = names.model_version(
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
self.model_name.role,
|
|
)
|
|
name_resolve.add(name, self.__recover_info.last_step_info.global_step)
|
|
|
|
self._loaded_recover_weights = False
|
|
hist_rollouts = (
|
|
self.config.train_batch_size
|
|
* self.__recover_info.last_step_info.global_step
|
|
)
|
|
self.rollout_stat.submitted = hist_rollouts
|
|
self.rollout_stat.accepted = hist_rollouts
|
|
|
|
return config.worker_info
|
|
|
|
def _discover_servers(self, n_servers: int, timeout: int = 300) -> List[str]:
|
|
logger.info(f"Waiting for {n_servers} generation servers...")
|
|
name = names.gen_servers(self.experiment_name, self.trial_name)
|
|
cnt = 0
|
|
while len(name_resolve.find_subtree(name)) < n_servers:
|
|
time.sleep(1)
|
|
cnt += 1
|
|
if cnt >= timeout:
|
|
raise TimeoutError("Waiting generation servers timeout.")
|
|
urls = name_resolve.get_subtree(name)
|
|
assert len(set(urls)) == len(urls), (len(urls), len(set(urls)), urls)
|
|
return urls
|
|
|
|
def _get_recover_ckpt_path(self, role: str):
|
|
assert self.__recover_run
|
|
epoch = self.__recover_info.last_step_info.epoch + 1
|
|
epochstep = self.__recover_info.last_step_info.epoch_step + 1
|
|
globalstep = self.__recover_info.last_step_info.global_step + 1
|
|
save_root = os.path.join(
|
|
constants.MODEL_SAVE_ROOT,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
)
|
|
role_path = os.path.join(save_root, role)
|
|
if not os.path.exists(role_path):
|
|
raise RuntimeError(
|
|
f"Guessed checkpoint path {role_path} does not exist. "
|
|
"Skip loading checkpoints in the recovered run."
|
|
)
|
|
model_path = os.path.join(
|
|
role_path,
|
|
f"epoch{epoch}epochstep{epochstep}globalstep{globalstep}",
|
|
)
|
|
if not os.path.exists(model_path):
|
|
raise RuntimeError(
|
|
f"Guessed checkpoint path {model_path} does not exist. "
|
|
"Skip loading checkpoints in the recovered run."
|
|
)
|
|
return model_path
|
|
|
|
def check_new_params(self) -> str | None:
|
|
name = names.model_version(
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
self.model_name.role,
|
|
)
|
|
try:
|
|
realloc_version = int(name_resolve.get(name))
|
|
except name_resolve.NameEntryNotFoundError:
|
|
return None
|
|
|
|
# Update the model weights after parameter realloction.
|
|
if realloc_version > self._last_param_realloc_step:
|
|
if self.__recover_run and not self._loaded_recover_weights:
|
|
realloc_dir = self._get_recover_ckpt_path(self.model_name.role)
|
|
self._loaded_recover_weights = True
|
|
else:
|
|
realloc_dir = os.path.join(
|
|
constants.PARAM_REALLOC_PATH,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
self.model_name.role,
|
|
str(realloc_version),
|
|
)
|
|
self._last_param_realloc_step = realloc_version
|
|
return realloc_dir
|
|
|
|
return None
|
|
|
|
async def flush_requests_and_update_weights(
|
|
self, server_url, new_param_path, update_weights_retries=5
|
|
):
|
|
server_index = self.server_urls.index(server_url)
|
|
success = False
|
|
for _ in range(update_weights_retries):
|
|
async with aiohttp.ClientSession(
|
|
server_url,
|
|
timeout=aiohttp.ClientTimeout(
|
|
total=self.config.flush_request_timeout,
|
|
sock_connect=self.config.flush_request_timeout,
|
|
),
|
|
) as session:
|
|
async with session.post(
|
|
f"/update_weights_from_disk",
|
|
json=dict(model_path=new_param_path, allow_interrupt=True),
|
|
) as resp:
|
|
if resp.status == 200:
|
|
res = await resp.json()
|
|
success = res["success"]
|
|
if success:
|
|
if "num_paused_requests" in res:
|
|
logger.info(
|
|
f"{res['num_paused_requests']} requests are interrupted "
|
|
f"during updateing weights for server {server_index}: {server_url}"
|
|
)
|
|
return
|
|
logger.warning(
|
|
f"Update weights failed: {res['message']}. Retrying."
|
|
)
|
|
logger.warning(f"Update weights failed: {resp.reason}. Retrying.")
|
|
time.sleep(0.1)
|
|
raise RuntimeError("Update weights failed.")
|
|
|
|
def _round_robin_schedule(self, req_meta: GenReqMeta) -> int:
|
|
if not hasattr(self, "round_robin_idx"):
|
|
self.round_robin_idx = 0
|
|
r = self.round_robin_idx
|
|
self.round_robin_idx += 1
|
|
self.round_robin_idx %= self.config.n_servers
|
|
return r
|
|
|
|
def _least_requests_schedule(self, req_meta: GenReqMeta) -> int:
|
|
counts = [
|
|
self._server_request_counts[server_url] for server_url in self.server_urls
|
|
]
|
|
return int(np.argmin(counts))
|
|
|
|
def _least_token_usage_schedule(self, req_meta: GenReqMeta) -> int:
|
|
url = min(self.server_urls, key=lambda k: self._server_token_usage[k])
|
|
return self.server_urls.index(url)
|
|
|
|
async def _poll_async(self):
|
|
if not self.thread:
|
|
# Find addresses of generation servers
|
|
self.server_urls = self._discover_servers(self.config.n_servers)
|
|
self.thread = threading.Thread(
|
|
target=self._run_routing_service, daemon=True
|
|
)
|
|
self.thread.start()
|
|
time.sleep(3) # Wait briefly for server to start
|
|
# Write address for clients
|
|
name = names.gen_server_manager(self.experiment_name, self.trial_name)
|
|
name_resolve.add(name, self.manager_addr)
|
|
logger.info(
|
|
f"GserverManager HTTP service started in background thread at {self.manager_addr}"
|
|
)
|
|
|
|
# Check experiment finish.
|
|
name = names.experiment_status(
|
|
constants.experiment_name(), constants.trial_name()
|
|
)
|
|
try:
|
|
exp_status = name_resolve.wait(name, timeout=300)
|
|
if exp_status != str(ExpStatus.RUNNING):
|
|
self.exit()
|
|
return PollResult(0, 0)
|
|
except TimeoutError:
|
|
raise TimeoutError(
|
|
f"Waiting for experiment status timeout. "
|
|
"This indicates that the master worker is not running. Exit the worker."
|
|
)
|
|
|
|
# Check weights.
|
|
with self.threading_lock:
|
|
# FIXME: we create a sync point across servers to update weights,
|
|
# but we can acutally update them individually
|
|
new_param_path = self.check_new_params()
|
|
if new_param_path is not None:
|
|
tasks = [
|
|
self.flush_requests_and_update_weights(base_url, new_param_path)
|
|
for base_url in self.server_urls
|
|
]
|
|
await asyncio.gather(*tasks)
|
|
logger.info(f"Generaion server updated weights from: {new_param_path}")
|
|
|
|
if self.schedule_policy == "least_token_usage":
|
|
tasks = [
|
|
self._get_server_token_usage(server_url)
|
|
for server_url in self.server_urls
|
|
]
|
|
loop = asyncio.get_event_loop()
|
|
token_usages = loop.run_until_complete(asyncio.gather(*tasks))
|
|
with self.threading_lock:
|
|
for server_url, token_usage in zip(self.server_urls, token_usages):
|
|
self._server_token_usage[server_url] = token_usage
|
|
|
|
if time.time() - self._last_thpt_output_time > 30:
|
|
interval = time.time() - self._last_thpt_output_time
|
|
logger.info(
|
|
f"Generation throughput: {self._gen_tokens / interval:.2f} tokens/s"
|
|
)
|
|
self._last_thpt_output_time = time.time()
|
|
self._gen_tokens = 0
|
|
|
|
# clear old weights
|
|
realloc_root = os.path.join(
|
|
constants.PARAM_REALLOC_PATH,
|
|
constants.experiment_name(),
|
|
constants.trial_name(),
|
|
self.model_name.role,
|
|
)
|
|
if os.path.exists(realloc_root):
|
|
for realloc_version in os.listdir(realloc_root):
|
|
# Lock-free is safe here.
|
|
# Remain one checkpoint for recover.
|
|
if (
|
|
os.path.isdir(os.path.join(realloc_root, realloc_version))
|
|
and int(realloc_version) < self._last_param_realloc_step - 1
|
|
):
|
|
shutil.rmtree(os.path.join(realloc_root, realloc_version))
|
|
logger.info(
|
|
f"Removed previous reallocated "
|
|
f"checkpoint: {os.path.join(realloc_root, realloc_version)}"
|
|
)
|
|
|
|
time.sleep(5)
|
|
|
|
return PollResult(0, 0)
|
|
|
|
async def _get_server_token_usage(self, server_url):
|
|
async with aiohttp.ClientSession(
|
|
server_url,
|
|
timeout=aiohttp.ClientTimeout(
|
|
total=self.config.flush_request_timeout,
|
|
sock_connect=self.config.flush_request_timeout,
|
|
),
|
|
) as session:
|
|
async with session.get("/metrics") as resp:
|
|
resp.raise_for_status()
|
|
text = await resp.text()
|
|
for l in text.split("\n"):
|
|
if l.startswith("sglang:num_used_tokens"):
|
|
return float(l.split(" ")[1])
|
|
raise RuntimeError(f"Failed to get token usage metrics from {server_url}")
|
|
|
|
async def _get_server_num_running_requests(self, server_url):
|
|
async with aiohttp.ClientSession(
|
|
server_url,
|
|
timeout=aiohttp.ClientTimeout(
|
|
total=self.config.flush_request_timeout,
|
|
sock_connect=self.config.flush_request_timeout,
|
|
),
|
|
) as session:
|
|
async with session.get(f"/metrics") as resp:
|
|
resp.raise_for_status()
|
|
text = await resp.text()
|
|
for line in text.split("\n"):
|
|
if line.startswith("sglang:num_running_reqs"):
|
|
return float(line.split(" ")[1])
|
|
raise RuntimeError(
|
|
f"Failed to get num running requests metrics from {server_url}"
|
|
)
|
|
|
|
def get_training_sample_cnt(self):
|
|
name = names.training_samples(self.experiment_name, self.trial_name)
|
|
try:
|
|
return int(name_resolve.get(name))
|
|
except name_resolve.NameEntryNotFoundError:
|
|
return 0
|
|
|
|
def is_staled(self):
|
|
# Use counter written by the trainer, local counter is inaccurate
|
|
global_sample_cnt = self.get_training_sample_cnt() + self.rollout_stat.running
|
|
expected_version = global_sample_cnt // self.config.train_batch_size
|
|
version = self._last_param_realloc_step
|
|
staled = expected_version > self.config.max_head_offpolicyness + version
|
|
global STALENESS_WARNED
|
|
if staled and not STALENESS_WARNED[version]:
|
|
logger.warning(
|
|
f"expected version ({expected_version}) = "
|
|
f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), "
|
|
f"current latest version {version}, "
|
|
f"offpolicyness {self.config.max_head_offpolicyness}. Staled? {staled}"
|
|
)
|
|
STALENESS_WARNED[version] = True
|
|
return staled
|
|
|
|
def _run_routing_service(self):
|
|
"""Expose an API for clients to find the destination server."""
|
|
import uvicorn
|
|
from fastapi import FastAPI
|
|
|
|
self.app = FastAPI()
|
|
|
|
@self.app.post("/schedule_request")
|
|
async def schedule_request(req_meta: GenReqMeta):
|
|
with self.threading_lock:
|
|
if (
|
|
req_meta.previous_server_url
|
|
and req_meta.previous_version == self._last_param_realloc_step
|
|
):
|
|
return dict(
|
|
url=req_meta.previous_server_url,
|
|
version=req_meta.previous_version,
|
|
)
|
|
|
|
if self.schedule_policy == "round_robin":
|
|
server_idx = self._round_robin_schedule(req_meta)
|
|
elif self.schedule_policy == "least_token_usage":
|
|
server_idx = self._least_token_usage_schedule(req_meta)
|
|
elif self.schedule_policy == "least_requests":
|
|
server_idx = self._least_requests_schedule(req_meta)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Unknown schedule policy {self.schedule_policy}"
|
|
)
|
|
|
|
server_url = self.server_urls[server_idx]
|
|
# qid prompt (n samples) use the same dst server
|
|
self._qid_to_server_url[req_meta.qid] = server_url
|
|
self._server_request_counts[server_url] += 1
|
|
self._server_token_usage[server_url] += (
|
|
req_meta.prompt_len
|
|
+ req_meta.new_token_budget * req_meta.group_size * 0.4
|
|
)
|
|
|
|
version = self._last_param_realloc_step
|
|
return dict(url=server_url, version=version)
|
|
|
|
@self.app.post("/get_model_version")
|
|
async def get_model_version(req: ModelVersionReq):
|
|
with self.threading_lock:
|
|
# FIXME: we may have different versions for different servers
|
|
version = self._last_param_realloc_step
|
|
return dict(version=version)
|
|
|
|
@self.app.post("/allocate_rollout")
|
|
async def allocate_rollout(req: AllocateRolloutInput):
|
|
with self.threading_lock:
|
|
has_capacity = (
|
|
self.rollout_stat.running < self.config.max_concurrent_rollouts
|
|
)
|
|
is_staled = self.is_staled()
|
|
reason = ""
|
|
if has_capacity and not is_staled:
|
|
self.rollout_stat.submitted += 1
|
|
self.rollout_stat.running += 1
|
|
logger.info(
|
|
f"Allocate rollout for qid {req.qid}. "
|
|
f"Submitted: {self.rollout_stat.submitted}, "
|
|
f"running: {self.rollout_stat.running}, "
|
|
f"accepted: {self.rollout_stat.accepted}."
|
|
)
|
|
return dict(success=True, reason=reason)
|
|
else:
|
|
if not has_capacity:
|
|
reason += f"capacity: {self.rollout_stat.running} >= {self.config.max_concurrent_rollouts}"
|
|
if is_staled:
|
|
global_sample_cnt = (
|
|
self.get_training_sample_cnt() + self.rollout_stat.running
|
|
)
|
|
expected_version = (
|
|
global_sample_cnt // self.config.train_batch_size
|
|
)
|
|
version = self._last_param_realloc_step
|
|
reason += (
|
|
f" and staled: expected version ({expected_version}) = "
|
|
f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), "
|
|
f"current latest version {version}, "
|
|
f"offpolicyness {self.config.max_head_offpolicyness}."
|
|
)
|
|
return dict(success=False, reason=reason)
|
|
|
|
@self.app.post("/finish_rollout")
|
|
async def finish_rollout(resp_meta: GenRespMeta):
|
|
with self.threading_lock:
|
|
server_url = self._qid_to_server_url[resp_meta.qid]
|
|
self._server_request_counts[server_url] -= 1
|
|
assert (
|
|
self._server_request_counts[server_url] >= 0
|
|
), "server request count < 0"
|
|
self._qid_to_server_url.pop(resp_meta.qid)
|
|
self._gen_tokens += resp_meta.n_tokens
|
|
self.rollout_stat.running -= 1
|
|
if resp_meta.accepted:
|
|
self.rollout_stat.accepted += 1
|
|
logger.info(
|
|
f"Finish rollout for qid {resp_meta.qid}. "
|
|
f"Running: {self.rollout_stat.running}, "
|
|
f"running: {self.rollout_stat.running}, "
|
|
f"accepted: {self.rollout_stat.accepted}"
|
|
)
|
|
return dict(success=True)
|
|
|
|
port = network.find_free_port(
|
|
experiment_name=self.experiment_name,
|
|
trial_name=self.trial_name,
|
|
)
|
|
self.manager_addr = f"{network.gethostip()}:{port}"
|
|
|
|
config = uvicorn.Config(
|
|
self.app,
|
|
host=self.manager_addr.split(":")[0],
|
|
port=int(self.manager_addr.split(":")[1]),
|
|
log_level="warning",
|
|
)
|
|
self.manager_http_server = uvicorn.Server(config)
|
|
self.manager_http_server.run()
|
|
|
|
def _exit_hook(self, exit_status):
|
|
if self.manager_http_server:
|
|
self.manager_http_server.should_exit = True
|
|
if self.thread:
|
|
self.thread.join(timeout=3)
|
|
logger.info("Server stopped")
|