mirror of https://github.com/inclusionAI/AReaL
PullRequest: 331 [lite] Support remote sglang engine with corresponding testcases.
Merge branch fw/lite of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/331 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * add test for sglang remote engine * fix
This commit is contained in:
parent
57b9b945ab
commit
8771778995
|
@ -83,7 +83,6 @@ def main_grpo():
|
|||
# or asynchronous rollout with filtering and off-policyness control
|
||||
# rollout_batch = rollout.prepare_batch(batch,
|
||||
# workflow=MyRolloutWorkflow(rollout_config.workflow),
|
||||
# offpolicyness=4,
|
||||
# should_accept=lambda x: x['rewards'].mean() > 0)
|
||||
|
||||
# In the single-controller mode
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -70,8 +70,157 @@ class GenerationHyperparameters:
|
|||
return GenerationHyperparameters(**args)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGLangConfig:
|
||||
"""Configuration for SGLang runtime. Refer to:
|
||||
https://github.com/sgl-project/sglang for detailed documentation.
|
||||
"""
|
||||
|
||||
disable_cuda_graph: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
enable_nccl_nvls: bool = False
|
||||
disable_outlines_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
enable_ep_moe: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
cuda_graph_max_bs: Optional[int] = None
|
||||
cuda_graph_bs: Optional[List[int]] = None
|
||||
torchao_config: str = ""
|
||||
enable_nan_detection: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
triton_attention_num_kv_splits: int = 8
|
||||
num_continuous_decode_steps: int = 1
|
||||
enable_memory_saver: bool = False
|
||||
allow_auto_truncate: bool = False
|
||||
# NOTE: to avoid the illegal memory access error
|
||||
attention_backend: Optional[str] = "flashinfer"
|
||||
sampling_backend: Optional[str] = None
|
||||
context_length: Optional[int] = 32768
|
||||
mem_fraction_static: Optional[float] = 0.9
|
||||
max_running_requests: Optional[int] = None
|
||||
# NOTE: chunked_prefill_size is by default 8192 on GPUs with 80GB mem in SGLang,
|
||||
# but we disable it to avoid precision issues
|
||||
chunked_prefill_size: Optional[int] = -1
|
||||
max_prefill_tokens: int = 32768
|
||||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
|
||||
dtype: str = "float16"
|
||||
kv_cache_dtype: str = "auto"
|
||||
|
||||
# logging
|
||||
log_level: str = "warning"
|
||||
log_level_http: Optional[str] = "warning"
|
||||
log_requests: bool = False
|
||||
log_requests_level: int = 0
|
||||
show_time_cost: bool = False
|
||||
enable_metrics: bool = True # Exports Prometheus-like metrics
|
||||
# The interval (in decoding iterations) to log throughput
|
||||
# and update prometheus metrics
|
||||
decode_log_interval: int = 1
|
||||
|
||||
# Use staticmethod to make OmegaConf happy.
|
||||
@staticmethod
|
||||
def build_cmd(
|
||||
sglang_config: "SGLangConfig",
|
||||
model_path,
|
||||
tp_size,
|
||||
base_gpu_id,
|
||||
dist_init_addr: Optional[str] = None,
|
||||
served_model_name: Optional[str] = None,
|
||||
skip_tokenizer_init: bool = True,
|
||||
):
|
||||
from realhf.base import network, pkg_version, seeding
|
||||
from realhf.experiments.common.utils import asdict as conf_as_dict
|
||||
|
||||
args: Dict = conf_as_dict(sglang_config)
|
||||
args["random_seed"] = seeding.get_seed()
|
||||
|
||||
if served_model_name is None:
|
||||
served_model_name = model_path
|
||||
host_ip = network.gethostip()
|
||||
host = "localhost" if not sglang_config.enable_metrics else host_ip
|
||||
args = dict(
|
||||
host=host,
|
||||
model_path=model_path,
|
||||
# Model and tokenizer
|
||||
tokenizer_path=model_path,
|
||||
tokenizer_mode="auto",
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
device="cuda",
|
||||
served_model_name=served_model_name,
|
||||
is_embedding=False,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
# Other runtime options
|
||||
tp_size=tp_size,
|
||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||
base_gpu_id=base_gpu_id,
|
||||
nnodes=1,
|
||||
node_rank=0,
|
||||
dist_init_addr=dist_init_addr,
|
||||
**args,
|
||||
)
|
||||
|
||||
if pkg_version.is_version_less("sglang", "0.4.4"):
|
||||
args.pop("log_requests_level")
|
||||
if pkg_version.is_version_less("sglang", "0.4.3"):
|
||||
args.pop("enable_nccl_nvls")
|
||||
args.pop("triton_attention_num_kv_splits")
|
||||
args.pop("cuda_graph_bs")
|
||||
args.pop("enable_memory_saver")
|
||||
args.pop("allow_auto_truncate")
|
||||
args.pop("file_storage_path")
|
||||
|
||||
flags = []
|
||||
for k, v in args.items():
|
||||
if v is None or v is False or v == "":
|
||||
continue
|
||||
if v is True:
|
||||
flags.append(f"--{k.replace('_','-')} ")
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
values = " ".join(map(str, v))
|
||||
flags.append(f"--{k.replace('_','-')} {values}")
|
||||
continue
|
||||
flags.append(f"--{k.replace('_','-')} {v}")
|
||||
flags = " ".join(flags)
|
||||
return f"python3 -m sglang.launch_server {flags}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceEngineConfig:
|
||||
experiment_name: str
|
||||
trial_name: str
|
||||
max_concurrent_rollouts: None | int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Maximum number of concurrent rollouts to the inference engine. Defaults to consumer_batch_size."
|
||||
},
|
||||
)
|
||||
queue_size: None | int = field(
|
||||
default=None,
|
||||
metadata={"help": "Input/Output queue size for async rollout."},
|
||||
)
|
||||
consumer_batch_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Batch size for consuming rollouts from the queue."},
|
||||
)
|
||||
max_head_offpolicyness: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "Maximum off-policyness for the head. "
|
||||
"If the current version is more than this many versions behind, "
|
||||
"the request will not be accepted.",
|
||||
},
|
||||
)
|
||||
# Used by remote inference engines.
|
||||
server_addrs: List[str] = field(
|
||||
default_factory=list,
|
||||
|
|
|
@ -109,6 +109,10 @@ class InferenceEngine(abc.ABC):
|
|||
"""Initialize environments for distributed inference and load models."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def destroy(self):
|
||||
"""Destroy the engine and release GPU memory."""
|
||||
pass
|
||||
|
||||
def update_weights(self, meta: WeightUpdateMeta) -> Future:
|
||||
"""Update weights in the inference engine."""
|
||||
raise NotImplementedError()
|
||||
|
@ -121,7 +125,7 @@ class InferenceEngine(abc.ABC):
|
|||
"""Asynchronously submit a request to the inference engine. Exits immediately."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def wait(self, count: int, timeout: int) -> TensorDict:
|
||||
def wait(self, count: int, timeout: float) -> TensorDict:
|
||||
"""Wait for a specified number of requests to complete, with a timeout."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
|||
class RolloutWorkflow:
|
||||
|
||||
async def arun_episode(
|
||||
self, engine: InferenceEngine, data: Dict[str, Any]
|
||||
self, engine: "InferenceEngine", data: Dict[str, Any]
|
||||
) -> TensorDict:
|
||||
"""Run a single episode of the workflow.
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@ import threading
|
|||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Empty, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import torch.distributed as dist
|
||||
|
@ -18,7 +18,7 @@ from arealite.api.io_struct import (
|
|||
RolloutStat,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
from realhf.base import logging, pkg_version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
|
@ -31,19 +31,27 @@ if pkg_version.is_available("sglang"):
|
|||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
||||
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.4
|
||||
RID_CACHE_SIZE = 128
|
||||
|
||||
|
||||
class RemoteSGLangEngine(InferenceEngine):
|
||||
|
||||
def __init__(self, config: InferenceEngineConfig):
|
||||
config.max_concurrent_rollouts = (
|
||||
config.max_concurrent_rollouts or config.consumer_batch_size
|
||||
)
|
||||
self.config = config
|
||||
|
||||
self.rid_to_address = {}
|
||||
# Maintain the addresses for the recent 128 requests
|
||||
self.rid_queue = []
|
||||
|
||||
self.addresses = config.server_addrs
|
||||
self.server_idx = 0
|
||||
|
||||
self.input_queue = Queue(maxsize=config.max_concurrent_rollouts)
|
||||
self.output_queue = Queue(maxsize=config.max_concurrent_rollouts)
|
||||
qsize = config.queue_size or config.max_concurrent_rollouts * 10
|
||||
self.input_queue = Queue(maxsize=qsize)
|
||||
self.output_queue = Queue(maxsize=qsize)
|
||||
self.result_cache = []
|
||||
|
||||
self.exiting = threading.Event()
|
||||
|
@ -51,32 +59,35 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
|
||||
self.rollout_stat = RolloutStat()
|
||||
|
||||
def _get_model_version(self) -> int:
|
||||
name = names.model_version(
|
||||
self.config.experiment_name,
|
||||
self.config.trial_name,
|
||||
"actor",
|
||||
)
|
||||
try:
|
||||
return int(name_resolve.get(name))
|
||||
except name_resolve.NameEntryNotFoundError:
|
||||
return 0
|
||||
self._version = 0
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: Optional[Dict[str, Any]] = None):
|
||||
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
||||
self.rollout_thread.start()
|
||||
|
||||
def destroy(self):
|
||||
self.exiting.set()
|
||||
self.rollout_thread.join()
|
||||
|
||||
def set_version(self, version):
|
||||
with self.lock:
|
||||
self._version = version
|
||||
|
||||
def get_version(self):
|
||||
with self.lock:
|
||||
return self._version
|
||||
|
||||
def _rollout_thread(self):
|
||||
"""Thread that runs the rollout loop."""
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self._rollout_thread_async())
|
||||
finally:
|
||||
self.exiting.set()
|
||||
asyncio.run(self._rollout_thread_async())
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def _rollout_thread_async(self):
|
||||
data = None
|
||||
|
||||
rollout_tasks: Dict[int, asyncio.Task] = {}
|
||||
rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||
rid = 0
|
||||
|
||||
try:
|
||||
|
@ -85,7 +96,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
if data is None:
|
||||
try:
|
||||
data, workflow = self.input_queue.get_nowait()
|
||||
logger.debug(f"Get data from puller: {data}")
|
||||
logger.info(f"Get data from puller: {data}")
|
||||
except Empty:
|
||||
logger.debug(f"No data from puller stream.")
|
||||
|
||||
|
@ -104,17 +115,17 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
)
|
||||
|
||||
# Staleness control
|
||||
version = self._get_model_version()
|
||||
version = self.get_version()
|
||||
ofp = self.config.max_head_offpolicyness
|
||||
with self.lock:
|
||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||
expected_version = sample_cnt // self.train_batch_size
|
||||
expected_version = sample_cnt // self.config.consumer_batch_size
|
||||
not_staled = expected_version <= ofp + version
|
||||
can_rollout &= not_staled
|
||||
if not not_staled:
|
||||
cannot_rollout_reason.append(
|
||||
f"Staled: expected version ({expected_version}) = "
|
||||
f"global sample cnt ({sample_cnt}) // batch size ({self.train_batch_size}), "
|
||||
f"global sample cnt ({sample_cnt}) // batch size ({self.config.consumer_batch_size}), "
|
||||
f"current latest version {version}, "
|
||||
f"offpolicyness {self.config.max_head_offpolicyness}."
|
||||
)
|
||||
|
@ -130,12 +141,12 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
task = asyncio.create_task(
|
||||
workflow.arun_episode(self, data), name=str(rid)
|
||||
)
|
||||
rollout_tasks[rid] = task
|
||||
rollout_tasks[str(rid)] = task
|
||||
|
||||
with self.lock:
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"Submit rollout rid {rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
|
@ -163,12 +174,18 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
traj: TensorDict
|
||||
task_rid = task.get_name()
|
||||
rollout_tasks.pop(task_rid)
|
||||
self.rollout_stat.accepted += 1
|
||||
|
||||
self.output_queue.put(traj)
|
||||
try:
|
||||
self.output_queue.put_nowait(traj)
|
||||
except Full:
|
||||
raise RuntimeError(
|
||||
"Output queue full. Please increase queue_size."
|
||||
)
|
||||
|
||||
with self.lock:
|
||||
self.rollout_stat.running -= 1
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"Finish rollout {task_rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
|
@ -256,7 +273,11 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
gconfig = req.gconfig
|
||||
stop_token_ids = gconfig.stop_token_ids
|
||||
|
||||
assert gconfig.n_samples == 1
|
||||
if gconfig.n_samples != 1:
|
||||
raise ValueError(
|
||||
"RemoteSGLangEngine does not support n_samples > 1. "
|
||||
"Please call generate for multiple times with n_samples = 1."
|
||||
)
|
||||
sample_params = {
|
||||
"top_p": gconfig.top_p,
|
||||
"top_k": gconfig.top_k,
|
||||
|
@ -265,8 +286,8 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
"stop_token_ids": stop_token_ids,
|
||||
}
|
||||
|
||||
# NOTE: rid should NOT be passed in payload
|
||||
payload = {
|
||||
"rid": req.rid,
|
||||
"text": req.text,
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": True,
|
||||
|
@ -287,6 +308,17 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
completions = ""
|
||||
stop_reason = "length"
|
||||
|
||||
if req.rid in self.rid_to_address:
|
||||
server_addr = self.rid_to_address[req.rid]
|
||||
else:
|
||||
server_addr = self.choose_server()
|
||||
if len(self.rid_queue) >= RID_CACHE_SIZE:
|
||||
# Remove the oldest entry if cache is full
|
||||
oldest_rid = self.rid_queue.pop(0)
|
||||
self.rid_to_address.pop(oldest_rid, None)
|
||||
self.rid_to_address[req.rid] = server_addr
|
||||
self.rid_queue.append(req.rid)
|
||||
|
||||
while (
|
||||
stop_reason != "stop"
|
||||
and len(accumulated_output_tokens) < gconfig.max_new_tokens
|
||||
|
@ -298,6 +330,7 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.config.request_timeout,
|
||||
target_addr=server_addr,
|
||||
)
|
||||
result = await response.json()
|
||||
|
||||
|
@ -369,9 +402,12 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
)
|
||||
|
||||
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
||||
self.input_queue.put((workflow, data))
|
||||
try:
|
||||
self.input_queue.put_nowait((data, workflow))
|
||||
except Full:
|
||||
raise RuntimeError("Input queue full. Please increase queue_size.")
|
||||
|
||||
def wait(self, count: int, timeout: int, should_accept: Callable) -> TensorDict:
|
||||
def wait(self, count: int, timeout: float, should_accept: Callable) -> TensorDict:
|
||||
tik = time.perf_counter()
|
||||
accepted = len(self.result_cache)
|
||||
while (
|
||||
|
@ -384,8 +420,9 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
if should_accept(result):
|
||||
self.result_cache.append(result)
|
||||
accepted += 1
|
||||
else:
|
||||
with self.lock:
|
||||
self.rollout_stat.accepted += 1
|
||||
self.rollout_stat.accepted -= 1
|
||||
except Empty:
|
||||
time.sleep(ROLLOUT_POLL_WAIT_TIME)
|
||||
if self.exiting.is_set():
|
||||
|
@ -399,3 +436,15 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self.result_cache[count:],
|
||||
)
|
||||
return TensorDict.cat(results, dim=0)
|
||||
|
||||
def rollout(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
for item in data:
|
||||
self.submit(item, workflow)
|
||||
return self.wait(
|
||||
count=len(data),
|
||||
timeout=self.config.request_timeout,
|
||||
should_accept=lambda x: True,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
GenerationHyperparameters,
|
||||
InferenceEngineConfig,
|
||||
SGLangConfig,
|
||||
)
|
||||
from arealite.api.io_struct import FinetuneSpec, LLMRequest, LLMResponse
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import name_resolve, network, seeding
|
||||
|
||||
EXPR_NAME = "test_sglang_engine"
|
||||
TRIAL_NAME = "trial_0"
|
||||
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
MODEL_PATH = "Qwen/Qwen2-0.5B"
|
||||
PORT = 13887
|
||||
DIST_PORT = 15887
|
||||
HOST = network.gethostip()
|
||||
|
||||
|
||||
def check_server_health(base_url):
|
||||
# Check server endpoint
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_server():
|
||||
from realhf.base import seeding
|
||||
|
||||
seeding.set_random_seed(1, EXPR_NAME)
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
sglang_config=SGLangConfig(mem_fraction_static=0.3),
|
||||
model_path=MODEL_PATH,
|
||||
tp_size=1,
|
||||
base_gpu_id=0,
|
||||
dist_init_addr=f"{HOST}:{DIST_PORT}",
|
||||
served_model_name=MODEL_PATH,
|
||||
skip_tokenizer_init=False,
|
||||
)
|
||||
# Launch process
|
||||
full_command = f"{cmd} --port {PORT}"
|
||||
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
|
||||
process = subprocess.Popen(
|
||||
full_command.split(),
|
||||
text=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stdout,
|
||||
)
|
||||
base_url = f"http://{HOST}:{PORT}"
|
||||
tik = time.time()
|
||||
while time.time() - tik < 90:
|
||||
if check_server_health(base_url):
|
||||
break
|
||||
time.sleep(1)
|
||||
if time.time() - tik > 90:
|
||||
raise RuntimeError("server launch failed")
|
||||
yield
|
||||
process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.skip("")
|
||||
@pytest.mark.asyncio
|
||||
async def test_remote_sglang_generate(sglang_server):
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
|
||||
config = InferenceEngineConfig(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
engine = RemoteSGLangEngine(config)
|
||||
req = LLMRequest(
|
||||
rid=str(uuid.uuid4()),
|
||||
text="hello! how are you today",
|
||||
gconfig=GenerationHyperparameters(max_new_tokens=16),
|
||||
)
|
||||
resp = await engine.agenerate(req)
|
||||
assert isinstance(resp, LLMResponse)
|
||||
assert resp.input_tokens == req.input_ids
|
||||
assert (
|
||||
len(resp.output_logprobs)
|
||||
== len(resp.output_tokens)
|
||||
== len(resp.output_versions)
|
||||
)
|
||||
assert isinstance(resp.completions, str)
|
||||
|
||||
|
||||
@pytest.mark.skip("")
|
||||
@pytest.mark.parametrize("n_samples", [1, 2, 4])
|
||||
def test_remote_sglang_rollout(sglang_server, n_samples):
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.workflow.rlvr import RLVRWorkflow
|
||||
|
||||
config = InferenceEngineConfig(
|
||||
experiment_name=EXPR_NAME,
|
||||
trial_name=TRIAL_NAME,
|
||||
max_concurrent_rollouts=2,
|
||||
consumer_batch_size=2,
|
||||
)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
engine = RemoteSGLangEngine(config)
|
||||
engine.initialize(None, None)
|
||||
|
||||
gconfig = GenerationHyperparameters(
|
||||
max_new_tokens=16, greedy=False, n_samples=n_samples
|
||||
)
|
||||
tokenizer = load_hf_tokenizer(MODEL_PATH)
|
||||
|
||||
workflow = RLVRWorkflow(
|
||||
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
|
||||
gconfig=gconfig,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
}
|
||||
result = engine.rollout([data] * 2, workflow=workflow)
|
||||
assert isinstance(result, TensorDict)
|
||||
bs = result.batch_size
|
||||
assert bs == torch.Size([2 * n_samples])
|
||||
engine.destroy()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ofp", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("bs", [2, 4])
|
||||
@pytest.mark.parametrize("n_samples", [2, 1])
|
||||
def test_remote_sglang_staleness_control(sglang_server, bs, ofp, n_samples):
|
||||
from arealite.engine.sglang_remote import RemoteSGLangEngine
|
||||
from arealite.workflow.rlvr import RLVRWorkflow
|
||||
|
||||
config = InferenceEngineConfig(
|
||||
experiment_name=EXPR_NAME,
|
||||
trial_name=TRIAL_NAME,
|
||||
consumer_batch_size=bs,
|
||||
max_head_offpolicyness=ofp,
|
||||
)
|
||||
config.server_addrs = [f"{HOST}:{PORT}"]
|
||||
engine = RemoteSGLangEngine(config)
|
||||
engine.initialize(None, None)
|
||||
|
||||
gconfig = GenerationHyperparameters(
|
||||
max_new_tokens=16, greedy=False, n_samples=n_samples
|
||||
)
|
||||
tokenizer = load_hf_tokenizer(MODEL_PATH)
|
||||
|
||||
workflow = RLVRWorkflow(
|
||||
reward_fn=lambda **kwargs: 1.0, # Dummy reward function
|
||||
gconfig=gconfig,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
}
|
||||
for _ in range(bs * 2):
|
||||
engine.submit(data, workflow=workflow)
|
||||
|
||||
# wait for some time
|
||||
time.sleep(15)
|
||||
assert engine.output_queue.qsize() == min(bs * 2, bs * (ofp + 1))
|
||||
|
||||
# Update model version
|
||||
engine.set_version(1)
|
||||
print("Updated model version", flush=True)
|
||||
|
||||
# submit again
|
||||
for _ in range(bs * 2):
|
||||
engine.submit(data, workflow=workflow)
|
||||
# wait for some time
|
||||
time.sleep(15)
|
||||
assert engine.output_queue.qsize() == min(bs * 4, bs * (ofp + 2))
|
||||
|
||||
# exit
|
||||
engine.destroy()
|
|
@ -0,0 +1,48 @@
|
|||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
|
||||
def concat_padded_tensors(
|
||||
tensor_dicts: List[TensorDict], pad_value: float = 0.0
|
||||
) -> TensorDict:
|
||||
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
|
||||
if not tensor_dicts:
|
||||
return TensorDict()
|
||||
|
||||
batch_sizes = [tuple(d.batch_size) for d in tensor_dicts]
|
||||
new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]]
|
||||
|
||||
# Find max sequence length across all dictionaries
|
||||
assert all("attention_mask" in td for td in tensor_dicts)
|
||||
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
|
||||
result = {}
|
||||
# Process each key
|
||||
for key in tensor_dicts[0].keys():
|
||||
tensors_to_concat = []
|
||||
for tensor_dict in tensor_dicts:
|
||||
tensor = tensor_dict[key]
|
||||
# Skip 1D tensors like rewards
|
||||
if len(tensor.shape) == 1:
|
||||
tensors_to_concat.append(tensor)
|
||||
continue
|
||||
current_length = tensor.shape[1]
|
||||
if current_length < max_length:
|
||||
# Pad tensor to max_length
|
||||
pad_width = max_length - current_length
|
||||
if key == "attention_mask":
|
||||
# Pad attention mask with 0s
|
||||
padding = torch.zeros(
|
||||
(tensor.shape[0], pad_width), dtype=tensor.dtype
|
||||
)
|
||||
else:
|
||||
# Pad feature tensors with pad_value
|
||||
padding = torch.full(
|
||||
(tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype
|
||||
)
|
||||
tensor = torch.cat([tensor, padding], dim=1)
|
||||
tensors_to_concat.append(tensor)
|
||||
|
||||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
||||
return TensorDict(result, batch_size=new_batch_size)
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
|
@ -7,6 +8,7 @@ from transformers import PreTrainedTokenizerFast
|
|||
from arealite.api.cli_args import GenerationHyperparameters
|
||||
from arealite.api.io_struct import LLMRequest
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
from arealite.utils.padding import concat_padded_tensors
|
||||
|
||||
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
|
@ -24,33 +26,38 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
text = self.tokenizer.apply_chat_template(
|
||||
data["messages"], tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
n_samples = self.gconfig.n_samples
|
||||
req = LLMRequest(
|
||||
rid=uuid.uuid4().hex,
|
||||
text=text,
|
||||
gconfig=self.gconfig,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
)
|
||||
resp = await engine.agenerate(req)
|
||||
resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)])
|
||||
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
logprobs = [0] * resp.input_len + resp.output_logprobs
|
||||
prompt_mask = [1] * resp.input_len + [0] * resp.output_len
|
||||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
results = []
|
||||
for resp in resps:
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
logprobs = [0] * resp.input_len + resp.output_logprobs
|
||||
prompt_mask = [1] * resp.input_len + [0] * resp.output_len
|
||||
versions = [-1] * resp.input_len + resp.output_versions
|
||||
|
||||
reward = self.reward_fn(
|
||||
prompt=req.text,
|
||||
completions=resp.completions,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
# reward
|
||||
rewards=torch.tensor([reward]),
|
||||
)
|
||||
reward = self.reward_fn(
|
||||
prompt=req.text,
|
||||
completions=resp.completions,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
**data,
|
||||
)
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(seq).unsqueeze(0),
|
||||
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
attention_mask=torch.ones(len(seq)).unsqueeze(0),
|
||||
# reward
|
||||
rewards=torch.tensor([reward]),
|
||||
)
|
||||
results.append(TensorDict(res, batch_size=[1]))
|
||||
|
||||
return TensorDict(res, batch_size=[1])
|
||||
return concat_padded_tensors(results)
|
||||
|
|
|
@ -53,6 +53,9 @@ dependencies = [
|
|||
"hydra-core==1.4.0.dev1",
|
||||
"packaging",
|
||||
"tabulate",
|
||||
"torchdata",
|
||||
"gymnasium",
|
||||
"tensordict",
|
||||
|
||||
# Monitoring and logging
|
||||
"wandb",
|
||||
|
|
|
@ -69,4 +69,7 @@ word2number
|
|||
Pebble
|
||||
timeout-decorator
|
||||
prettytable
|
||||
swanlab[dashboard]
|
||||
swanlab[dashboard]
|
||||
torchdata
|
||||
gymnasium
|
||||
tensordict
|
Loading…
Reference in New Issue