PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .
This commit is contained in:
博惟 2025-07-14 15:20:17 +08:00 committed by 晓雷
parent d8038b2669
commit 724628eaf0
9 changed files with 370 additions and 208 deletions

View File

@ -92,7 +92,7 @@ def main_grpo():
future.result()
# synchronous rollout
rollout_batch = rollout.rollout(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
rollout_batch = rollout.rollout_batch(batch, workflow=MyRolloutWorkflow(rollout_config.workflow))
# or asynchronous rollout with filtering and off-policyness control
# rollout_batch = rollout.prepare_batch(batch,
# workflow=MyRolloutWorkflow(rollout_config.workflow),
@ -697,7 +697,7 @@ reward = TrainController(Critic())
rollout_controller = RolloutController(...)
for _ in range(epochs):
for _ in range(steps_per_epoch):
data = rollout_controller.rollout(prompt)
data = rollout_controller.rollout_batch(prompt)
data['reward'] = reward.compute_values(data)
...
```

View File

@ -199,6 +199,9 @@ class SGLangConfig:
https://github.com/sgl-project/sglang for detailed documentation.
"""
model_path: str = ""
random_seed: int = 1
skip_tokenizer_init: bool = False
disable_cuda_graph: bool = False
disable_radix_cache: bool = False
disable_cuda_graph_padding: bool = False
@ -234,10 +237,8 @@ class SGLangConfig:
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
dtype: str = "float16"
kv_cache_dtype: str = "auto"
# logging
log_level: str = "warning"
log_level_http: Optional[str] = "warning"
@ -253,55 +254,60 @@ class SGLangConfig:
@staticmethod
def build_cmd(
sglang_config: "SGLangConfig",
model_path,
tp_size,
base_gpu_id,
host,
port,
dist_init_addr: Optional[str] = None,
served_model_name: Optional[str] = None,
skip_tokenizer_init: bool = True,
sglang_version: Optional[str] = None,
):
from realhf.base import network, pkg_version, seeding
from realhf.base import pkg_version
from realhf.experiments.common.utils import asdict as conf_as_dict
args: Dict = conf_as_dict(sglang_config)
args["random_seed"] = seeding.get_seed()
if served_model_name is None:
served_model_name = model_path
host_ip = network.gethostip()
host = "localhost" if not sglang_config.enable_metrics else host_ip
args = dict(
host=host,
model_path=model_path,
port=port,
# Model and tokenizer
tokenizer_path=model_path,
tokenizer_path=sglang_config.model_path,
tokenizer_mode="auto",
load_format="auto",
trust_remote_code=True,
device="cuda",
served_model_name=served_model_name,
is_embedding=False,
skip_tokenizer_init=skip_tokenizer_init,
# Other runtime options
tp_size=tp_size,
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
base_gpu_id=base_gpu_id,
nnodes=1,
node_rank=0,
# initialization addresses and ports
dist_init_addr=dist_init_addr,
**args,
)
if pkg_version.is_version_less("sglang", "0.4.4"):
if sglang_version:
version_less_than_0_4_4 = (
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
)
version_less_than_0_4_3 = (
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
)
elif pkg_version.is_available("sglang"):
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
else:
raise ValueError(
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
)
if version_less_than_0_4_4:
args.pop("log_requests_level")
if pkg_version.is_version_less("sglang", "0.4.3"):
if version_less_than_0_4_3:
args.pop("enable_nccl_nvls")
args.pop("triton_attention_num_kv_splits")
args.pop("cuda_graph_bs")
args.pop("enable_memory_saver")
args.pop("allow_auto_truncate")
args.pop("file_storage_path")
flags = []
for k, v in args.items():
if v is None or v is False or v == "":
@ -320,8 +326,8 @@ class SGLangConfig:
@dataclass
class InferenceEngineConfig:
experiment_name: str
trial_name: str
experiment_name: str = MISSING
trial_name: str = MISSING
max_concurrent_rollouts: None | int = field(
default=None,
metadata={
@ -345,27 +351,20 @@ class InferenceEngineConfig:
},
)
# Used by remote inference engines.
server_addrs: List[str] = field(
default_factory=list,
metadata={"help": "List of server addresses for inference."},
)
enable_rollout_tracing: bool = field(default=False)
schedule_policy: str = field(
default="round_robin",
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
)
setup_timeout: float = field(default=90.0)
request_timeout: float = field(
default=30.0, metadata={"help": "Timeout for HTTP requests."}
default=3600, metadata={"help": "Timeout for HTTP requests."}
)
request_retries: int = field(
default=3, metadata={"help": "Number of retries for failed requests."}
)
@dataclass
class SGLangEngineConfig:
pass
@dataclass
class _Timer:
experiment_name: str = MISSING
@ -595,42 +594,53 @@ class BaseExperimentConfig:
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
stats_logger: StatsLoggerConfig = field(default_factory=StatsLoggerConfig)
server_only: bool = False
sglang: SGLangConfig = field(default_factory=SGLangConfig)
@dataclass
class SFTConfig(BaseExperimentConfig):
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
def parse_cli_args(argv: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", help="The path of the main configuration file", required=True
)
args, overrides = parser.parse_known_args(argv)
# Initialize hydra config
config_file = Path(args.config).absolute()
assert config_file.exists()
# hydra only recognize relative paths
relpath = Path(
os.path.relpath(str(config_file), (Path(__file__).parent).absolute())
)
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None)
cfg = hydra_compose(
config_name=str(relpath.name).rstrip(".yaml"),
config_name=str(relpath.name).split(".yaml")[0],
overrides=overrides,
)
return cfg, config_file
def to_structured_cfg(cfg, config_cls):
# Merge with the default configuration.
# The yaml and commandline can omit some default values defined in python dataclasses.
default_cfg = OmegaConf.structured(config_cls)
cfg = OmegaConf.merge(default_cfg, cfg)
return cfg
def load_expr_config(argv: List[str], config_cls):
cfg, config_file = parse_cli_args(argv)
cfg = to_structured_cfg(cfg, config_cls=config_cls)
cfg = OmegaConf.to_object(cfg)
assert isinstance(cfg, BaseExperimentConfig)
# Setup environment
from realhf.base import constants, name_resolve
from realhf.base import constants, name_resolve, names
constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name)
name_resolve.reconfigure(cfg.cluster.name_resolve)
name_resolve.clear_subtree(
names.trial_root(experiment_name=cfg.experiment_name, trial_name=cfg.trial_name)
)
return cfg, str(config_file)

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.io_struct import (
FinetuneSpec,
@ -77,9 +78,9 @@ class TrainEngine(abc.ABC):
def train_batch(
self,
input_: Dict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> Dict[str, float]:
"""Update the model with a batch of data and a loss function."""
raise NotImplementedError()
@ -87,9 +88,9 @@ class TrainEngine(abc.ABC):
@torch.no_grad()
def eval_batch(
self,
input_: Dict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> torch.Tensor | None:
"""Evaluate the model using the forward pass and loss function."""
raise NotImplementedError()
@ -97,9 +98,9 @@ class TrainEngine(abc.ABC):
@torch.no_grad()
def forward(
self,
input_: Dict,
input_: TensorDict,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Run the forward pass or inference on the model. Note that it is gradient-free."""
@ -127,12 +128,33 @@ class InferenceEngine(abc.ABC):
"""Asynchronously submit a request to the inference engine. Exits immediately."""
raise NotImplementedError()
def wait(self, count: int, timeout: float) -> TensorDict:
def wait(
self,
count: int,
timeout: float | None = None,
should_accept: Callable | None = None,
) -> TensorDict:
"""Wait for a specified number of requests to complete, with a timeout."""
raise NotImplementedError()
def rollout(
def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
raise NotImplementedError()
def prepare_batch(
self,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
):
"""Asynchronously submit and wait until a full batch is ready."""
raise NotImplementedError()
def pause(self):
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
raise NotImplementedError()
def resume(self):
"""Resume request submission for async rollout."""
raise NotImplementedError()

View File

@ -16,7 +16,6 @@ from arealite.api.cli_args import GenerationHyperparameters
@dataclass
class LLMRequest:
rid: str = field(default_factory=lambda: str(uuid.uuid4()))
text: Optional[str] = None
input_ids: List[int] = field(default_factory=list)
gconfig: GenerationHyperparameters = field(
default_factory=GenerationHyperparameters
@ -28,7 +27,6 @@ class LLMRequest:
@dataclass
class LLMResponse:
# outputs
completions: str
input_tokens: List[int] = field(default_factory=list)
output_tokens: List[int] = field(default_factory=list)
output_logprobs: List[float] = field(default_factory=list)

View File

@ -130,11 +130,6 @@ class FSDPEngine(TrainEngine):
)
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
# Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy(
param_dtype=dtype,
@ -318,7 +313,9 @@ class FSDPEngine(TrainEngine):
self.config.trial_name,
meta.model_version,
)
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
name_resolve.add(
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
)
else:
raise ValueError(f"Unknown weight update type {meta.type}")

View File

@ -1,23 +1,30 @@
import asyncio
import os
import random
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional
import aiohttp
import requests
import torch.distributed as dist
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import InferenceEngineConfig
from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import (
FinetuneSpec,
LLMRequest,
LLMResponse,
RolloutStat,
WeightUpdateMeta,
)
from arealite.utils.padding import concat_padded_tensors
from realhf.base import logging, name_resolve, names, pkg_version
if TYPE_CHECKING:
@ -30,7 +37,7 @@ if pkg_version.is_available("sglang"):
else:
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
ROLLOUT_POLL_WAIT_TIME = 0.4
ROLLOUT_POLL_WAIT_TIME = 0.1
RID_CACHE_SIZE = 128
@ -46,22 +53,51 @@ class RemoteSGLangEngine(InferenceEngine):
# Maintain the addresses for the recent 128 requests
self.rid_queue = []
self.addresses = config.server_addrs
self.server_idx = 0
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
if not self.addresses:
raise RuntimeError("No configured SGLang servers.")
logger.info("Waiting for server ready...")
for addr in self.addresses:
self._wait_for_server(addr)
logger.info("Servers are all ready!")
qsize = config.queue_size or config.max_concurrent_rollouts * 10
self.server_idx = random.randint(0, len(self.addresses) - 1)
qsize = config.queue_size or config.max_concurrent_rollouts * 16
self.input_queue = Queue(maxsize=qsize)
self.output_queue = Queue(maxsize=qsize)
self.result_cache = []
self.exiting = threading.Event()
self.paused = threading.Event()
self.lock = threading.Lock()
self.rollout_stat = RolloutStat()
self._version = 0
def initialize(self, addr: str | None, ft_spec: Optional[Dict[str, Any]] = None):
def _wait_for_server(self, address):
base_url = f"http://{address}"
tik = time.time()
while time.time() - tik < self.config.setup_timeout:
if self.check_health(base_url):
return
time.sleep(1)
raise RuntimeError("server launch failed")
def check_health(self, base_url):
# Check server endpoint
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200
except requests.exceptions.RequestException as e:
return False
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
self.rollout_tasks: Dict[str, asyncio.Task] = {}
self.rollout_thread = threading.Thread(target=self._rollout_thread)
self.rollout_thread.start()
@ -85,79 +121,45 @@ class RemoteSGLangEngine(InferenceEngine):
traceback.print_exc()
async def _rollout_thread_async(self):
data = None
rollout_tasks: Dict[str, asyncio.Task] = {}
pending_data = []
rollout_tasks = self.rollout_tasks
rid = 0
try:
while not self.exiting.is_set():
# Load next data from controller
if data is None:
while True:
try:
data, workflow = self.input_queue.get_nowait()
logger.info(f"Get data from puller: {data}")
logger.debug(f"Get data from puller: {data}")
pending_data.append(data)
except Empty:
logger.debug(f"No data from puller stream.")
break
# Check capacity
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
cannot_rollout_reason = []
capacity = max(1, self.config.max_concurrent_rollouts // world_size)
can_rollout = len(rollout_tasks) < capacity
if not can_rollout:
cannot_rollout_reason.append(
f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}"
)
# Staleness control
version = self.get_version()
ofp = self.config.max_head_offpolicyness
with self.lock:
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
expected_version = sample_cnt // self.config.consumer_batch_size
not_staled = expected_version <= ofp + version
can_rollout &= not_staled
if not not_staled:
cannot_rollout_reason.append(
f"Staled: expected version ({expected_version}) = "
f"global sample cnt ({sample_cnt}) // batch size ({self.config.consumer_batch_size}), "
f"current latest version {version}, "
f"offpolicyness {self.config.max_head_offpolicyness}."
)
if not can_rollout:
logger.debug(
f"Cannot submit new rollouts. "
+ "\n".join(cannot_rollout_reason)
)
capacity = self.get_capacity()
# Create new rollout task
if can_rollout and data is not None:
while capacity > 0 and pending_data and not self.paused.is_set():
task = asyncio.create_task(
workflow.arun_episode(self, data), name=str(rid)
workflow.arun_episode(self, pending_data.pop(0)), name=str(rid)
)
rollout_tasks[str(rid)] = task
with self.lock:
rollout_tasks[str(rid)] = task
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
if self.config.enable_rollout_tracing:
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
capacity -= 1
rid += 1
data = None
# Wait for rollout completion
tasks = list(rollout_tasks.values())
with self.lock:
tasks = list(rollout_tasks.values())
done = []
if tasks:
done, _ = await asyncio.wait(
@ -165,16 +167,19 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
if not done:
await asyncio.sleep(1)
else:
await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME)
await asyncio.sleep(1)
# Collect done results
for task in done:
traj = await task
traj: TensorDict
task_rid = task.get_name()
rollout_tasks.pop(task_rid)
self.rollout_stat.accepted += 1
with self.lock:
rollout_tasks.pop(task_rid)
self.rollout_stat.accepted += 1
try:
self.output_queue.put_nowait(traj)
@ -185,21 +190,25 @@ class RemoteSGLangEngine(InferenceEngine):
with self.lock:
self.rollout_stat.running -= 1
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
if self.config.enable_rollout_tracing:
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
except Exception:
traceback.print_exc()
finally:
# Cancel remaining tasks
for task in rollout_tasks.values():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
with self.lock:
for task in rollout_tasks.values():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def choose_server(self) -> str:
if self.config.schedule_policy == "round_robin":
@ -236,8 +245,7 @@ class RemoteSGLangEngine(InferenceEngine):
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=timeout,
sock_connect=30,
sock_read=timeout,
sock_connect=timeout,
)
) as session:
if method.upper() == "GET":
@ -252,7 +260,7 @@ class RemoteSGLangEngine(InferenceEngine):
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response
return await response.json()
except (
aiohttp.ClientError,
@ -288,15 +296,11 @@ class RemoteSGLangEngine(InferenceEngine):
# NOTE: rid should NOT be passed in payload
payload = {
"text": req.text,
"input_ids": req.input_ids.copy(),
"sampling_params": sample_params,
"return_logprob": True,
"stream": False,
}
if req.text:
payload["text"] = req.text
else:
payload["input_ids"] = req.input_ids
# Make request
start_time = time.perf_counter()
@ -324,7 +328,7 @@ class RemoteSGLangEngine(InferenceEngine):
and len(accumulated_output_tokens) < gconfig.max_new_tokens
):
# loop until the generation is complete
response = await self.arequest_with_retry(
result = await self.arequest_with_retry(
endpoint="/generate",
payload=payload,
method="POST",
@ -332,10 +336,8 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=self.config.request_timeout,
target_addr=server_addr,
)
result = await response.json()
# Parse response
completions += result["text"]
meta_info = result["meta_info"]
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
@ -350,12 +352,11 @@ class RemoteSGLangEngine(InferenceEngine):
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
payload["text"] += result["text"]
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
latency = time.perf_counter() - start_time
return LLMResponse(
completions=completions,
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
@ -376,10 +377,10 @@ class RemoteSGLangEngine(InferenceEngine):
update_name = names.update_weights_from_disk(
self.config.experiment_name, self.config.trial_name, meta.model_version
)
save_timestamp = int(name_resolve.wait(update_name, timeout=120))
load_timestamp = time.time_ns()
save_timestamp = float(name_resolve.wait(update_name, timeout=120))
load_timestamp = datetime.now().timestamp()
logger.info(
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp)/1e6:.2f} ms"
f"Begin update weights from {meta.path}, responded in {(load_timestamp - save_timestamp):.2f}s"
)
try:
jobs = [
@ -393,14 +394,14 @@ class RemoteSGLangEngine(InferenceEngine):
finally:
loop.close()
logger.info(
f"Loading weights done in {(time.time_ns() - load_timestamp)/1e6:.2f} ms"
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
)
self.set_version(meta.model_version)
else:
raise NotImplementedError(f"Unsupported weight update type: {meta.type}")
async def aupdate_weights_from_disk(self, addr, path: str):
response = await self.arequest_with_retry(
res = await self.arequest_with_retry(
endpoint="/update_weights_from_disk",
payload=dict(model_path=str(path), allow_interrupt=True),
method="POST",
@ -408,7 +409,6 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=self.config.request_timeout,
target_addr=addr,
)
res = await response.json()
assert res["success"]
if "num_paused_requests" in res:
logger.info(
@ -416,15 +416,40 @@ class RemoteSGLangEngine(InferenceEngine):
f"during updating weights for server {addr}"
)
def get_capacity(self):
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
max_concurrent_rollouts = max(
1, self.config.max_concurrent_rollouts // world_size
)
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
# Staleness control
version = self.get_version()
ofp = self.config.max_head_offpolicyness
with self.lock:
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
return capacity
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
try:
self.input_queue.put_nowait((data, workflow))
except Full:
raise RuntimeError("Input queue full. Please increase queue_size.")
def wait(self, count: int, timeout: float, should_accept: Callable) -> TensorDict:
def wait(
self,
count: int,
timeout: float | None = None,
should_accept: Callable | None = None,
) -> TensorDict:
tik = time.perf_counter()
accepted = len(self.result_cache)
timeout = timeout or float(7 * 24 * 3600)
while (
accepted < count
and not self.exiting.is_set()
@ -432,14 +457,14 @@ class RemoteSGLangEngine(InferenceEngine):
):
try:
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
if should_accept(result):
if should_accept is None or should_accept(result):
self.result_cache.append(result)
accepted += 1
else:
with self.lock:
self.rollout_stat.accepted -= 1
except Empty:
time.sleep(ROLLOUT_POLL_WAIT_TIME)
pass
if self.exiting.is_set():
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
if accepted < count:
@ -450,16 +475,39 @@ class RemoteSGLangEngine(InferenceEngine):
self.result_cache[:count],
self.result_cache[count:],
)
return TensorDict.cat(results, dim=0)
return concat_padded_tensors(results)
def rollout(
def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
for item in data:
self.submit(item, workflow)
return self.wait(
count=len(data),
timeout=self.config.request_timeout,
should_accept=lambda x: True,
)
return self.wait(count=len(data))
def prepare_batch(
self,
data_generator: Iterator,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
):
assert dataloader.batch_size is not None
while True:
if self.get_capacity() + dataloader.batch_size > 0:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(dataloader)
data = next(data_generator)
for item in data:
self.submit(item, workflow=workflow)
try:
return self.wait(dataloader.batch_size, timeout=1)
except TimeoutError:
pass
def pause(self):
self.paused.set()
def resume(self):
self.paused.clear()

View File

@ -5,7 +5,6 @@ import time
import uuid
import pytest
import requests
import torch
from tensordict import TensorDict
@ -15,62 +14,43 @@ from arealite.api.cli_args import (
SGLangConfig,
)
from arealite.api.io_struct import LLMRequest, LLMResponse, WeightUpdateMeta
from arealite.utils import network
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import network
EXPR_NAME = "test_sglang_engine"
TRIAL_NAME = "trial_0"
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
if not os.path.exists(MODEL_PATH):
MODEL_PATH = "Qwen/Qwen2-0.5B"
PORT = 13887
DIST_PORT = 15887
PORT, DIST_PORT = network.find_free_ports(2)
HOST = network.gethostip()
def check_server_health(base_url):
# Check server endpoint
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
@pytest.fixture(scope="module")
def sglang_server():
from realhf.base import seeding
seeding.set_random_seed(1, EXPR_NAME)
cmd = SGLangConfig.build_cmd(
sglang_config=SGLangConfig(mem_fraction_static=0.3),
model_path=MODEL_PATH,
sglang_config=SGLangConfig(
skip_tokenizer_init=True,
model_path=MODEL_PATH,
mem_fraction_static=0.3,
),
host=HOST,
port=PORT,
tp_size=1,
base_gpu_id=0,
dist_init_addr=f"{HOST}:{DIST_PORT}",
served_model_name=MODEL_PATH,
skip_tokenizer_init=False,
)
# Launch process
full_command = f"{cmd} --port {PORT}"
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
cmd = cmd.replace("\\\n", " ").replace("\\", " ")
process = subprocess.Popen(
full_command.split(),
cmd.split(),
text=True,
stdout=sys.stdout,
stderr=sys.stdout,
)
base_url = f"http://{HOST}:{PORT}"
tik = time.time()
while time.time() - tik < 90:
if check_server_health(base_url):
break
time.sleep(1)
if time.time() - tik > 90:
raise RuntimeError("server launch failed")
yield
process.terminate()
@ -80,11 +60,12 @@ async def test_remote_sglang_generate(sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
tokenizer = load_hf_tokenizer(MODEL_PATH)
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
req = LLMRequest(
rid=str(uuid.uuid4()),
text="hello! how are you today",
input_ids=tokenizer.encode("hello! how are you today"),
gconfig=GenerationHyperparameters(max_new_tokens=16),
)
resp = await engine.agenerate(req)
@ -95,7 +76,6 @@ async def test_remote_sglang_generate(sglang_server):
== len(resp.output_tokens)
== len(resp.output_versions)
)
assert isinstance(resp.completions, str)
@pytest.mark.parametrize("n_samples", [1, 2, 4])
@ -109,7 +89,7 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
max_concurrent_rollouts=2,
consumer_batch_size=2,
)
config.server_addrs = [f"{HOST}:{PORT}"]
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
engine.initialize(None, None)
@ -122,12 +102,13 @@ def test_remote_sglang_rollout(sglang_server, n_samples):
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
gconfig=gconfig,
tokenizer=tokenizer,
enable_thinking=False,
)
data = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
result = engine.rollout([data] * 2, workflow=workflow)
result = engine.rollout_batch([data] * 2, workflow=workflow)
assert isinstance(result, TensorDict)
bs = result.batch_size
assert bs == torch.Size([2 * n_samples])
@ -147,7 +128,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
consumer_batch_size=bs,
max_head_offpolicyness=ofp,
)
config.server_addrs = [f"{HOST}:{PORT}"]
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
engine = RemoteSGLangEngine(config)
engine.initialize(None, None)
@ -160,6 +141,7 @@ def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
gconfig=gconfig,
tokenizer=tokenizer,
enable_thinking=False,
)
data = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
@ -220,7 +202,7 @@ def test_disk_update_weights_from_fsdp_engine(tmp_path_factory, sglang_server):
from arealite.engine.sglang_remote import RemoteSGLangEngine
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
config.server_addrs = [f"{HOST}:{PORT}"]
os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{HOST}:{PORT}"
inf_engine = RemoteSGLangEngine(config)
# test update weights
path = tmp_path_factory.mktemp("upload_weights_from_disk")

100
arealite/utils/network.py Normal file
View File

@ -0,0 +1,100 @@
import random
import socket
from typing import List, Set
def gethostname():
return socket.gethostname()
def gethostip():
return socket.gethostbyname(socket.gethostname())
def find_free_ports(
count: int, port_range: tuple = (1024, 65535), exclude_ports: Set[int] | None = None
) -> List[int]:
"""
Find multiple free ports within a specified range.
Args:
count: Number of free ports to find
port_range: Tuple of (min_port, max_port) to search within
exclude_ports: Set of ports to exclude from search
Returns:
List of free port numbers
Raises:
ValueError: If unable to find requested number of free ports
"""
if exclude_ports is None:
exclude_ports = set()
min_port, max_port = port_range
free_ports = []
attempted_ports = set()
# Calculate available port range
available_range = max_port - min_port + 1 - len(exclude_ports)
if count > available_range:
raise ValueError(
f"Cannot find {count} ports in range {port_range}. "
f"Only {available_range} ports available."
)
max_attempts = count * 10 # Reasonable limit to avoid infinite loops
attempts = 0
while len(free_ports) < count and attempts < max_attempts:
# Generate random port within range
port = random.randint(min_port, max_port)
# Skip if port already attempted or excluded
if port in attempted_ports or port in exclude_ports:
attempts += 1
continue
attempted_ports.add(port)
if is_port_free(port):
free_ports.append(port)
attempts += 1
if len(free_ports) < count:
raise ValueError(
f"Could only find {len(free_ports)} free ports "
f"out of {count} requested after {max_attempts} attempts"
)
return sorted(free_ports)
def is_port_free(port: int) -> bool:
"""
Check if a port is free by attempting to bind to it.
Args:
port: Port number to check
Returns:
True if port is free, False otherwise
"""
# Check TCP
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind(("", port))
sock.close()
except OSError:
return False
# Check UDP
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
sock.bind(("", port))
sock.close()
return True
except OSError:
return False

View File

@ -17,19 +17,24 @@ class RLVRWorkflow(RolloutWorkflow):
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
enable_thinking: bool,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
self.enable_thinking = enable_thinking
async def arun_episode(self, engine, data):
text = self.tokenizer.apply_chat_template(
data["messages"], tokenize=False, add_generation_prompt=True
input_ids = self.tokenizer.apply_chat_template(
data["messages"],
tokenize=True,
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
n_samples = self.gconfig.n_samples
req = LLMRequest(
rid=uuid.uuid4().hex,
text=text,
input_ids=input_ids,
gconfig=self.gconfig.new(n_samples=1),
)
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
@ -42,8 +47,8 @@ class RLVRWorkflow(RolloutWorkflow):
versions = [-1] * resp.input_len + resp.output_versions
reward = self.reward_fn(
prompt=req.text,
completions=resp.completions,
prompt=self.tokenizer.decode(input_ids),
completions=self.tokenizer.decode(resp.output_tokens),
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,