PullRequest: 38 Support SGLang generation based on the 24.07 docker image

Merge branch fw/sglang of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/38?tab=comment#note_168700036

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* format
* .
* .
* cleanup
This commit is contained in:
博惟 2025-03-17 15:45:00 +08:00
parent 7b3e33430a
commit 767bb7bf47
9 changed files with 912 additions and 7 deletions

View File

@ -3,10 +3,12 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import abc
import asyncio
import dataclasses
import keyword
from typing import *
import aiohttp
import numpy as np
import torch
import torch.utils.data
@ -106,6 +108,139 @@ class GenerationHyperparameters:
f"To use CUDAGraph, ReaL's PyTorch version should be at least 2.3.0."
)
def new(self, **kwargs):
args = dataclasses.asdict(self)
args.update(kwargs)
return GenerationHyperparameters(**args)
@dataclasses.dataclass
class APIGenerateInput:
qid: Hashable
group_idx: int
prompt_ids: List[int]
input_ids: List[int]
gconfig: GenerationHyperparameters
stop_token_ids: List[int] = dataclasses.field(default_factory=list)
return_logprob: bool = False
@dataclasses.dataclass
class APIGenerateOutput:
qid: Hashable
group_idx: int
prompt_ids: List[int]
input_ids: List[int]
output_ids: List[int] = dataclasses.field(default_factory=list)
output_logprobs: List[int] = dataclasses.field(default_factory=list)
no_eos: bool = True
success: bool = False
latency: float = 0.0
ttft: float = 0.0 # Time to first token
itl: List[float] = dataclasses.field(
default_factory=list
) # List of inter-token latencies
error: str = ""
@classmethod
def from_input(cls, inp: APIGenerateInput):
return cls(
qid=inp.qid,
group_idx=inp.group_idx,
prompt_ids=inp.prompt_ids,
input_ids=inp.input_ids,
)
@property
def output_len(self):
return len(self.output_ids)
@property
def input_len(self):
return len(self.input_ids)
@property
def prompt_len(self):
return len(self.prompt_ids)
@property
def gen_len(self):
return self.output_len + self.input_len - self.prompt_len
@dataclasses.dataclass
class BundledGenerationOutputs:
qid: Hashable
prompt_ids: List[int]
seqs: List[List[int]]
no_eos: List[bool]
@classmethod
def from_single(cls, outputs: List[APIGenerateOutput]):
assert len(set(o.qid for o in outputs)) == 1
return cls(
qid=outputs[0].qid,
prompt_ids=outputs[0].prompt_ids,
seqs=[o.input_ids + o.output_ids for o in outputs],
no_eos=[o.no_eos for o in outputs],
)
@property
def seqlens(self):
return [len(seq) for seq in self.seqs]
@property
def prompt_len(self):
return len(self.prompt_ids)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
class LLMAPIClient:
def __init__(
self, generate_url: str, update_weights_url: str, concurrency_limit: int = -1
):
self.update_weights_url = update_weights_url
self.generate_url = generate_url
self.concurrency_limit = concurrency_limit
self.session: aiohttp.ClientSession
self.semaphore: asyncio.Semaphore
async def __aenter__(self):
conn = aiohttp.TCPConnector(limit=0, ttl_dns_cache=300)
self.session = aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT,
connector=conn,
read_bufsize=1024 * 1024 * 10,
)
if self.concurrency_limit > 0:
self.semaphore = asyncio.Semaphore(self.concurrency_limit)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
async def async_add_generate_request(
self, req: APIGenerateInput, stream: bool = True
) -> APIGenerateOutput:
if self.concurrency_limit > 0:
async with self.semaphore:
return await self._do_generate(req, stream=stream)
else:
return await self._do_generate(req, stream=stream)
async def _do_generate(
self, req: APIGenerateInput, stream: bool = True
) -> APIGenerateOutput:
raise NotImplementedError()
async def async_update_weights_from_disk(self, path):
raise NotImplementedError()
@dataclasses.dataclass
class ReaLMoEConfig:

View File

@ -131,6 +131,48 @@ class vLLMConfig:
additional_engine_args: Dict = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class SGLangConfig:
disable_cuda_graph: bool = False
disable_radix_cache: bool = False
disable_jump_forward: bool = False
disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_mla: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_ep_moe: bool = False
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
triton_attention_num_kv_splits: int = 8
num_continuous_decode_steps: int = 1
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
return_hidden_states: bool = False
# NOTE: to avoid the illegal memory access error
attention_backend: Optional[str] = "triton"
sampling_backend: Optional[str] = None
context_length: Optional[int] = None
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
hybrid_train: bool = False
@dataclasses.dataclass
class DistributedDataParallelConfig:
"""Configuration for Megatron DistributedDataParallel.
@ -249,6 +291,7 @@ class ModelTrainEvalConfig:
)
megatron: MegatronConfig = dataclasses.field(default_factory=MegatronConfig)
vllm: vLLMConfig = dataclasses.field(default_factory=vLLMConfig)
sglang: SGLangConfig = dataclasses.field(default_factory=SGLangConfig)
init_from_scratch: bool = False
init_critic_from_actor: bool = False

View File

@ -80,6 +80,7 @@ TRITON_CACHE_PATH = f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/triton"
DATASET_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/datasets"
PROFILER_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/profiler"
PARAM_REALLOC_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/param_realloc"
SGLANG_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/sglang"
TORCH_EXTENSIONS_DIR = (
f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/torch/extensions"
)
@ -165,6 +166,7 @@ os.makedirs(DATASET_CACHE_PATH, exist_ok=True)
os.makedirs(PROFILER_CACHE_PATH, exist_ok=True)
os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True)
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
os.makedirs(SGLANG_CACHE_PATH, exist_ok=True)
# _model_name will be changed in the model_scope context manager
_model_name: "ModelName" = None
@ -451,6 +453,10 @@ def prev_pipe_stage():
) % pipe_parallel_world_size()
def is_dp_head():
return is_last_pipe_stage() and model_parallel_rank() == 0
def model_parallel_rank() -> int:
"""Return the rank inside the tensor parallelism group."""
try:

View File

@ -6,12 +6,15 @@ import socket
from contextlib import closing
def find_free_port():
def find_free_port(low=1, high=65536):
"""From stackoverflow Issue 1365265."""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
if low <= port <= high:
return port
def gethostname():

View File

@ -8,7 +8,7 @@ from typing import List
from packaging.version import Version
from realhf.api.quickstart.device_mesh import RPCAllocation
from realhf.api.quickstart.model import ModelTrainEvalConfig, vLLMConfig
from realhf.api.quickstart.model import ModelTrainEvalConfig, SGLangConfig, vLLMConfig
from realhf.base import logging
logger = logging.getLogger(__name__)
@ -35,6 +35,18 @@ def check_valid_vllm(role: str, vllm: vLLMConfig, rpc_allocs: List[RPCAllocation
)
def check_valid_sglang(
role: str, sglang: SGLangConfig, rpc_allocs: List[RPCAllocation]
):
rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.role == role]
if sglang.hybrid_train and not any(rpc.is_train() for rpc in rpcs):
logger.warning(
"SGLang hybrid_train is enabled, but no training RPCs are found."
)
if sglang.hybrid_train and not sglang.disable_cuda_graph:
raise ValueError("SGLang hybrid_train requires CUDA graph to be disabled.")
def check_valid_optimizer(model: ModelTrainEvalConfig):
if model.optimizer.min_lr_ratio < 0.0 or model.optimizer.min_lr_ratio > 1.0:
raise ValueError(f"Invalid min_lr_ratio: {model.optimizer.min_lr_ratio}")

View File

@ -52,6 +52,7 @@ from realhf.experiments.common.check import (
check_valid_model_and_path,
check_valid_optimizer,
check_valid_parallel_batch_size,
check_valid_sglang,
check_valid_vllm,
)
from realhf.experiments.common.utils import (
@ -596,6 +597,8 @@ class CommonExperimentConfig(Experiment):
if gen_backend_name == "vllm":
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
elif gen_backend_name == "sglang":
check_valid_sglang(model_name.role, model_cfg.sglang, rpc_allocs)
shard_idx = shard_counter[model_name]
dict_args: Dict[str, Any] = asdict(backend_cfg)
@ -693,6 +696,9 @@ class CommonExperimentConfig(Experiment):
rpc.is_generate() for rpc in rpcs
):
assert len(rpcs) == 1 and rpcs[0].is_generate(), rpcs
assert (
not model_cfg.sglang.hybrid_train
), "vLLM and SGLang cannot be enabled at the same time"
dict_args: Dict[str, Any] = asdict(model_cfg.vllm)
check_valid_vllm(model_name.role, model_cfg.vllm, rpc_allocs)
backend = ModelBackendAbstraction(
@ -702,6 +708,12 @@ class CommonExperimentConfig(Experiment):
**dict_args,
),
)
elif model_cfg.sglang.hybrid_train and any(
rpc.is_generate() for rpc in rpcs
):
raise NotImplementedError(
"SGLang hybrid_train=True is not supported yet."
)
else:
backend = make_inf_backend_config(model_cfg, rpc_alloc.parallel)
if any(rpc.is_generate() for rpc in rpcs) and backend.type_ not in [

View File

@ -138,7 +138,7 @@ def resolve_replica_ids(
if not same_alloc or (
alloc.rpc.is_generate()
and main_alloc.rpc.is_train()
and (models[role].vllm.hybrid_train)
and (models[role].vllm.hybrid_train or models[role].sglang.hybrid_train)
):
alloc.rpc.model_name = ModelName(role, i)
i += 1

View File

@ -0,0 +1,452 @@
# Copyright 2025 Ant Group Inc.
import asyncio
import dataclasses
import json
import os
import sys
import time
import traceback
from importlib.metadata import version
from typing import Dict, List, Tuple
import aiohttp
import requests
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import transformers
from packaging.version import Version
from tqdm.asyncio import tqdm
from realhf.api.core import data_api
from realhf.api.core.model_api import (
APIGenerateInput,
APIGenerateOutput,
FinetuneSpec,
GenerationHyperparameters,
LLMAPIClient,
Model,
ModelBackend,
PipelinableEngine,
register_backend,
)
from realhf.api.quickstart.model import SGLangConfig
from realhf.base import cluster, constants, gpu_utils, logging, network, seeding
logger = logging.getLogger("SGLang backend")
def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
class SGLangAPIClient(LLMAPIClient):
async def _do_generate(
self, req: APIGenerateInput, stream: bool = False
) -> APIGenerateOutput:
gconfig = req.gconfig
sample_params = {
"n": gconfig.n,
"top_p": gconfig.top_p,
"top_k": gconfig.top_k,
"max_new_tokens": gconfig.max_new_tokens,
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
"stop_token_ids": req.stop_token_ids,
}
payload = {
"input_ids": req.prompt_ids,
"sampling_params": sample_params,
"return_logprob": req.return_logprob,
"stream": stream,
}
output = APIGenerateOutput.from_input(req)
# The following code is partially adopted from sglang/bench_serving.py
output_ids = []
output_logprobs = []
finish_reason = {}
ttft = 0.0
latency = float("inf")
st = time.perf_counter()
most_recent_timestamp = st
timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=None)
try:
async with self.session.post(
url=self.generate_url,
json=payload,
timeout=timeout,
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["token_ids"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
output_ids = data["token_ids"]
finish_reason = data["meta_info"]["finish_reason"]
output_logprobs = data["meta_info"][
"output_token_logprobs"
]
assert finish_reason["type"] in ["length", "stop"], finish_reason
output.output_logprobs = [x[0] for x in output_logprobs]
output.output_ids = output_ids
output.no_eos = finish_reason["type"] == "length"
output.success = True
output.latency = latency
else:
output.error = response.reason or ""
output.success = False
except Exception as e:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
raise RuntimeError(
f"SGLang generation request fails:\n{output.error}"
) from e
return output
async def async_update_weights_from_disk(self, path):
timeout = aiohttp.ClientTimeout(total=300, connect=30, sock_read=None)
async with self.session.post(
url=self.update_weights_url,
json=dict(model_path=path),
timeout=timeout,
) as response:
if response.status != 200:
raise RuntimeError("Update weights failed.")
def sglang_server_process(server_args_dict):
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
sglang_version = version("sglang")
if sglang_version < Version("0.4.3"):
from sglang.srt.server import launch_server
server_args_dict.pop("enable_nccl_nvls")
server_args_dict.pop("triton_attention_num_kv_splits")
server_args_dict.pop("cuda_graph_bs")
server_args_dict.pop("enable_memory_saver")
server_args_dict.pop("allow_auto_truncate")
server_args_dict.pop("return_hidden_states")
else:
from sglang.srt.entrypoints.http_server import launch_server
server_args = ServerArgs(**server_args_dict)
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, list(range(gpu_utils.gpu_count())))
)
try:
launch_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)
class SGLangGenerationEngine(PipelinableEngine):
def __init__(
self,
server_args_dict: Dict,
hybrid_train: bool,
request_timeout: int = 1800,
):
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_group())
return
# Start the serving process
self.server_proc = mp.Process(
target=sglang_server_process,
args=(server_args_dict,),
)
self.server_proc.start()
self.base_url = f"http://{server_args_dict['host']}:{server_args_dict['port']}"
self.api_urls = {
"generate": f"{self.base_url}/generate",
"offload_weights": f"{self.base_url}/offload_weights",
"init_kv_cache": f"{self.base_url}/init_kv_cache",
"clear_kv_cache": f"{self.base_url}/clear_kv_cache",
"init_model_weights": f"{self.base_url}/init_model_weights",
"update_weights_from_disk": f"{self.base_url}/update_weights_from_disk",
}
asyncio.run(self.wait_server())
self.request_timeout = request_timeout
# offload weights/cache
self.hybrid_train = hybrid_train
dist.barrier(group=constants.model_parallel_group())
def __del__(self):
if hasattr(self, "server_proc"):
from sglang.srt.utils import kill_process_tree
self.server_proc.terminate()
kill_process_tree(os.getpid())
# NOTE: A placeholder function.
def train(self, mode: bool = True):
return self
# NOTE: A placeholder function.
def eval(self):
return self
async def wait_server(self):
# Wait until the server is launched
from sglang.srt.utils import kill_process_tree
from sglang.utils import get_exception_traceback
success = False
for _ in range(120):
await asyncio.sleep(1)
try:
res = requests.get(
self.base_url + "/get_model_info", timeout=5, headers={}
)
assert res.status_code == 200, f"{res=}, {res.text=}"
success = True
break
except (AssertionError, requests.exceptions.RequestException):
last_traceback = get_exception_traceback()
pass
if not success:
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_process_tree(os.getpid())
return
async def async_generate(
self,
input_: data_api.SequenceSample,
mb_spec: data_api.MicroBatchSpec,
tokenizer: transformers.PreTrainedTokenizerFast,
gconfig: GenerationHyperparameters = dataclasses.field(
default_factory=GenerationHyperparameters
),
stream: bool = False,
disable_tqdm: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | None:
pbar = None if disable_tqdm else tqdm(total=input_.bs * gconfig.n)
async with SGLangAPIClient(
generate_url=self.api_urls["generate"],
update_weights_url=self.api_urls["update_weights_from_disk"],
) as client:
tasks = []
input_queries = []
for d in input_.unpack():
if len(d.seqlens["packed_input_ids"]) > 1:
raise RuntimeError(
f"sglang backend does not support grouped generation "
f"for now. Group size {len(d.seqlens['packed_input_ids'])}."
)
prompt_token_ids = d.data["packed_input_ids"].cpu().numpy().tolist()
qid = d.ids[0]
for group_idx in range(gconfig.n):
req = APIGenerateInput(
qid=qid,
group_idx=group_idx,
prompt_ids=prompt_token_ids,
input_ids=prompt_token_ids,
gconfig=gconfig.new(n=1),
stop_token_ids=[tokenizer.pad_token_id, tokenizer.eos_token_id],
return_logprob=True,
)
input_queries.append((qid, group_idx))
tasks.append(
client.async_add_generate_request(
req,
stream=stream,
)
)
outputs = {}
for r in asyncio.as_completed(tasks):
out = await r
outputs[(out.qid, out.group_idx)] = out
if pbar:
pbar.update(1)
if pbar is not None:
pbar.close()
results: List[APIGenerateOutput] = [outputs[key] for key in input_queries]
# Build the output: generated token ids, generated token scores,
# and logits mask (which will always be None in sglang).
batch_token_ids = []
batch_logprobs = []
max_seqlen = -1
for x in results:
max_seqlen = max(max_seqlen, len(x.output_ids))
batch_token_ids.append(x.output_ids)
batch_logprobs.append(x.output_logprobs)
# To be consistent with our internal implementation,
# we should pad generated tokens and logprobs
batch_token_ids = [
t + [tokenizer.pad_token_id] * (max_seqlen - len(t))
for t in batch_token_ids
]
batch_logprobs = [p + [0.0] * (max_seqlen - len(p)) for p in batch_logprobs]
return (
torch.tensor(
batch_token_ids, dtype=torch.long, device=constants.current_device()
),
torch.tensor(
batch_logprobs, dtype=torch.float32, device=constants.current_device()
),
None,
)
def generate(
self,
input_: data_api.SequenceSample,
mb_spec: data_api.MicroBatchSpec,
tokenizer: transformers.PreTrainedTokenizerFast,
gconfig: GenerationHyperparameters = dataclasses.field(
default_factory=GenerationHyperparameters
),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | None:
if gconfig.min_new_tokens != 0:
raise RuntimeError(
"NOTE: passing in an arbitrary `min_new_tokens` will lead to a bug for SGLang v0.4.3 "
"because we force to skip_tokenizer_init."
)
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_group())
return None, None, None
results = asyncio.run(
self.async_generate(
input_=input_,
mb_spec=mb_spec,
tokenizer=tokenizer,
gconfig=gconfig,
)
)
dist.barrier(group=constants.model_parallel_group())
return results
def update_weights_from_disk(self, path):
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_group())
async def _fn():
async with SGLangAPIClient(
generate_url=self.api_urls["generate"],
update_weights_url=self.api_urls["update_weights_from_disk"],
) as client:
await client.async_update_weights_from_disk(path)
asyncio.run(_fn())
dist.barrier(group=constants.model_parallel_group())
@dataclasses.dataclass
class SGLangGenerationBackend(ModelBackend, SGLangConfig):
model_path: str = ""
dtype: str = "float16"
def _initialize(self, model: Model, spec: FinetuneSpec) -> Model:
if constants.pipe_parallel_world_size() != 1:
raise RuntimeError("SGLang does not support pipe parallel size > 1.")
if constants.model_parallel_world_size() > cluster.spec.n_gpus_per_node:
raise RuntimeError(
"AReaL's SGLang integration does not support model parallel size > n_gpus_per_node."
)
additional_args = dataclasses.asdict(self)
additional_args.pop("hybrid_train")
additional_args["random_seed"] = seeding.get_seed()
# For simplicity, we let all DP ranks have different ports.
ports = [None for _ in range(constants.data_parallel_world_size())]
while any(port is None for port in ports) or len(set(ports)) != len(ports):
dist.all_gather_object(
ports, network.find_free_port(), group=constants.data_parallel_group()
)
additional_args["port"] = ports[constants.data_parallel_rank()]
server_args_dict = dict(
host="localhost",
# Model and tokenizer
tokenizer_path=self.model_path,
tokenizer_mode="auto",
load_format="auto",
trust_remote_code=True,
kv_cache_dtype="auto",
device="cuda",
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{constants.model_name().role}",
is_embedding=False,
skip_tokenizer_init=True,
# Other runtime options
tp_size=constants.model_parallel_world_size(),
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
file_storage_pth=os.path.join(
constants.SGLANG_CACHE_PATH,
f"sglang_storage{constants.data_parallel_rank()}",
),
# Data parallelism
dp_size=1, # TODO: check whether we require SGLang dp
load_balance_method="round_robin",
# Expert parallelism
ep_size=1, # TODO: check
# Multi-node distributed serving
dist_init_addr=None,
nnodes=1,
node_rank=0,
**additional_args,
)
model.module = SGLangGenerationEngine(
server_args_dict,
hybrid_train=self.hybrid_train,
)
model.backend_name = "sglang"
return model
def load(self, model: Model, load_dir: str):
model.module.update_weights_from_disk(load_dir)
register_backend("sglang", SGLangGenerationBackend)

View File

@ -0,0 +1,242 @@
# Copyright 2025 Ant Group Inc.
import functools
import multiprocessing as mp
import os
import uuid
import numpy as np
import torch
import torch.distributed as dist
import transformers
from realhf.api.core import data_api, model_api
from realhf.api.core.config import ModelName
from realhf.api.core.data_api import MicroBatchSpec
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import constants, logging
from realhf.base.testing import init_global_constants
logger = logging.getLogger("test sglang backend")
def check_sequences_consistency(
batched_seq1: torch.LongTensor, batched_seq2: torch.LongTensor
):
matched_tokens = 0
matched_seqs = 0
total_tokens = 0
assert len(batched_seq1) == len(batched_seq2)
for i in range(len(batched_seq1)):
a = batched_seq1[i]
b = batched_seq2[i]
assert torch.is_tensor(a) and torch.is_tensor(b)
assert a.dim() == 1 and b.dim() == 1, (a.shape, b.shape)
gen_len = a.shape[0] if a.shape[0] < b.shape[0] else b.shape[0]
b = b[:gen_len]
a = a[:gen_len]
for j in range(gen_len):
if a[j] != b[j]:
logger.info(f"Mismatch at sequence {i} position {j}")
break
matched_tokens += 1
else:
matched_seqs += 1
total_tokens += gen_len
logger.info(
f"Matched {matched_seqs}/{len(batched_seq1)} "
f"sequences and {matched_tokens}/{total_tokens} tokens"
)
return (
matched_seqs,
matched_tokens,
float(matched_tokens) / total_tokens,
float(matched_seqs) / len(batched_seq1),
)
def test_fn(
rank: int,
world_size: int,
path: str,
model_family_name: str,
dp: int,
pp: int,
tp: int,
):
assert not torch.cuda.is_initialized()
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
torch.cuda.set_device(0)
assert world_size == (
dp * pp * tp
), f"dp={dp}, pp={pp}, tp={tp}, world_size={world_size}"
# Initialize distributed environment.
dist.init_process_group(
"nccl",
rank=rank,
world_size=world_size,
init_method="tcp://localhost:7777",
)
torch.cuda.set_device(0)
model_name = ModelName("default", 0)
constants.set_experiment_trial_names("slang-test", str(uuid.uuid4()))
init_global_constants(
num_dp=dp,
num_mp=tp,
num_pp=pp,
sequence_parallel=False,
model_name=model_name,
max_prompt_len=128,
)
from realhf.impl.model.nn.real_llm_api import ReaLModel, add_helper_functions
mconfig: ReaLModelConfig = getattr(ReaLModel, f"config_from_{model_family_name}")(
transformers.AutoConfig.from_pretrained(
path,
trust_remote_code=True,
force_download=True,
)
)
with constants.model_scope(model_name):
module = ReaLModel(mconfig, dtype=torch.float16, device="cuda")
module._instantiation_hooks.append(
lambda: getattr(module, f"from_{model_family_name}")(
load_dir=path, init_critic_from_actor=False
)
)
add_helper_functions(module)
module.instantiate()
module.eval()
tokenizer = data_api.load_hf_tokenizer(path)
from realhf.impl.model.backend.sglang import SGLangGenerationBackend
backend = SGLangGenerationBackend(model_path=path)
model = model_api.Model(
name=model_name,
module=module,
tokenizer=tokenizer,
device=module.device,
dtype=module.dtype,
)
ft_spec = model_api.FinetuneSpec(
total_train_epochs=1,
dataset_size=100,
train_batch_size=1,
)
model = backend.initialize(model, ft_spec)
gconfig = model_api.GenerationHyperparameters(
n=1,
max_new_tokens=32,
min_new_tokens=0,
greedy=True,
top_p=1.0,
top_k=int(1e8),
temperature=1.0,
use_cuda_graph=False,
)
bs = 8
for i in range(1):
seqlens = [torch.randint(5, 10, (1,)).cuda() for _ in range(bs)]
for s in seqlens:
dist.broadcast(s, src=0)
seqlens = [int(s) for s in seqlens]
token_ids = (
torch.randint(0, mconfig.vocab_size, (sum(seqlens),)).long().cuda()
)
dist.broadcast(token_ids, src=0)
max_seqlen = max(seqlens)
cu_seqlens = torch.nn.functional.pad(
torch.tensor(seqlens, device="cuda").cumsum(0),
(1, 0),
).int()
res = module.generate(
tokenizer=tokenizer,
packed_input_ids=token_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
gconfig=gconfig,
)
gen_tokens1 = res.sequences
logprobs1 = res.scores
x = data_api.SequenceSample.from_default(
seqlens=seqlens,
ids=list(range(bs)),
data=dict(packed_input_ids=token_ids),
)
gen_tokens2, logprobs2, _ = model.module.generate(
input_=x,
mb_spec=MicroBatchSpec(),
tokenizer=tokenizer,
gconfig=gconfig,
)
if constants.model_parallel_rank() == 0:
# The outputs are Nones for tp_rank > 1 in SGLang
_, _, token_match_percent, seq_match_percent = (
check_sequences_consistency(gen_tokens1, gen_tokens2)
)
assert token_match_percent > 0.8, token_match_percent
assert seq_match_percent > 0.8, seq_match_percent
print("success")
# 清理分布式环境
dist.destroy_process_group()
def test_sglang_consistency(tp: int, dp: int, path: str, model_family_name: str):
mp.set_start_method("spawn", force=True)
world_size = dp * tp
procs = [
mp.Process(
target=test_fn,
args=(
i,
world_size,
),
kwargs=dict(
path=path,
model_family_name=model_family_name,
dp=dp,
pp=1,
tp=tp,
),
)
for i in range(world_size)
]
try:
for p in procs:
p.start()
for p in procs:
p.join()
except KeyboardInterrupt:
[p.terminate() for p in procs]
[p.join() for p in procs]
if __name__ == "__main__":
path = "/storage/openpsi/models/Qwen__Qwen2-1.5B-Instruct/"
model_family_name = "qwen2"
# test_fn(
# rank=0,
# world_size=1,
# path=path,
# model_family_name=model_family_name,
# dp=1,
# pp=1,
# tp=1,
# )
test_sglang_consistency(
tp=2,
dp=2,
path=path,
model_family_name=model_family_name,
)