[lite] [fix] Fix a performance issue and several minor issues before release (#203)

* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing

* PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP  engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .

* PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .

* PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities

Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* fix destroy process group

* PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset

Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/358

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* fix loss mask
* fix
* .

* PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub

Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/368

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation

Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/371

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher

Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/370

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* .
* .
* .
* fix

* PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results.

Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/392

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* support fsdp engine and sglang remote engine
* minor fix
* .
* refactor trainer
* add close
* rm mb_spec
* .
* fix
* .
* qwen2 grpo works
* fix
* fix
* async works
* fix
* slurm launcher not tested
* fix arg parse
* .
* sglang server wrapper
* .
* .
* slurm run
* ready for boba
* debug
* 32k run
* .
* .
* fix
* .
* .
* .
* .
* .
* fix
* .
* fix
* .
* .
* .
* .
* fix
* .
* .
* .
* .
* .
* .
* .
* refactor train engine
* refactor train engine
* .
* fix update weight error
* .
* .
* match train
* format
* .
* fix
* seems to work
* .
* .
* .
* .

* .

* PullRequest: 408 [Feature] Bump SGLang version to v0.4.9.post2

Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/408

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* bump arealite to sglang 0.4.9.post2
* .
* PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta`

* PullRequest: 422 [lite] Fix tests and scripts after updating sgl to 0.4.9

Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/422

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* bump arealite to sglang 0.4.9.post2
* .
* PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta`
* .

* PullRequest: 423 [lite] Remove the boba example for github release.

Merge branch fw/remove-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/423

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* update readme

* PullRequest: 431 [Fix] Fix environment of lite

Merge branch fw/lite-dev of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/431

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* change requirements
* .
* .
* .

* PullRequest: 440 [FIX] fix update weight from disk

Merge branch sxj/lite-fix-disk-update of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/440

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* [FIX] fix update weight from disk

* PullRequest: 442 [lite] Refactor `RemoteSGLangEngine` into two parts: `RemoteSGLangEngine` and `WorkflowExecutor`.

Merge branch mzy/workflow-executor of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/442

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* refactor workflow executor
* .
* fix tests and eval
* .
* .
* revert workflow executor into remote sglang engine
* .

* PullRequest: 456 [lite] [Bug] Use `ProcessPoolExecutor` to calculate reward to avoid rollout slow down

Merge branch mzy/lite/fix-reward of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/456?tab=comment

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix reward
* .
* .
* .

* PullRequest: 460 [lite][fix] add a warning when reward computation timeout

Merge branch fw/lite-fix of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/460

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* add a warning when reward computation timeout

* PullRequest: 465 [lite][fix] Fix issues raised by tsao

Merge branch fw/lite-fix of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/465

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix

---------

Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
Co-authored-by: 冰临 <shenxujie.sxj@antgroup.com>
This commit is contained in:
Wei Fu 2025-07-31 19:29:55 +08:00 committed by GitHub
parent 7fb6a80e48
commit e507ce281c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 384 additions and 310 deletions

View File

@ -195,7 +195,10 @@ class TrainEngineConfig:
gradient_checkpointing: bool = field(
default=True, metadata={"help": "Enable gradient checkpointing"}
)
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."})
dtype: str = field(default="bfloat16", metadata={"help": "Parameter dtype."})
grad_reduce_dtype: str = field(
default="float32", metadata={"help": "Gradient reduce dtype."}
)
optimizer: Optional[OptimizerConfig] = field(
default=None, metadata={"help": "Optimizer configuration"}
)
@ -262,7 +265,7 @@ class PPOActorConfig(TrainEngineConfig):
default=1.0, metadata={"help": "Lambda parameter for GAE"}
)
adv_norm: bool = field(
default=True, metadata={"help": "Enable advantage normalization"}
default=True, metadata={"help": "Enable advantage normalization globally"}
)
# KL Control
@ -275,7 +278,9 @@ class PPOActorConfig(TrainEngineConfig):
)
use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
metadata={
"help": "Use the decoupled loss. Implicitly enable recompute_logprob."
},
)
behav_imp_weight_cap: Optional[float] = field(
default=None,
@ -329,7 +334,7 @@ class SGLangConfig:
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
dtype: str = "float16"
dtype: str = "bfloat16"
kv_cache_dtype: str = "auto"
# logging
log_level: str = "warning"
@ -407,31 +412,8 @@ class SGLangConfig:
dist_init_addr=dist_init_addr,
**args,
)
sglang_version = pkg_version.get_version("sglang")
if sglang_version:
version_less_than_0_4_4 = (
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
)
version_less_than_0_4_3 = (
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
)
elif pkg_version.is_available("sglang"):
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
else:
raise ValueError(
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
)
if version_less_than_0_4_4:
args.pop("log_requests_level")
if version_less_than_0_4_3:
args.pop("enable_nccl_nvls")
args.pop("triton_attention_num_kv_splits")
args.pop("cuda_graph_bs")
args.pop("enable_memory_saver")
args.pop("allow_auto_truncate")
args.pop("file_storage_path")
if not pkg_version.is_version_greater_or_equal("sglang", "0.4.9.post2"):
raise RuntimeError("Needs sglang>=0.4.9.post2 to run the code.")
return args
@ -466,7 +448,7 @@ class InferenceEngineConfig:
default="round_robin",
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
)
setup_timeout: float = field(default=90.0)
setup_timeout: float = field(default=120.0)
request_timeout: float = field(
default=3600, metadata={"help": "Timeout for HTTP requests."}
)

View File

@ -174,6 +174,7 @@ class InferenceEngine(abc.ABC):
self,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
should_accept: Callable | None = None,
):
"""Asynchronously submit and wait until a full batch is ready."""
raise NotImplementedError()

View File

@ -1,10 +1,30 @@
from typing import TYPE_CHECKING, Any, Dict
import asyncio
import itertools
import queue
import threading
import time
import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, List
import torch.distributed as dist
import uvloop
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import InferenceEngineConfig
from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import RolloutStat
from arealite.utils.data import concat_padded_tensors
from realhf.base import logging
if TYPE_CHECKING:
from arealite.api.engine_api import InferenceEngine
logger = logging.getLogger("arealite.workflow_api")
ROLLOUT_POLL_WAIT_TIME = 0.05
class RolloutWorkflow:
@ -16,3 +36,226 @@ class RolloutWorkflow:
See concrete example implementations under the `arealite/workflow` directory.
"""
raise NotImplementedError()
class WorkflowExecutor:
def __init__(
self,
config: InferenceEngineConfig,
inference_engine: "InferenceEngine",
):
config.max_concurrent_rollouts = (
config.max_concurrent_rollouts or config.consumer_batch_size
)
self.config = config
self.exiting = threading.Event()
self.paused = threading.Event()
self.lock = threading.Lock()
self.inference_engine = inference_engine
qsize = config.queue_size or config.max_concurrent_rollouts * 16
self.input_queue = queue.Queue(maxsize=qsize)
self.output_queue = queue.Queue(maxsize=qsize)
self.result_cache: List[TensorDict] = []
self.rollout_stat = RolloutStat()
def initialize(self):
self.rollout_tasks: Dict[str, asyncio.Task] = {}
self.rollout_thread = threading.Thread(target=self._rollout_thread)
self.rollout_thread.start()
def destroy(self):
self.exiting.set()
self.rollout_thread.join()
def get_capacity(self):
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
with self.lock:
max_concurrent_rollouts = max(
1, self.config.max_concurrent_rollouts // world_size
)
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
# Staleness control
version = self.inference_engine.get_version()
ofp = self.config.max_head_offpolicyness
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
return capacity
def _rollout_thread(self):
"""Thread that runs the rollout loop."""
try:
uvloop.run(self._rollout_thread_async())
except Exception:
traceback.print_exc()
async def _rollout_thread_async(self):
rollout_tasks = self.rollout_tasks
rid = 0
try:
while not self.exiting.is_set():
# Check capacity
capacity = self.get_capacity()
# Create new rollout task
self.lock.acquire()
while (
capacity > 0
and not self.paused.is_set()
and self.input_queue.qsize() > 0
):
data, workflow = self.input_queue.get_nowait()
logger.debug(f"Get data from puller: {data}")
task = asyncio.create_task(
workflow.arun_episode(self.inference_engine, data),
name=str(rid),
)
rollout_tasks[str(rid)] = task
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
if self.config.enable_rollout_tracing:
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
capacity -= 1
rid += 1
tasks = list(rollout_tasks.values())
self.lock.release()
# Wait for rollout completion
done = []
if tasks:
done, _ = await asyncio.wait(
tasks,
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
# Collect done results
for task in done:
traj = await task
traj: TensorDict
task_rid = task.get_name()
with self.lock:
rollout_tasks.pop(task_rid)
self.rollout_stat.accepted += 1
self.rollout_stat.running -= 1
if self.config.enable_rollout_tracing:
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
try:
self.output_queue.put_nowait(traj)
except queue.Full:
raise RuntimeError(
"Output queue full. Please increase queue_size."
)
await asyncio.sleep(1)
except Exception:
traceback.print_exc()
finally:
# Cancel remaining tasks
with self.lock:
for task in rollout_tasks.values():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
try:
self.input_queue.put_nowait((data, workflow))
except queue.Full:
raise RuntimeError("Input queue full. Please increase queue_size.")
def wait(
self,
count: int,
timeout: float | None = None,
should_accept: Callable | None = None,
) -> TensorDict:
tik = time.perf_counter()
accepted = len(self.result_cache)
timeout = timeout or float(7 * 24 * 3600)
while (
accepted < count
and not self.exiting.is_set()
and time.perf_counter() - tik < timeout
):
try:
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
if should_accept is None or should_accept(result):
self.result_cache.append(result)
accepted += 1
else:
with self.lock:
self.rollout_stat.accepted -= 1
except queue.Empty:
pass
if self.exiting.is_set():
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
if accepted < count:
raise TimeoutError(
f"Timed out waiting for {count} rollouts, " f"only received {accepted}."
)
results, self.result_cache = (
self.result_cache[:count],
self.result_cache[count:],
)
return concat_padded_tensors(results)
def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
for item in data:
self.submit(item, workflow)
return self.wait(count=len(data))
def prepare_batch(
self,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
should_accept: Callable | None = None,
):
if not hasattr(self, "data_generator"):
self.data_generator = itertools.cycle(dataloader)
assert dataloader.batch_size is not None
while True:
# Submit at least two batches to allow maximum overlap
if (
self.get_capacity() + dataloader.batch_size > 0
and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize
):
data = next(self.data_generator)
for item in data:
self.submit(item, workflow=workflow)
try:
return self.wait(
dataloader.batch_size, timeout=1, should_accept=should_accept
)
except TimeoutError:
pass
def pause(self):
self.paused.set()
def resume(self):
self.paused.clear()

View File

@ -55,7 +55,7 @@ class FSDPEngine(BaseHFEngine):
# Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy(
param_dtype=getattr(torch, self.config.dtype),
reduce_dtype=torch.float32,
reduce_dtype=getattr(torch, self.config.grad_reduce_dtype),
cast_forward_inputs=True,
)
self.device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size)

View File

@ -128,7 +128,7 @@ class PPOActor:
advantages = torch.stack(advantages_reversed[::-1], dim=1)
# Optionally perform advantage normalization.
if self.adv_norm:
if self.adv_norm or self.group_adv_norm:
if self.group_adv_norm:
adv_list = []
for i in range(0, bs, self.group_size):

View File

@ -2,17 +2,13 @@ import asyncio
import os
import random
import shutil
import threading
import time
import traceback
from concurrent.futures import Future, ProcessPoolExecutor
from datetime import datetime
from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, List
from typing import Any, Callable, Dict, List
import aiohttp
import requests
import torch.distributed as dist
import uvloop
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
@ -28,25 +24,19 @@ from arealite.api.io_struct import (
VLMResponse,
WeightUpdateMeta,
)
from arealite.utils.data import concat_padded_tensors
from arealite.api.workflow_api import RolloutWorkflow, WorkflowExecutor
from arealite.utils.http import arequest_with_retry, get_default_connector
from realhf.base import logging, name_resolve, names
if TYPE_CHECKING:
from arealite.api.workflow_api import RolloutWorkflow
logger = logging.getLogger(__name__)
ROLLOUT_POLL_WAIT_TIME = 0.05
RID_CACHE_SIZE = 128
class RemoteSGLangEngine(InferenceEngine):
def __init__(self, config: InferenceEngineConfig):
config.max_concurrent_rollouts = (
config.max_concurrent_rollouts or config.consumer_batch_size
)
self.config = config
self.rid_to_address = {}
@ -54,29 +44,20 @@ class RemoteSGLangEngine(InferenceEngine):
self.rid_queue = []
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
self.session = None
if not self.addresses:
raise RuntimeError("No configured SGLang servers.")
logger.info("Waiting for server ready...")
for addr in self.addresses:
self._wait_for_server(addr)
logger.info("Servers are all ready!")
self.server_idx = random.randint(0, len(self.addresses) - 1)
qsize = config.queue_size or config.max_concurrent_rollouts * 16
self.input_queue = Queue(maxsize=qsize)
self.output_queue = Queue(maxsize=qsize)
self.result_cache = []
self.exiting = threading.Event()
self.paused = threading.Event()
self.lock = threading.Lock()
self.rollout_stat = RolloutStat()
self.distributed_weight_update_initialized = False
self._version = 0
self.workflow_executor = WorkflowExecutor(
config=config,
inference_engine=self,
)
def _wait_for_server(self, address):
base_url = f"http://{address}"
tik = time.time()
@ -94,137 +75,46 @@ class RemoteSGLangEngine(InferenceEngine):
except requests.exceptions.RequestException as e:
return False
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
self.rollout_tasks: Dict[str, asyncio.Task] = {}
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None = None):
logger.info("Waiting for server ready...")
for addr_ in self.addresses:
self._wait_for_server(addr_)
logger.info("Servers are all ready!")
self.executor = ProcessPoolExecutor(max_workers=1)
self.rollout_thread = threading.Thread(target=self._rollout_thread)
self.rollout_thread.start()
self.workflow_executor.initialize()
def destroy(self):
self.executor.shutdown()
self.exiting.set()
self.rollout_thread.join()
def set_version(self, version):
with self.lock:
self._version = version
self._version = version
def get_version(self):
with self.lock:
return self._version
def _rollout_thread(self):
"""Thread that runs the rollout loop."""
try:
uvloop.run(self._rollout_thread_async())
except Exception:
traceback.print_exc()
async def _rollout_thread_async(self):
rollout_tasks = self.rollout_tasks
rid = 0
# NOTE: session is not thread-safe, but we only submit requests in the sub-thread.
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.config.request_timeout,
sock_connect=self.config.request_timeout,
connect=self.config.request_timeout,
),
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
)
try:
while not self.exiting.is_set():
# Check capacity
capacity = self.get_capacity()
# Create new rollout task
while (
capacity > 0
and not self.paused.is_set()
and self.input_queue.qsize() > 0
):
data, workflow = self.input_queue.get_nowait()
logger.debug(f"Get data from puller: {data}")
task = asyncio.create_task(
workflow.arun_episode(self, data), name=str(rid)
)
with self.lock:
rollout_tasks[str(rid)] = task
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
if self.config.enable_rollout_tracing:
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
capacity -= 1
rid += 1
# Wait for rollout completion
with self.lock:
tasks = list(rollout_tasks.values())
done = []
if tasks:
done, _ = await asyncio.wait(
tasks,
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
# Collect done results
for task in done:
traj = await task
traj: TensorDict
task_rid = task.get_name()
with self.lock:
rollout_tasks.pop(task_rid)
self.rollout_stat.accepted += 1
try:
self.output_queue.put_nowait(traj)
except Full:
raise RuntimeError(
"Output queue full. Please increase queue_size."
)
with self.lock:
self.rollout_stat.running -= 1
if self.config.enable_rollout_tracing:
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
await asyncio.sleep(1)
except Exception:
traceback.print_exc()
finally:
# Cancel remaining tasks
with self.lock:
for task in rollout_tasks.values():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
return self._version
def choose_server(self) -> str:
with self.lock:
if self.config.schedule_policy == "round_robin":
server = self.addresses[self.server_idx]
self.server_idx = (self.server_idx + 1) % len(self.addresses)
return server
if self.config.schedule_policy == "round_robin":
server = self.addresses[self.server_idx]
self.server_idx = (self.server_idx + 1) % len(self.addresses)
return server
raise NotImplementedError("Only round-robin scheduling is implemented.")
async def agenerate(
self, req: LLMRequest | VLMRequest
) -> LLMResponse | VLMResponse:
"""Async version of generate using aiohttp."""
if self.session is None:
# NOTE: Lazily initialize aiohttp.ClientSession since it needs to be initialized
# inside asyncio loop in WorkflowExecutor
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.config.request_timeout,
sock_connect=self.config.request_timeout,
connect=self.config.request_timeout,
),
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
)
# Prepare request payload
gconfig = req.gconfig
stop_token_ids = gconfig.stop_token_ids
@ -396,30 +286,8 @@ class RemoteSGLangEngine(InferenceEngine):
fut.add_done_callback(callback)
return fut
def get_capacity(self):
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
max_concurrent_rollouts = max(
1, self.config.max_concurrent_rollouts // world_size
)
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
# Staleness control
version = self.get_version()
ofp = self.config.max_head_offpolicyness
with self.lock:
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
return capacity
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
try:
self.input_queue.put_nowait((data, workflow))
except Full:
raise RuntimeError("Input queue full. Please increase queue_size.")
def submit(self, data: Dict[str, Any], workflow: RolloutWorkflow) -> None:
return self.workflow_executor.submit(data, workflow)
def wait(
self,
@ -427,77 +295,31 @@ class RemoteSGLangEngine(InferenceEngine):
timeout: float | None = None,
should_accept: Callable | None = None,
) -> TensorDict:
tik = time.perf_counter()
accepted = len(self.result_cache)
timeout = timeout or float(7 * 24 * 3600)
while (
accepted < count
and not self.exiting.is_set()
and time.perf_counter() - tik < timeout
):
try:
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
if should_accept is None or should_accept(result):
self.result_cache.append(result)
accepted += 1
else:
with self.lock:
self.rollout_stat.accepted -= 1
except Empty:
pass
if self.exiting.is_set():
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
if accepted < count:
raise TimeoutError(
f"Timed out waiting for {count} rollouts, " f"only received {accepted}."
)
results, self.result_cache = (
self.result_cache[:count],
self.result_cache[count:],
return self.workflow_executor.wait(
count,
timeout=timeout,
should_accept=should_accept,
)
return concat_padded_tensors(results)
def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
for item in data:
self.submit(item, workflow)
return self.wait(count=len(data))
return self.workflow_executor.rollout_batch(data, workflow)
def prepare_batch(
self,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
workflow: RolloutWorkflow,
):
if not hasattr(self, "data_generator"):
self.data_generator = iter(dataloader)
assert dataloader.batch_size is not None
while True:
# Submit at least two batches to allow maximum overlap
if (
self.get_capacity() + dataloader.batch_size > 0
and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize
):
try:
data = next(self.data_generator)
except StopIteration:
self.data_generator = iter(dataloader)
data = next(self.data_generator)
for item in data:
self.submit(item, workflow=workflow)
try:
return self.wait(dataloader.batch_size, timeout=1)
except TimeoutError:
pass
return self.workflow_executor.prepare_batch(dataloader, workflow)
def pause(self):
self.paused.set()
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
return self.workflow_executor.pause()
def resume(self):
self.paused.clear()
"""Resume request submission for async rollout."""
return self.workflow_executor.resume()
def update_weights_from_disk(

View File

@ -94,6 +94,7 @@ class SGLangEngine(InferenceEngine):
asyncio.run(self._rollout_thread_async())
except Exception as e:
traceback.print_exc()
raise e
async def _rollout_thread_async(self):
data = None

View File

@ -324,7 +324,6 @@ def ray_main():
)
allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode)
sglang_cmds = []
sglang_addrs = []
n_sglang_nodes = 0
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:

View File

@ -61,6 +61,7 @@ async def arequest_with_retry(
ctx = _session.delete(url, timeout=timeo)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
async with ctx as response:
if verbose:
logger.info("http requests return")

View File

@ -31,7 +31,7 @@ class MultiTurnWorkflow(RolloutWorkflow):
messages = data["messages"]
# Run multi-turn rollout until correct
t = reward = 0
discount = 0
discount = 1
rid = uuid.uuid4().hex
while reward == 0 and t < self.max_turns:
# Amend a prompt if the previous answer is incorrect

View File

@ -1,6 +1,8 @@
import asyncio
import functools
import os
import uuid
from concurrent.futures import ProcessPoolExecutor
import colorama
import torch
@ -12,6 +14,11 @@ from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import LLMRequest
from arealite.api.workflow_api import RolloutWorkflow
from arealite.utils.data import concat_padded_tensors
from realhf.base import logging
logger = logging.getLogger("RLVR workflow")
REWARD_TIMEOUT_SECONDS = 15
class RLVRWorkflow(RolloutWorkflow):
@ -28,6 +35,7 @@ class RLVRWorkflow(RolloutWorkflow):
self.tokenizer = tokenizer
self.enable_thinking = enable_thinking
self.dump_dir = dump_dir
self.rw_executor = ProcessPoolExecutor(max_workers=4)
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
os.makedirs(self.dump_dir, exist_ok=True)
@ -54,6 +62,7 @@ class RLVRWorkflow(RolloutWorkflow):
seqlens = []
results = []
loop = asyncio.get_event_loop()
for resp in resps:
seq = resp.input_tokens + resp.output_tokens
logprobs = [0.0] * resp.input_len + resp.output_logprobs
@ -65,13 +74,26 @@ class RLVRWorkflow(RolloutWorkflow):
prompt_strs.append(prompt_str)
completions_strs.append(completions_str)
seqlens.append(len(seq))
reward = self.reward_fn(
prompt=prompt_str,
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
try:
reward = await asyncio.wait_for(
loop.run_in_executor(
self.rw_executor,
functools.partial(
self.reward_fn,
prompt_str,
completions_str,
resp.input_tokens,
resp.output_tokens,
**data,
),
),
timeout=REWARD_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.warning(
f"Computing reward timeout after {REWARD_TIMEOUT_SECONDS}s. Set reward to 0."
)
reward = 0
rewards.append(reward)
res = dict(
# unsqueeze to add an additional batch dimension

View File

@ -233,7 +233,7 @@ def prepare_batch(
workflow: "RolloutWorkflow",
):
if not hasattr(self, "data_generator"):
self.data_generator = iter(dataloader)
self.data_generator = itertools.cycle(dataloader)
assert dataloader.batch_size is not None
while True:
# Submit at least two batches to allow maximum overlap
@ -242,11 +242,7 @@ def prepare_batch(
and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize
):
try:
data = next(self.data_generator)
except StopIteration:
self.data_generator = iter(dataloader)
data = next(self.data_generator)
data = next(self.data_generator)
for item in data:
# submit data into input_queue
self.submit(item, workflow=workflow)
@ -264,18 +260,13 @@ rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize()
eval_rollout = ...
data_generator = iter(train_dataloader)
data_generator = iterools.cycle(train_dataloader)
for global_step in range(max_steps):
# rollout batched training data for current step
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
```
If you want to use rollout workflows with custom reward functions or agentic tool
@ -375,17 +366,12 @@ Now a complete GRPO training step in AReaLite is done! The core logic of our exa
training script can be summarized as:
```python
data_generator = iter(train_dataloader)
data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
logp = actor.compute_logp(batch)
batch["prox_logp"] = logp

View File

@ -75,7 +75,7 @@ and converting it into an `LLMRequest` object for the inference engine:
class MultiTurnWorkflow(RolloutWorkflow):
# ... __init__ method above ...
async def arun_episode(self, engine: InferenceEngine, data):
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
# Initialize result containers
seq, logprobs, loss_mask, versions = [], [], [], []
messages = data["messages"]
@ -119,7 +119,7 @@ we'll apply a discount, add feedback to the conversation, and let the model try
class MultiTurnWorkflow(RolloutWorkflow):
# ... previous methods ...
async def arun_episode(self, engine: InferenceEngine, data):
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
# ... initialization code ...
while reward == 0 and t < self.max_turns:
# Add feedback if the previous answer was incorrect
@ -190,7 +190,7 @@ Finally, let's complete the implementation by collecting trajectories in the
class MultiTurnWorkflow(RolloutWorkflow):
# ... previous methods ...
async def arun_episode(self, engine: InferenceEngine, data):
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
# ... episode logic above ...
while reward == 0 and t < self.max_turns:
@ -417,8 +417,27 @@ Using your custom workflow is straightforward—just create it in your training
pass it to the `rollout_batch` or `prepare_batch` method:
```python
# in realhf/impl/agent/__init__.py
import realhf.impl.agent.math_multi_turn_agent
def main(args):
# ... setup code ...
# Create your custom workflow
workflow = MultiTurnWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
turn_discount=0.9,
max_turns=5,
)
# Run training—no other changes needed!
data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps):
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
# ... continue with training loop ...
```
Then update your experiment configuration in

View File

@ -171,16 +171,11 @@ def main(args):
)
# Main training loop
data_generator = itertools.cycle(dataloader)
for global_step in range(max_steps):
# Generate training data
with stats_tracker.record_timing("rollout"):
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
batch = batch.to(actor.device)

View File

@ -68,6 +68,7 @@ ref:
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
disable_dropout: true
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 10240
@ -81,9 +82,9 @@ sglang:
random_seed: ${seed}
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.9
mem_fraction_static: 0.8
# datasets
train_dataset:
@ -94,7 +95,7 @@ train_dataset:
path: openai/gsm8k
type: rl
valid_dataset:
valid_dataset:
batch_size: 256
shuffle: true
pin_memory: true
@ -131,4 +132,5 @@ stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled

View File

@ -1,3 +1,4 @@
import itertools
import os
import sys
@ -129,7 +130,7 @@ def main(args):
max_steps = total_epochs * steps_per_epoch
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
data_generator = iter(train_dataloader)
data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps):
epoch = global_step // steps_per_epoch
step = global_step % steps_per_epoch
@ -138,12 +139,7 @@ def main(args):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
try:
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
batch = batch.to(actor.device)
# Create barrier to synchronize all rollout processes.

View File

@ -6,7 +6,7 @@ pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 "deepspeed>=0.17.
pip install "sglang[all]==0.4.9.post2"
pip install megatron-core==0.11.0 nvidia-ml-py
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
pip install "flash-attn<=2.7.3" --no-build-isolation
pip install "flash-attn<=2.8.2" --no-build-isolation
# Package used for calculating math reward
pip install -e evaluation/latex2sympy

View File

@ -56,6 +56,8 @@ dependencies = [
"torchdata",
"autoflake",
"tensordict",
"pybase64",
"msgspec",
# Monitoring and logging
"wandb",

View File

@ -74,3 +74,6 @@ torchdata
autoflake
tensordict
deepspeed>=0.17.2
pybase64
msgspec
transformers==4.53.1