mirror of https://github.com/inclusionAI/AReaL
PullRequest: 38 Support SGLang generation based on the 24.07 docker image
Merge branch fw/sglang of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/38?tab=comment#note_168700036 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * format * . * . * cleanup
This commit is contained in:
parent
7b3e33430a
commit
767bb7bf47
|
@ -3,10 +3,12 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import keyword
|
||||
from typing import *
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
@ -106,6 +108,139 @@ class GenerationHyperparameters:
|
|||
f"To use CUDAGraph, ReaL's PyTorch version should be at least 2.3.0."
|
||||
)
|
||||
|
||||
def new(self, **kwargs):
|
||||
args = dataclasses.asdict(self)
|
||||
args.update(kwargs)
|
||||
return GenerationHyperparameters(**args)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class APIGenerateInput:
|
||||
qid: Hashable
|
||||
group_idx: int
|
||||
prompt_ids: List[int]
|
||||
input_ids: List[int]
|
||||
gconfig: GenerationHyperparameters
|
||||
stop_token_ids: List[int] = dataclasses.field(default_factory=list)
|
||||
return_logprob: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class APIGenerateOutput:
|
||||
qid: Hashable
|
||||
group_idx: int
|
||||
prompt_ids: List[int]
|
||||
input_ids: List[int]
|
||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||
output_logprobs: List[int] = dataclasses.field(default_factory=list)
|
||||
no_eos: bool = True
|
||||
success: bool = False
|
||||
latency: float = 0.0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: List[float] = dataclasses.field(
|
||||
default_factory=list
|
||||
) # List of inter-token latencies
|
||||
error: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_input(cls, inp: APIGenerateInput):
|
||||
return cls(
|
||||
qid=inp.qid,
|
||||
group_idx=inp.group_idx,
|
||||
prompt_ids=inp.prompt_ids,
|
||||
input_ids=inp.input_ids,
|
||||
)
|
||||
|
||||
@property
|
||||
def output_len(self):
|
||||
return len(self.output_ids)
|
||||
|
||||
@property
|
||||
def input_len(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
@property
|
||||
def prompt_len(self):
|
||||
return len(self.prompt_ids)
|
||||
|
||||
@property
|
||||
def gen_len(self):
|
||||
return self.output_len + self.input_len - self.prompt_len
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BundledGenerationOutputs:
|
||||
qid: Hashable
|
||||
prompt_ids: List[int]
|
||||
seqs: List[List[int]]
|
||||
no_eos: List[bool]
|
||||
|
||||
@classmethod
|
||||
def from_single(cls, outputs: List[APIGenerateOutput]):
|
||||
assert len(set(o.qid for o in outputs)) == 1
|
||||
return cls(
|
||||
qid=outputs[0].qid,
|
||||
prompt_ids=outputs[0].prompt_ids,
|
||||
seqs=[o.input_ids + o.output_ids for o in outputs],
|
||||
no_eos=[o.no_eos for o in outputs],
|
||||
)
|
||||
|
||||
@property
|
||||
def seqlens(self):
|
||||
return [len(seq) for seq in self.seqs]
|
||||
|
||||
@property
|
||||
def prompt_len(self):
|
||||
return len(self.prompt_ids)
|
||||
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
class LLMAPIClient:
|
||||
def __init__(
|
||||
self, generate_url: str, update_weights_url: str, concurrency_limit: int = -1
|
||||
):
|
||||
self.update_weights_url = update_weights_url
|
||||
self.generate_url = generate_url
|
||||
self.concurrency_limit = concurrency_limit
|
||||
|
||||
self.session: aiohttp.ClientSession
|
||||
self.semaphore: asyncio.Semaphore
|
||||
|
||||
async def __aenter__(self):
|
||||
conn = aiohttp.TCPConnector(limit=0, ttl_dns_cache=300)
|
||||
self.session = aiohttp.ClientSession(
|
||||
timeout=AIOHTTP_TIMEOUT,
|
||||
connector=conn,
|
||||
read_bufsize=1024 * 1024 * 10,
|
||||
)
|
||||
if self.concurrency_limit > 0:
|
||||
self.semaphore = asyncio.Semaphore(self.concurrency_limit)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
async def async_add_generate_request(
|
||||
self, req: APIGenerateInput, stream: bool = True
|
||||
) -> APIGenerateOutput:
|
||||
|
||||
if self.concurrency_limit > 0:
|
||||
async with self.semaphore:
|
||||
return await self._do_generate(req, stream=stream)
|
||||
else:
|
||||
return await self._do_generate(req, stream=stream)
|
||||
|
||||
async def _do_generate(
|
||||
self, req: APIGenerateInput, stream: bool = True
|
||||
) -> APIGenerateOutput:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_update_weights_from_disk(self, path):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReaLMoEConfig:
|
||||
|
|
|
@ -131,6 +131,48 @@ class vLLMConfig:
|
|||
additional_engine_args: Dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SGLangConfig:
|
||||
disable_cuda_graph: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_jump_forward: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
enable_nccl_nvls: bool = False
|
||||
disable_outlines_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
disable_mla: bool = False
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
enable_ep_moe: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
cuda_graph_max_bs: Optional[int] = None
|
||||
cuda_graph_bs: Optional[List[int]] = None
|
||||
torchao_config: str = ""
|
||||
enable_nan_detection: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
triton_attention_num_kv_splits: int = 8
|
||||
num_continuous_decode_steps: int = 1
|
||||
enable_memory_saver: bool = False
|
||||
allow_auto_truncate: bool = False
|
||||
return_hidden_states: bool = False
|
||||
# NOTE: to avoid the illegal memory access error
|
||||
attention_backend: Optional[str] = "triton"
|
||||
sampling_backend: Optional[str] = None
|
||||
context_length: Optional[int] = None
|
||||
mem_fraction_static: Optional[float] = None
|
||||
max_running_requests: Optional[int] = None
|
||||
max_total_tokens: Optional[int] = None
|
||||
chunked_prefill_size: Optional[int] = None
|
||||
max_prefill_tokens: int = 16384
|
||||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
hybrid_train: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DistributedDataParallelConfig:
|
||||
"""Configuration for Megatron DistributedDataParallel.
|
||||
|
@ -249,6 +291,7 @@ class ModelTrainEvalConfig:
|
|||
)
|
||||
megatron: MegatronConfig = dataclasses.field(default_factory=MegatronConfig)
|
||||
vllm: vLLMConfig = dataclasses.field(default_factory=vLLMConfig)
|
||||
sglang: SGLangConfig = dataclasses.field(default_factory=SGLangConfig)
|
||||
init_from_scratch: bool = False
|
||||
init_critic_from_actor: bool = False
|
||||
|
||||
|
|
|
@ -80,6 +80,7 @@ TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
|
|||
DATASET_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets"
|
||||
PROFILER_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler"
|
||||
PARAM_REALLOC_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc"
|
||||
SGLANG_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang"
|
||||
TORCH_EXTENSIONS_DIR = (
|
||||
f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/torch/extensions"
|
||||
)
|
||||
|
@ -165,6 +166,7 @@ os.makedirs(DATASET_CACHE_PATH, exist_ok=True)
|
|||
os.makedirs(PROFILER_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True)
|
||||
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(SGLANG_CACHE_PATH, exist_ok=True)
|
||||
|
||||
# _model_name will be changed in the model_scope context manager
|
||||
_model_name: "ModelName" = None
|
||||
|
@ -451,6 +453,10 @@ def prev_pipe_stage():
|
|||
) % pipe_parallel_world_size()
|
||||
|
||||
|
||||
def is_dp_head():
|
||||
return is_last_pipe_stage() and model_parallel_rank() == 0
|
||||
|
||||
|
||||
def model_parallel_rank() -> int:
|
||||
"""Return the rank inside the tensor parallelism group."""
|
||||
try:
|
||||
|
|
|
@ -6,12 +6,15 @@ import socket
|
|||
from contextlib import closing
|
||||
|
||||
|
||||
def find_free_port():
|
||||
def find_free_port(low=1, high=65536):
|
||||
"""From stackoverflow Issue 1365265."""
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("", 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return s.getsockname()[1]
|
||||
while True:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("", 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
port = s.getsockname()[1]
|
||||
if low <= port <= high:
|
||||
return port
|
||||
|
||||
|
||||
def gethostname():
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import List
|
|||
from packaging.version import Version
|
||||
|
||||
from realhf.api.quickstart.device_mesh import RPCAllocation
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, vLLMConfig
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, SGLangConfig, vLLMConfig
|
||||
from realhf.base import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -35,6 +35,18 @@ def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation
|
|||
)
|
||||
|
||||
|
||||
def check_valid_sglang(
|
||||
role: str, sglang: SGLangConfig, rpc_allocs: List[RPCAllocation]
|
||||
):
|
||||
rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.role == role]
|
||||
if sglang.hybrid_train and not any(rpc.is_train() for rpc in rpcs):
|
||||
logger.warning(
|
||||
"SGLang hybrid_train is enabled, but no training RPCs are found."
|
||||
)
|
||||
if sglang.hybrid_train and not sglang.disable_cuda_graph:
|
||||
raise ValueError("SGLang hybrid_train requires CUDA graph to be disabled.")
|
||||
|
||||
|
||||
def check_valid_optimizer(model: ModelTrainEvalConfig):
|
||||
if model.optimizer.min_lr_ratio < 0.0 or model.optimizer.min_lr_ratio > 1.0:
|
||||
raise ValueError(f"Invalid min_lr_ratio: {model.optimizer.min_lr_ratio}")
|
||||
|
|
|
@ -52,6 +52,7 @@ from realhf.experiments.common.check import (
|
|||
check_valid_model_and_path,
|
||||
check_valid_optimizer,
|
||||
check_valid_parallel_batch_size,
|
||||
check_valid_sglang,
|
||||
check_valid_vllm,
|
||||
)
|
||||
from realhf.experiments.common.utils import (
|
||||
|
@ -596,6 +597,8 @@ class CommonExperimentConfig(Experiment):
|
|||
|
||||
if gen_backend_name == "vllm":
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
elif gen_backend_name == "sglang":
|
||||
check_valid_sglang(model_name.role, model_cfg.sglang, rpc_allocs)
|
||||
|
||||
shard_idx = shard_counter[model_name]
|
||||
dict_args: Dict[str, Any] = asdict(backend_cfg)
|
||||
|
@ -693,6 +696,9 @@ class CommonExperimentConfig(Experiment):
|
|||
rpc.is_generate() for rpc in rpcs
|
||||
):
|
||||
assert len(rpcs) == 1 and rpcs[0].is_generate(), rpcs
|
||||
assert (
|
||||
not model_cfg.sglang.hybrid_train
|
||||
), "vLLM and SGLang cannot be enabled at the same time"
|
||||
dict_args: Dict[str, Any] = asdict(model_cfg.vllm)
|
||||
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
|
||||
backend = ModelBackendAbstraction(
|
||||
|
@ -702,6 +708,12 @@ class CommonExperimentConfig(Experiment):
|
|||
**dict_args,
|
||||
),
|
||||
)
|
||||
elif model_cfg.sglang.hybrid_train and any(
|
||||
rpc.is_generate() for rpc in rpcs
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SGLang hybrid_train=True is not supported yet."
|
||||
)
|
||||
else:
|
||||
backend = make_inf_backend_config(model_cfg, rpc_alloc.parallel)
|
||||
if any(rpc.is_generate() for rpc in rpcs) and backend.type_ not in [
|
||||
|
|
|
@ -138,7 +138,7 @@ def resolve_replica_ids(
|
|||
if not same_alloc or (
|
||||
alloc.rpc.is_generate()
|
||||
and main_alloc.rpc.is_train()
|
||||
and (models[role].vllm.hybrid_train)
|
||||
and (models[role].vllm.hybrid_train or models[role].sglang.hybrid_train)
|
||||
):
|
||||
alloc.rpc.model_name = ModelName(role, i)
|
||||
i += 1
|
||||
|
|
|
@ -0,0 +1,452 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from importlib.metadata import version
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from realhf.api.core import data_api
|
||||
from realhf.api.core.model_api import (
|
||||
APIGenerateInput,
|
||||
APIGenerateOutput,
|
||||
FinetuneSpec,
|
||||
GenerationHyperparameters,
|
||||
LLMAPIClient,
|
||||
Model,
|
||||
ModelBackend,
|
||||
PipelinableEngine,
|
||||
register_backend,
|
||||
)
|
||||
from realhf.api.quickstart.model import SGLangConfig
|
||||
from realhf.base import cluster, constants, gpu_utils, logging, network, seeding
|
||||
|
||||
logger = logging.getLogger("SGLang backend")
|
||||
|
||||
|
||||
def remove_prefix(text: str, prefix: str) -> str:
|
||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||
|
||||
|
||||
class SGLangAPIClient(LLMAPIClient):
|
||||
|
||||
async def _do_generate(
|
||||
self, req: APIGenerateInput, stream: bool = False
|
||||
) -> APIGenerateOutput:
|
||||
gconfig = req.gconfig
|
||||
sample_params = {
|
||||
"n": gconfig.n,
|
||||
"top_p": gconfig.top_p,
|
||||
"top_k": gconfig.top_k,
|
||||
"max_new_tokens": gconfig.max_new_tokens,
|
||||
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
|
||||
"stop_token_ids": req.stop_token_ids,
|
||||
}
|
||||
payload = {
|
||||
"input_ids": req.prompt_ids,
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": req.return_logprob,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
output = APIGenerateOutput.from_input(req)
|
||||
|
||||
# The following code is partially adopted from sglang/bench_serving.py
|
||||
output_ids = []
|
||||
output_logprobs = []
|
||||
finish_reason = {}
|
||||
ttft = 0.0
|
||||
latency = float("inf")
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=None)
|
||||
try:
|
||||
async with self.session.post(
|
||||
url=self.generate_url,
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||
latency = time.perf_counter() - st
|
||||
if chunk == "[DONE]":
|
||||
pass
|
||||
else:
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if data["token_ids"]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
output_ids = data["token_ids"]
|
||||
finish_reason = data["meta_info"]["finish_reason"]
|
||||
output_logprobs = data["meta_info"][
|
||||
"output_token_logprobs"
|
||||
]
|
||||
|
||||
assert finish_reason["type"] in ["length", "stop"], finish_reason
|
||||
output.output_logprobs = [x[0] for x in output_logprobs]
|
||||
output.output_ids = output_ids
|
||||
output.no_eos = finish_reason["type"] == "length"
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception as e:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
raise RuntimeError(
|
||||
f"SGLang generation request fails:\n{output.error}"
|
||||
) from e
|
||||
|
||||
return output
|
||||
|
||||
async def async_update_weights_from_disk(self, path):
|
||||
timeout = aiohttp.ClientTimeout(total=300, connect=30, sock_read=None)
|
||||
async with self.session.post(
|
||||
url=self.update_weights_url,
|
||||
json=dict(model_path=path),
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Update weights failed.")
|
||||
|
||||
|
||||
def sglang_server_process(server_args_dict):
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
sglang_version = version("sglang")
|
||||
if sglang_version < Version("0.4.3"):
|
||||
from sglang.srt.server import launch_server
|
||||
|
||||
server_args_dict.pop("enable_nccl_nvls")
|
||||
server_args_dict.pop("triton_attention_num_kv_splits")
|
||||
server_args_dict.pop("cuda_graph_bs")
|
||||
server_args_dict.pop("enable_memory_saver")
|
||||
server_args_dict.pop("allow_auto_truncate")
|
||||
server_args_dict.pop("return_hidden_states")
|
||||
else:
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
|
||||
server_args = ServerArgs(**server_args_dict)
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
map(str, list(range(gpu_utils.gpu_count())))
|
||||
)
|
||||
|
||||
try:
|
||||
launch_server(server_args)
|
||||
finally:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
|
||||
class SGLangGenerationEngine(PipelinableEngine):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args_dict: Dict,
|
||||
hybrid_train: bool,
|
||||
request_timeout: int = 1800,
|
||||
):
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
return
|
||||
# Start the serving process
|
||||
self.server_proc = mp.Process(
|
||||
target=sglang_server_process,
|
||||
args=(server_args_dict,),
|
||||
)
|
||||
self.server_proc.start()
|
||||
|
||||
self.base_url = f"http://{server_args_dict['host']}:{server_args_dict['port']}"
|
||||
|
||||
self.api_urls = {
|
||||
"generate": f"{self.base_url}/generate",
|
||||
"offload_weights": f"{self.base_url}/offload_weights",
|
||||
"init_kv_cache": f"{self.base_url}/init_kv_cache",
|
||||
"clear_kv_cache": f"{self.base_url}/clear_kv_cache",
|
||||
"init_model_weights": f"{self.base_url}/init_model_weights",
|
||||
"update_weights_from_disk": f"{self.base_url}/update_weights_from_disk",
|
||||
}
|
||||
|
||||
asyncio.run(self.wait_server())
|
||||
|
||||
self.request_timeout = request_timeout
|
||||
|
||||
# offload weights/cache
|
||||
self.hybrid_train = hybrid_train
|
||||
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "server_proc"):
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
self.server_proc.terminate()
|
||||
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
# NOTE: A placeholder function.
|
||||
def train(self, mode: bool = True):
|
||||
return self
|
||||
|
||||
# NOTE: A placeholder function.
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
async def wait_server(self):
|
||||
# Wait until the server is launched
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
success = False
|
||||
for _ in range(120):
|
||||
await asyncio.sleep(1)
|
||||
try:
|
||||
res = requests.get(
|
||||
self.base_url + "/get_model_info", timeout=5, headers={}
|
||||
)
|
||||
assert res.status_code == 200, f"{res=}, {res.text=}"
|
||||
success = True
|
||||
break
|
||||
except (AssertionError, requests.exceptions.RequestException):
|
||||
last_traceback = get_exception_traceback()
|
||||
pass
|
||||
if not success:
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_process_tree(os.getpid())
|
||||
return
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
input_: data_api.SequenceSample,
|
||||
mb_spec: data_api.MicroBatchSpec,
|
||||
tokenizer: transformers.PreTrainedTokenizerFast,
|
||||
gconfig: GenerationHyperparameters = dataclasses.field(
|
||||
default_factory=GenerationHyperparameters
|
||||
),
|
||||
stream: bool = False,
|
||||
disable_tqdm: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | None:
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=input_.bs * gconfig.n)
|
||||
|
||||
async with SGLangAPIClient(
|
||||
generate_url=self.api_urls["generate"],
|
||||
update_weights_url=self.api_urls["update_weights_from_disk"],
|
||||
) as client:
|
||||
tasks = []
|
||||
input_queries = []
|
||||
for d in input_.unpack():
|
||||
if len(d.seqlens["packed_input_ids"]) > 1:
|
||||
raise RuntimeError(
|
||||
f"sglang backend does not support grouped generation "
|
||||
f"for now. Group size {len(d.seqlens['packed_input_ids'])}."
|
||||
)
|
||||
|
||||
prompt_token_ids = d.data["packed_input_ids"].cpu().numpy().tolist()
|
||||
qid = d.ids[0]
|
||||
for group_idx in range(gconfig.n):
|
||||
req = APIGenerateInput(
|
||||
qid=qid,
|
||||
group_idx=group_idx,
|
||||
prompt_ids=prompt_token_ids,
|
||||
input_ids=prompt_token_ids,
|
||||
gconfig=gconfig.new(n=1),
|
||||
stop_token_ids=[tokenizer.pad_token_id, tokenizer.eos_token_id],
|
||||
return_logprob=True,
|
||||
)
|
||||
input_queries.append((qid, group_idx))
|
||||
tasks.append(
|
||||
client.async_add_generate_request(
|
||||
req,
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
for r in asyncio.as_completed(tasks):
|
||||
out = await r
|
||||
outputs[(out.qid, out.group_idx)] = out
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
results: List[APIGenerateOutput] = [outputs[key] for key in input_queries]
|
||||
|
||||
# Build the output: generated token ids, generated token scores,
|
||||
# and logits mask (which will always be None in sglang).
|
||||
batch_token_ids = []
|
||||
batch_logprobs = []
|
||||
max_seqlen = -1
|
||||
for x in results:
|
||||
max_seqlen = max(max_seqlen, len(x.output_ids))
|
||||
batch_token_ids.append(x.output_ids)
|
||||
batch_logprobs.append(x.output_logprobs)
|
||||
|
||||
# To be consistent with our internal implementation,
|
||||
# we should pad generated tokens and logprobs
|
||||
batch_token_ids = [
|
||||
t + [tokenizer.pad_token_id] * (max_seqlen - len(t))
|
||||
for t in batch_token_ids
|
||||
]
|
||||
batch_logprobs = [p + [0.0] * (max_seqlen - len(p)) for p in batch_logprobs]
|
||||
|
||||
return (
|
||||
torch.tensor(
|
||||
batch_token_ids, dtype=torch.long, device=constants.current_device()
|
||||
),
|
||||
torch.tensor(
|
||||
batch_logprobs, dtype=torch.float32, device=constants.current_device()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_: data_api.SequenceSample,
|
||||
mb_spec: data_api.MicroBatchSpec,
|
||||
tokenizer: transformers.PreTrainedTokenizerFast,
|
||||
gconfig: GenerationHyperparameters = dataclasses.field(
|
||||
default_factory=GenerationHyperparameters
|
||||
),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | None:
|
||||
if gconfig.min_new_tokens != 0:
|
||||
raise RuntimeError(
|
||||
"NOTE: passing in an arbitrary `min_new_tokens` will lead to a bug for SGLang v0.4.3 "
|
||||
"because we force to skip_tokenizer_init."
|
||||
)
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
return None, None, None
|
||||
|
||||
results = asyncio.run(
|
||||
self.async_generate(
|
||||
input_=input_,
|
||||
mb_spec=mb_spec,
|
||||
tokenizer=tokenizer,
|
||||
gconfig=gconfig,
|
||||
)
|
||||
)
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
return results
|
||||
|
||||
def update_weights_from_disk(self, path):
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
|
||||
async def _fn():
|
||||
async with SGLangAPIClient(
|
||||
generate_url=self.api_urls["generate"],
|
||||
update_weights_url=self.api_urls["update_weights_from_disk"],
|
||||
) as client:
|
||||
await client.async_update_weights_from_disk(path)
|
||||
|
||||
asyncio.run(_fn())
|
||||
dist.barrier(group=constants.model_parallel_group())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SGLangGenerationBackend(ModelBackend, SGLangConfig):
|
||||
model_path: str = ""
|
||||
dtype: str = "float16"
|
||||
|
||||
def _initialize(self, model: Model, spec: FinetuneSpec) -> Model:
|
||||
if constants.pipe_parallel_world_size() != 1:
|
||||
raise RuntimeError("SGLang does not support pipe parallel size > 1.")
|
||||
if constants.model_parallel_world_size() > cluster.spec.n_gpus_per_node:
|
||||
raise RuntimeError(
|
||||
"AReaL's SGLang integration does not support model parallel size > n_gpus_per_node."
|
||||
)
|
||||
|
||||
additional_args = dataclasses.asdict(self)
|
||||
additional_args.pop("hybrid_train")
|
||||
additional_args["random_seed"] = seeding.get_seed()
|
||||
|
||||
# For simplicity, we let all DP ranks have different ports.
|
||||
ports = [None for _ in range(constants.data_parallel_world_size())]
|
||||
while any(port is None for port in ports) or len(set(ports)) != len(ports):
|
||||
dist.all_gather_object(
|
||||
ports, network.find_free_port(), group=constants.data_parallel_group()
|
||||
)
|
||||
additional_args["port"] = ports[constants.data_parallel_rank()]
|
||||
|
||||
server_args_dict = dict(
|
||||
host="localhost",
|
||||
# Model and tokenizer
|
||||
tokenizer_path=self.model_path,
|
||||
tokenizer_mode="auto",
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
kv_cache_dtype="auto",
|
||||
device="cuda",
|
||||
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{constants.model_name().role}",
|
||||
is_embedding=False,
|
||||
skip_tokenizer_init=True,
|
||||
# Other runtime options
|
||||
tp_size=constants.model_parallel_world_size(),
|
||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
|
||||
file_storage_pth=os.path.join(
|
||||
constants.SGLANG_CACHE_PATH,
|
||||
f"sglang_storage{constants.data_parallel_rank()}",
|
||||
),
|
||||
# Data parallelism
|
||||
dp_size=1, # TODO: check whether we require SGLang dp
|
||||
load_balance_method="round_robin",
|
||||
# Expert parallelism
|
||||
ep_size=1, # TODO: check
|
||||
# Multi-node distributed serving
|
||||
dist_init_addr=None,
|
||||
nnodes=1,
|
||||
node_rank=0,
|
||||
**additional_args,
|
||||
)
|
||||
|
||||
model.module = SGLangGenerationEngine(
|
||||
server_args_dict,
|
||||
hybrid_train=self.hybrid_train,
|
||||
)
|
||||
model.backend_name = "sglang"
|
||||
return model
|
||||
|
||||
def load(self, model: Model, load_dir: str):
|
||||
model.module.update_weights_from_disk(load_dir)
|
||||
|
||||
|
||||
register_backend("sglang", SGLangGenerationBackend)
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import functools
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
from realhf.api.core import data_api, model_api
|
||||
from realhf.api.core.config import ModelName
|
||||
from realhf.api.core.data_api import MicroBatchSpec
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.testing import init_global_constants
|
||||
|
||||
logger = logging.getLogger("test sglang backend")
|
||||
|
||||
|
||||
def check_sequences_consistency(
|
||||
batched_seq1: torch.LongTensor, batched_seq2: torch.LongTensor
|
||||
):
|
||||
matched_tokens = 0
|
||||
matched_seqs = 0
|
||||
total_tokens = 0
|
||||
assert len(batched_seq1) == len(batched_seq2)
|
||||
for i in range(len(batched_seq1)):
|
||||
a = batched_seq1[i]
|
||||
b = batched_seq2[i]
|
||||
assert torch.is_tensor(a) and torch.is_tensor(b)
|
||||
assert a.dim() == 1 and b.dim() == 1, (a.shape, b.shape)
|
||||
gen_len = a.shape[0] if a.shape[0] < b.shape[0] else b.shape[0]
|
||||
b = b[:gen_len]
|
||||
a = a[:gen_len]
|
||||
for j in range(gen_len):
|
||||
if a[j] != b[j]:
|
||||
logger.info(f"Mismatch at sequence {i} position {j}")
|
||||
break
|
||||
matched_tokens += 1
|
||||
else:
|
||||
matched_seqs += 1
|
||||
total_tokens += gen_len
|
||||
logger.info(
|
||||
f"Matched {matched_seqs}/{len(batched_seq1)} "
|
||||
f"sequences and {matched_tokens}/{total_tokens} tokens"
|
||||
)
|
||||
return (
|
||||
matched_seqs,
|
||||
matched_tokens,
|
||||
float(matched_tokens) / total_tokens,
|
||||
float(matched_seqs) / len(batched_seq1),
|
||||
)
|
||||
|
||||
|
||||
def test_fn(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
path: str,
|
||||
model_family_name: str,
|
||||
dp: int,
|
||||
pp: int,
|
||||
tp: int,
|
||||
):
|
||||
assert not torch.cuda.is_initialized()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
||||
torch.cuda.set_device(0)
|
||||
assert world_size == (
|
||||
dp * pp * tp
|
||||
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
|
||||
# Initialize distributed environment.
|
||||
dist.init_process_group(
|
||||
"nccl",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
init_method="tcp://localhost:7777",
|
||||
)
|
||||
torch.cuda.set_device(0)
|
||||
model_name = ModelName("default", 0)
|
||||
constants.set_experiment_trial_names("slang-test", str(uuid.uuid4()))
|
||||
init_global_constants(
|
||||
num_dp=dp,
|
||||
num_mp=tp,
|
||||
num_pp=pp,
|
||||
sequence_parallel=False,
|
||||
model_name=model_name,
|
||||
max_prompt_len=128,
|
||||
)
|
||||
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
|
||||
|
||||
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
|
||||
transformers.AutoConfig.from_pretrained(
|
||||
path,
|
||||
trust_remote_code=True,
|
||||
force_download=True,
|
||||
)
|
||||
)
|
||||
with constants.model_scope(model_name):
|
||||
module = ReaLModel(mconfig, dtype=torch.float16, device="cuda")
|
||||
module._instantiation_hooks.append(
|
||||
lambda: getattr(module, f"from_{model_family_name}")(
|
||||
load_dir=path, init_critic_from_actor=False
|
||||
)
|
||||
)
|
||||
add_helper_functions(module)
|
||||
module.instantiate()
|
||||
module.eval()
|
||||
tokenizer = data_api.load_hf_tokenizer(path)
|
||||
|
||||
from realhf.impl.model.backend.sglang import SGLangGenerationBackend
|
||||
|
||||
backend = SGLangGenerationBackend(model_path=path)
|
||||
model = model_api.Model(
|
||||
name=model_name,
|
||||
module=module,
|
||||
tokenizer=tokenizer,
|
||||
device=module.device,
|
||||
dtype=module.dtype,
|
||||
)
|
||||
ft_spec = model_api.FinetuneSpec(
|
||||
total_train_epochs=1,
|
||||
dataset_size=100,
|
||||
train_batch_size=1,
|
||||
)
|
||||
model = backend.initialize(model, ft_spec)
|
||||
|
||||
gconfig = model_api.GenerationHyperparameters(
|
||||
n=1,
|
||||
max_new_tokens=32,
|
||||
min_new_tokens=0,
|
||||
greedy=True,
|
||||
top_p=1.0,
|
||||
top_k=int(1e8),
|
||||
temperature=1.0,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
bs = 8
|
||||
for i in range(1):
|
||||
seqlens = [torch.randint(5, 10, (1,)).cuda() for _ in range(bs)]
|
||||
|
||||
for s in seqlens:
|
||||
dist.broadcast(s, src=0)
|
||||
seqlens = [int(s) for s in seqlens]
|
||||
|
||||
token_ids = (
|
||||
torch.randint(0, mconfig.vocab_size, (sum(seqlens),)).long().cuda()
|
||||
)
|
||||
dist.broadcast(token_ids, src=0)
|
||||
|
||||
max_seqlen = max(seqlens)
|
||||
cu_seqlens = torch.nn.functional.pad(
|
||||
torch.tensor(seqlens, device="cuda").cumsum(0),
|
||||
(1, 0),
|
||||
).int()
|
||||
|
||||
res = module.generate(
|
||||
tokenizer=tokenizer,
|
||||
packed_input_ids=token_ids,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
gconfig=gconfig,
|
||||
)
|
||||
gen_tokens1 = res.sequences
|
||||
logprobs1 = res.scores
|
||||
|
||||
x = data_api.SequenceSample.from_default(
|
||||
seqlens=seqlens,
|
||||
ids=list(range(bs)),
|
||||
data=dict(packed_input_ids=token_ids),
|
||||
)
|
||||
gen_tokens2, logprobs2, _ = model.module.generate(
|
||||
input_=x,
|
||||
mb_spec=MicroBatchSpec(),
|
||||
tokenizer=tokenizer,
|
||||
gconfig=gconfig,
|
||||
)
|
||||
if constants.model_parallel_rank() == 0:
|
||||
# The outputs are Nones for tp_rank > 1 in SGLang
|
||||
_, _, token_match_percent, seq_match_percent = (
|
||||
check_sequences_consistency(gen_tokens1, gen_tokens2)
|
||||
)
|
||||
assert token_match_percent > 0.8, token_match_percent
|
||||
assert seq_match_percent > 0.8, seq_match_percent
|
||||
|
||||
print("success")
|
||||
|
||||
# 清理分布式环境
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_sglang_consistency(tp: int, dp: int, path: str, model_family_name: str):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
world_size = dp * tp
|
||||
procs = [
|
||||
mp.Process(
|
||||
target=test_fn,
|
||||
args=(
|
||||
i,
|
||||
world_size,
|
||||
),
|
||||
kwargs=dict(
|
||||
path=path,
|
||||
model_family_name=model_family_name,
|
||||
dp=dp,
|
||||
pp=1,
|
||||
tp=tp,
|
||||
),
|
||||
)
|
||||
for i in range(world_size)
|
||||
]
|
||||
try:
|
||||
for p in procs:
|
||||
p.start()
|
||||
for p in procs:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
[p.terminate() for p in procs]
|
||||
[p.join() for p in procs]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path = "/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
|
||||
model_family_name = "qwen2"
|
||||
# test_fn(
|
||||
# rank=0,
|
||||
# world_size=1,
|
||||
# path=path,
|
||||
# model_family_name=model_family_name,
|
||||
# dp=1,
|
||||
# pp=1,
|
||||
# tp=1,
|
||||
# )
|
||||
test_sglang_consistency(
|
||||
tp=2,
|
||||
dp=2,
|
||||
path=path,
|
||||
model_family_name=model_family_name,
|
||||
)
|
Loading…
Reference in New Issue