[lite] [fix] Fix a performance issue and several minor issues before release (#203)

* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing

* PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP  engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .

* PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .

* PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities

Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* fix destroy process group

* PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset

Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/358

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* fix loss mask
* fix
* .

* PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub

Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/368

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation

Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/371

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher

Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/370

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* .
* .
* .
* fix

* PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results.

Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/392

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* support fsdp engine and sglang remote engine
* minor fix
* .
* refactor trainer
* add close
* rm mb_spec
* .
* fix
* .
* qwen2 grpo works
* fix
* fix
* async works
* fix
* slurm launcher not tested
* fix arg parse
* .
* sglang server wrapper
* .
* .
* slurm run
* ready for boba
* debug
* 32k run
* .
* .
* fix
* .
* .
* .
* .
* .
* fix
* .
* fix
* .
* .
* .
* .
* fix
* .
* .
* .
* .
* .
* .
* .
* refactor train engine
* refactor train engine
* .
* fix update weight error
* .
* .
* match train
* format
* .
* fix
* seems to work
* .
* .
* .
* .

* .

* PullRequest: 408 [Feature] Bump SGLang version to v0.4.9.post2

Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/408

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* bump arealite to sglang 0.4.9.post2
* .
* PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta`

* PullRequest: 422 [lite] Fix tests and scripts after updating sgl to 0.4.9

Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/422

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* bump arealite to sglang 0.4.9.post2
* .
* PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta`
* .

* PullRequest: 423 [lite] Remove the boba example for github release.

Merge branch fw/remove-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/423

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* update readme

* PullRequest: 431 [Fix] Fix environment of lite

Merge branch fw/lite-dev of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/431

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* change requirements
* .
* .
* .

* PullRequest: 440 [FIX] fix update weight from disk

Merge branch sxj/lite-fix-disk-update of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/440

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* [FIX] fix update weight from disk

* PullRequest: 442 [lite] Refactor `RemoteSGLangEngine` into two parts: `RemoteSGLangEngine` and `WorkflowExecutor`.

Merge branch mzy/workflow-executor of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/442

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* refactor workflow executor
* .
* fix tests and eval
* .
* .
* revert workflow executor into remote sglang engine
* .

* PullRequest: 456 [lite] [Bug] Use `ProcessPoolExecutor` to calculate reward to avoid rollout slow down

Merge branch mzy/lite/fix-reward of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/456?tab=comment

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix reward
* .
* .
* .

* PullRequest: 460 [lite][fix] add a warning when reward computation timeout

Merge branch fw/lite-fix of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/460

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* add a warning when reward computation timeout

* PullRequest: 465 [lite][fix] Fix issues raised by tsao

Merge branch fw/lite-fix of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/465

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix

---------

Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
Co-authored-by: 冰临 <shenxujie.sxj@antgroup.com>
This commit is contained in:
Wei Fu 2025-07-31 19:29:55 +08:00 committed by GitHub
parent 7fb6a80e48
commit e507ce281c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 384 additions and 310 deletions

View File

@ -195,7 +195,10 @@ class TrainEngineConfig:
gradient_checkpointing: bool = field( gradient_checkpointing: bool = field(
default=True, metadata={"help": "Enable gradient checkpointing"} default=True, metadata={"help": "Enable gradient checkpointing"}
) )
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."}) dtype: str = field(default="bfloat16", metadata={"help": "Parameter dtype."})
grad_reduce_dtype: str = field(
default="float32", metadata={"help": "Gradient reduce dtype."}
)
optimizer: Optional[OptimizerConfig] = field( optimizer: Optional[OptimizerConfig] = field(
default=None, metadata={"help": "Optimizer configuration"} default=None, metadata={"help": "Optimizer configuration"}
) )
@ -262,7 +265,7 @@ class PPOActorConfig(TrainEngineConfig):
default=1.0, metadata={"help": "Lambda parameter for GAE"} default=1.0, metadata={"help": "Lambda parameter for GAE"}
) )
adv_norm: bool = field( adv_norm: bool = field(
default=True, metadata={"help": "Enable advantage normalization"} default=True, metadata={"help": "Enable advantage normalization globally"}
) )
# KL Control # KL Control
@ -275,7 +278,9 @@ class PPOActorConfig(TrainEngineConfig):
) )
use_decoupled_loss: bool = field( use_decoupled_loss: bool = field(
default=False, default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."}, metadata={
"help": "Use the decoupled loss. Implicitly enable recompute_logprob."
},
) )
behav_imp_weight_cap: Optional[float] = field( behav_imp_weight_cap: Optional[float] = field(
default=None, default=None,
@ -329,7 +334,7 @@ class SGLangConfig:
schedule_policy: str = "lpm" schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0 cpu_offload_gb: int = 0
dtype: str = "float16" dtype: str = "bfloat16"
kv_cache_dtype: str = "auto" kv_cache_dtype: str = "auto"
# logging # logging
log_level: str = "warning" log_level: str = "warning"
@ -407,31 +412,8 @@ class SGLangConfig:
dist_init_addr=dist_init_addr, dist_init_addr=dist_init_addr,
**args, **args,
) )
sglang_version = pkg_version.get_version("sglang") if not pkg_version.is_version_greater_or_equal("sglang", "0.4.9.post2"):
if sglang_version: raise RuntimeError("Needs sglang>=0.4.9.post2 to run the code.")
version_less_than_0_4_4 = (
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
)
version_less_than_0_4_3 = (
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
)
elif pkg_version.is_available("sglang"):
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
else:
raise ValueError(
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
)
if version_less_than_0_4_4:
args.pop("log_requests_level")
if version_less_than_0_4_3:
args.pop("enable_nccl_nvls")
args.pop("triton_attention_num_kv_splits")
args.pop("cuda_graph_bs")
args.pop("enable_memory_saver")
args.pop("allow_auto_truncate")
args.pop("file_storage_path")
return args return args
@ -466,7 +448,7 @@ class InferenceEngineConfig:
default="round_robin", default="round_robin",
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]}, metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
) )
setup_timeout: float = field(default=90.0) setup_timeout: float = field(default=120.0)
request_timeout: float = field( request_timeout: float = field(
default=3600, metadata={"help": "Timeout for HTTP requests."} default=3600, metadata={"help": "Timeout for HTTP requests."}
) )

View File

@ -174,6 +174,7 @@ class InferenceEngine(abc.ABC):
self, self,
dataloader: StatefulDataLoader, dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow", workflow: "RolloutWorkflow",
should_accept: Callable | None = None,
): ):
"""Asynchronously submit and wait until a full batch is ready.""" """Asynchronously submit and wait until a full batch is ready."""
raise NotImplementedError() raise NotImplementedError()

View File

@ -1,10 +1,30 @@
from typing import TYPE_CHECKING, Any, Dict import asyncio
import itertools
import queue
import threading
import time
import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, List
import torch.distributed as dist
import uvloop
from tensordict import TensorDict from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import InferenceEngineConfig
from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import RolloutStat
from arealite.utils.data import concat_padded_tensors
from realhf.base import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from arealite.api.engine_api import InferenceEngine from arealite.api.engine_api import InferenceEngine
logger = logging.getLogger("arealite.workflow_api")
ROLLOUT_POLL_WAIT_TIME = 0.05
class RolloutWorkflow: class RolloutWorkflow:
@ -16,3 +36,226 @@ class RolloutWorkflow:
See concrete example implementations under the `arealite/workflow` directory. See concrete example implementations under the `arealite/workflow` directory.
""" """
raise NotImplementedError() raise NotImplementedError()
class WorkflowExecutor:
def __init__(
self,
config: InferenceEngineConfig,
inference_engine: "InferenceEngine",
):
config.max_concurrent_rollouts = (
config.max_concurrent_rollouts or config.consumer_batch_size
)
self.config = config
self.exiting = threading.Event()
self.paused = threading.Event()
self.lock = threading.Lock()
self.inference_engine = inference_engine
qsize = config.queue_size or config.max_concurrent_rollouts * 16
self.input_queue = queue.Queue(maxsize=qsize)
self.output_queue = queue.Queue(maxsize=qsize)
self.result_cache: List[TensorDict] = []
self.rollout_stat = RolloutStat()
def initialize(self):
self.rollout_tasks: Dict[str, asyncio.Task] = {}
self.rollout_thread = threading.Thread(target=self._rollout_thread)
self.rollout_thread.start()
def destroy(self):
self.exiting.set()
self.rollout_thread.join()
def get_capacity(self):
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
with self.lock:
max_concurrent_rollouts = max(
1, self.config.max_concurrent_rollouts // world_size
)
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
# Staleness control
version = self.inference_engine.get_version()
ofp = self.config.max_head_offpolicyness
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
return capacity
def _rollout_thread(self):
"""Thread that runs the rollout loop."""
try:
uvloop.run(self._rollout_thread_async())
except Exception:
traceback.print_exc()
async def _rollout_thread_async(self):
rollout_tasks = self.rollout_tasks
rid = 0
try:
while not self.exiting.is_set():
# Check capacity
capacity = self.get_capacity()
# Create new rollout task
self.lock.acquire()
while (
capacity > 0
and not self.paused.is_set()
and self.input_queue.qsize() > 0
):
data, workflow = self.input_queue.get_nowait()
logger.debug(f"Get data from puller: {data}")
task = asyncio.create_task(
workflow.arun_episode(self.inference_engine, data),
name=str(rid),
)
rollout_tasks[str(rid)] = task
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
if self.config.enable_rollout_tracing:
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
capacity -= 1
rid += 1
tasks = list(rollout_tasks.values())
self.lock.release()
# Wait for rollout completion
done = []
if tasks:
done, _ = await asyncio.wait(
tasks,
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
# Collect done results
for task in done:
traj = await task
traj: TensorDict
task_rid = task.get_name()
with self.lock:
rollout_tasks.pop(task_rid)
self.rollout_stat.accepted += 1
self.rollout_stat.running -= 1
if self.config.enable_rollout_tracing:
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
try:
self.output_queue.put_nowait(traj)
except queue.Full:
raise RuntimeError(
"Output queue full. Please increase queue_size."
)
await asyncio.sleep(1)
except Exception:
traceback.print_exc()
finally:
# Cancel remaining tasks
with self.lock:
for task in rollout_tasks.values():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
try:
self.input_queue.put_nowait((data, workflow))
except queue.Full:
raise RuntimeError("Input queue full. Please increase queue_size.")
def wait(
self,
count: int,
timeout: float | None = None,
should_accept: Callable | None = None,
) -> TensorDict:
tik = time.perf_counter()
accepted = len(self.result_cache)
timeout = timeout or float(7 * 24 * 3600)
while (
accepted < count
and not self.exiting.is_set()
and time.perf_counter() - tik < timeout
):
try:
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
if should_accept is None or should_accept(result):
self.result_cache.append(result)
accepted += 1
else:
with self.lock:
self.rollout_stat.accepted -= 1
except queue.Empty:
pass
if self.exiting.is_set():
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
if accepted < count:
raise TimeoutError(
f"Timed out waiting for {count} rollouts, " f"only received {accepted}."
)
results, self.result_cache = (
self.result_cache[:count],
self.result_cache[count:],
)
return concat_padded_tensors(results)
def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results."""
for item in data:
self.submit(item, workflow)
return self.wait(count=len(data))
def prepare_batch(
self,
dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow",
should_accept: Callable | None = None,
):
if not hasattr(self, "data_generator"):
self.data_generator = itertools.cycle(dataloader)
assert dataloader.batch_size is not None
while True:
# Submit at least two batches to allow maximum overlap
if (
self.get_capacity() + dataloader.batch_size > 0
and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize
):
data = next(self.data_generator)
for item in data:
self.submit(item, workflow=workflow)
try:
return self.wait(
dataloader.batch_size, timeout=1, should_accept=should_accept
)
except TimeoutError:
pass
def pause(self):
self.paused.set()
def resume(self):
self.paused.clear()

View File

@ -55,7 +55,7 @@ class FSDPEngine(BaseHFEngine):
# Simple auto wrap policy # Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy( self.mixed_precision_policy = MixedPrecisionPolicy(
param_dtype=getattr(torch, self.config.dtype), param_dtype=getattr(torch, self.config.dtype),
reduce_dtype=torch.float32, reduce_dtype=getattr(torch, self.config.grad_reduce_dtype),
cast_forward_inputs=True, cast_forward_inputs=True,
) )
self.device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size) self.device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size)

View File

@ -128,7 +128,7 @@ class PPOActor:
advantages = torch.stack(advantages_reversed[::-1], dim=1) advantages = torch.stack(advantages_reversed[::-1], dim=1)
# Optionally perform advantage normalization. # Optionally perform advantage normalization.
if self.adv_norm: if self.adv_norm or self.group_adv_norm:
if self.group_adv_norm: if self.group_adv_norm:
adv_list = [] adv_list = []
for i in range(0, bs, self.group_size): for i in range(0, bs, self.group_size):

View File

@ -2,17 +2,13 @@ import asyncio
import os import os
import random import random
import shutil import shutil
import threading
import time import time
import traceback
from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures import Future, ProcessPoolExecutor
from datetime import datetime from datetime import datetime
from queue import Empty, Full, Queue from typing import Any, Callable, Dict, List
from typing import TYPE_CHECKING, Any, Callable, Dict, List
import aiohttp import aiohttp
import requests import requests
import torch.distributed as dist
import uvloop import uvloop
from tensordict import TensorDict from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
@ -28,25 +24,19 @@ from arealite.api.io_struct import (
VLMResponse, VLMResponse,
WeightUpdateMeta, WeightUpdateMeta,
) )
from arealite.utils.data import concat_padded_tensors from arealite.api.workflow_api import RolloutWorkflow, WorkflowExecutor
from arealite.utils.http import arequest_with_retry, get_default_connector from arealite.utils.http import arequest_with_retry, get_default_connector
from realhf.base import logging, name_resolve, names from realhf.base import logging, name_resolve, names
if TYPE_CHECKING:
from arealite.api.workflow_api import RolloutWorkflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ROLLOUT_POLL_WAIT_TIME = 0.05
RID_CACHE_SIZE = 128 RID_CACHE_SIZE = 128
class RemoteSGLangEngine(InferenceEngine): class RemoteSGLangEngine(InferenceEngine):
def __init__(self, config: InferenceEngineConfig): def __init__(self, config: InferenceEngineConfig):
config.max_concurrent_rollouts = (
config.max_concurrent_rollouts or config.consumer_batch_size
)
self.config = config self.config = config
self.rid_to_address = {} self.rid_to_address = {}
@ -54,29 +44,20 @@ class RemoteSGLangEngine(InferenceEngine):
self.rid_queue = [] self.rid_queue = []
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",") self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
self.session = None
if not self.addresses: if not self.addresses:
raise RuntimeError("No configured SGLang servers.") raise RuntimeError("No configured SGLang servers.")
logger.info("Waiting for server ready...")
for addr in self.addresses:
self._wait_for_server(addr)
logger.info("Servers are all ready!")
self.server_idx = random.randint(0, len(self.addresses) - 1) self.server_idx = random.randint(0, len(self.addresses) - 1)
qsize = config.queue_size or config.max_concurrent_rollouts * 16
self.input_queue = Queue(maxsize=qsize)
self.output_queue = Queue(maxsize=qsize)
self.result_cache = []
self.exiting = threading.Event()
self.paused = threading.Event()
self.lock = threading.Lock()
self.rollout_stat = RolloutStat()
self.distributed_weight_update_initialized = False self.distributed_weight_update_initialized = False
self._version = 0 self._version = 0
self.workflow_executor = WorkflowExecutor(
config=config,
inference_engine=self,
)
def _wait_for_server(self, address): def _wait_for_server(self, address):
base_url = f"http://{address}" base_url = f"http://{address}"
tik = time.time() tik = time.time()
@ -94,137 +75,46 @@ class RemoteSGLangEngine(InferenceEngine):
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
return False return False
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None): def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None = None):
self.rollout_tasks: Dict[str, asyncio.Task] = {} logger.info("Waiting for server ready...")
for addr_ in self.addresses:
self._wait_for_server(addr_)
logger.info("Servers are all ready!")
self.executor = ProcessPoolExecutor(max_workers=1) self.executor = ProcessPoolExecutor(max_workers=1)
self.rollout_thread = threading.Thread(target=self._rollout_thread) self.workflow_executor.initialize()
self.rollout_thread.start()
def destroy(self): def destroy(self):
self.executor.shutdown() self.executor.shutdown()
self.exiting.set()
self.rollout_thread.join()
def set_version(self, version): def set_version(self, version):
with self.lock: self._version = version
self._version = version
def get_version(self): def get_version(self):
with self.lock: return self._version
return self._version
def _rollout_thread(self):
"""Thread that runs the rollout loop."""
try:
uvloop.run(self._rollout_thread_async())
except Exception:
traceback.print_exc()
async def _rollout_thread_async(self):
rollout_tasks = self.rollout_tasks
rid = 0
# NOTE: session is not thread-safe, but we only submit requests in the sub-thread.
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.config.request_timeout,
sock_connect=self.config.request_timeout,
connect=self.config.request_timeout,
),
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
)
try:
while not self.exiting.is_set():
# Check capacity
capacity = self.get_capacity()
# Create new rollout task
while (
capacity > 0
and not self.paused.is_set()
and self.input_queue.qsize() > 0
):
data, workflow = self.input_queue.get_nowait()
logger.debug(f"Get data from puller: {data}")
task = asyncio.create_task(
workflow.arun_episode(self, data), name=str(rid)
)
with self.lock:
rollout_tasks[str(rid)] = task
self.rollout_stat.submitted += 1
self.rollout_stat.running += 1
if self.config.enable_rollout_tracing:
logger.info(
f"Submit rollout rid {rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
capacity -= 1
rid += 1
# Wait for rollout completion
with self.lock:
tasks = list(rollout_tasks.values())
done = []
if tasks:
done, _ = await asyncio.wait(
tasks,
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
# Collect done results
for task in done:
traj = await task
traj: TensorDict
task_rid = task.get_name()
with self.lock:
rollout_tasks.pop(task_rid)
self.rollout_stat.accepted += 1
try:
self.output_queue.put_nowait(traj)
except Full:
raise RuntimeError(
"Output queue full. Please increase queue_size."
)
with self.lock:
self.rollout_stat.running -= 1
if self.config.enable_rollout_tracing:
logger.info(
f"Finish rollout {task_rid}. "
f"Submit: {self.rollout_stat.submitted}, "
f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}."
)
await asyncio.sleep(1)
except Exception:
traceback.print_exc()
finally:
# Cancel remaining tasks
with self.lock:
for task in rollout_tasks.values():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def choose_server(self) -> str: def choose_server(self) -> str:
with self.lock: if self.config.schedule_policy == "round_robin":
if self.config.schedule_policy == "round_robin": server = self.addresses[self.server_idx]
server = self.addresses[self.server_idx] self.server_idx = (self.server_idx + 1) % len(self.addresses)
self.server_idx = (self.server_idx + 1) % len(self.addresses) return server
return server
raise NotImplementedError("Only round-robin scheduling is implemented.") raise NotImplementedError("Only round-robin scheduling is implemented.")
async def agenerate( async def agenerate(
self, req: LLMRequest | VLMRequest self, req: LLMRequest | VLMRequest
) -> LLMResponse | VLMResponse: ) -> LLMResponse | VLMResponse:
"""Async version of generate using aiohttp.""" """Async version of generate using aiohttp."""
if self.session is None:
# NOTE: Lazily initialize aiohttp.ClientSession since it needs to be initialized
# inside asyncio loop in WorkflowExecutor
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.config.request_timeout,
sock_connect=self.config.request_timeout,
connect=self.config.request_timeout,
),
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
)
# Prepare request payload # Prepare request payload
gconfig = req.gconfig gconfig = req.gconfig
stop_token_ids = gconfig.stop_token_ids stop_token_ids = gconfig.stop_token_ids
@ -396,30 +286,8 @@ class RemoteSGLangEngine(InferenceEngine):
fut.add_done_callback(callback) fut.add_done_callback(callback)
return fut return fut
def get_capacity(self): def submit(self, data: Dict[str, Any], workflow: RolloutWorkflow) -> None:
if dist.is_initialized(): return self.workflow_executor.submit(data, workflow)
world_size = dist.get_world_size()
else:
world_size = 1
max_concurrent_rollouts = max(
1, self.config.max_concurrent_rollouts // world_size
)
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
# Staleness control
version = self.get_version()
ofp = self.config.max_head_offpolicyness
with self.lock:
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
return capacity
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
try:
self.input_queue.put_nowait((data, workflow))
except Full:
raise RuntimeError("Input queue full. Please increase queue_size.")
def wait( def wait(
self, self,
@ -427,77 +295,31 @@ class RemoteSGLangEngine(InferenceEngine):
timeout: float | None = None, timeout: float | None = None,
should_accept: Callable | None = None, should_accept: Callable | None = None,
) -> TensorDict: ) -> TensorDict:
tik = time.perf_counter() return self.workflow_executor.wait(
accepted = len(self.result_cache) count,
timeout = timeout or float(7 * 24 * 3600) timeout=timeout,
while ( should_accept=should_accept,
accepted < count
and not self.exiting.is_set()
and time.perf_counter() - tik < timeout
):
try:
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
if should_accept is None or should_accept(result):
self.result_cache.append(result)
accepted += 1
else:
with self.lock:
self.rollout_stat.accepted -= 1
except Empty:
pass
if self.exiting.is_set():
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
if accepted < count:
raise TimeoutError(
f"Timed out waiting for {count} rollouts, " f"only received {accepted}."
)
results, self.result_cache = (
self.result_cache[:count],
self.result_cache[count:],
) )
return concat_padded_tensors(results)
def rollout_batch( def rollout_batch(
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow" self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
) -> TensorDict: ) -> TensorDict:
"""Submit a batch of requests to the inference engine and wait for the results.""" return self.workflow_executor.rollout_batch(data, workflow)
for item in data:
self.submit(item, workflow)
return self.wait(count=len(data))
def prepare_batch( def prepare_batch(
self, self,
dataloader: StatefulDataLoader, dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow", workflow: RolloutWorkflow,
): ):
if not hasattr(self, "data_generator"): return self.workflow_executor.prepare_batch(dataloader, workflow)
self.data_generator = iter(dataloader)
assert dataloader.batch_size is not None
while True:
# Submit at least two batches to allow maximum overlap
if (
self.get_capacity() + dataloader.batch_size > 0
and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize
):
try:
data = next(self.data_generator)
except StopIteration:
self.data_generator = iter(dataloader)
data = next(self.data_generator)
for item in data:
self.submit(item, workflow=workflow)
try:
return self.wait(dataloader.batch_size, timeout=1)
except TimeoutError:
pass
def pause(self): def pause(self):
self.paused.set() """Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
return self.workflow_executor.pause()
def resume(self): def resume(self):
self.paused.clear() """Resume request submission for async rollout."""
return self.workflow_executor.resume()
def update_weights_from_disk( def update_weights_from_disk(

View File

@ -94,6 +94,7 @@ class SGLangEngine(InferenceEngine):
asyncio.run(self._rollout_thread_async()) asyncio.run(self._rollout_thread_async())
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise e
async def _rollout_thread_async(self): async def _rollout_thread_async(self):
data = None data = None

View File

@ -324,7 +324,6 @@ def ray_main():
) )
allocation_mode = config.allocation_mode allocation_mode = config.allocation_mode
allocation_mode = AllocationMode.from_str(allocation_mode) allocation_mode = AllocationMode.from_str(allocation_mode)
sglang_cmds = []
sglang_addrs = [] sglang_addrs = []
n_sglang_nodes = 0 n_sglang_nodes = 0
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG: if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:

View File

@ -61,6 +61,7 @@ async def arequest_with_retry(
ctx = _session.delete(url, timeout=timeo) ctx = _session.delete(url, timeout=timeo)
else: else:
raise ValueError(f"Unsupported HTTP method: {method}") raise ValueError(f"Unsupported HTTP method: {method}")
async with ctx as response: async with ctx as response:
if verbose: if verbose:
logger.info("http requests return") logger.info("http requests return")

View File

@ -31,7 +31,7 @@ class MultiTurnWorkflow(RolloutWorkflow):
messages = data["messages"] messages = data["messages"]
# Run multi-turn rollout until correct # Run multi-turn rollout until correct
t = reward = 0 t = reward = 0
discount = 0 discount = 1
rid = uuid.uuid4().hex rid = uuid.uuid4().hex
while reward == 0 and t < self.max_turns: while reward == 0 and t < self.max_turns:
# Amend a prompt if the previous answer is incorrect # Amend a prompt if the previous answer is incorrect

View File

@ -1,6 +1,8 @@
import asyncio import asyncio
import functools
import os import os
import uuid import uuid
from concurrent.futures import ProcessPoolExecutor
import colorama import colorama
import torch import torch
@ -12,6 +14,11 @@ from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import LLMRequest from arealite.api.io_struct import LLMRequest
from arealite.api.workflow_api import RolloutWorkflow from arealite.api.workflow_api import RolloutWorkflow
from arealite.utils.data import concat_padded_tensors from arealite.utils.data import concat_padded_tensors
from realhf.base import logging
logger = logging.getLogger("RLVR workflow")
REWARD_TIMEOUT_SECONDS = 15
class RLVRWorkflow(RolloutWorkflow): class RLVRWorkflow(RolloutWorkflow):
@ -28,6 +35,7 @@ class RLVRWorkflow(RolloutWorkflow):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.enable_thinking = enable_thinking self.enable_thinking = enable_thinking
self.dump_dir = dump_dir self.dump_dir = dump_dir
self.rw_executor = ProcessPoolExecutor(max_workers=4)
if self.dump_dir is not None and not os.path.exists(self.dump_dir): if self.dump_dir is not None and not os.path.exists(self.dump_dir):
os.makedirs(self.dump_dir, exist_ok=True) os.makedirs(self.dump_dir, exist_ok=True)
@ -54,6 +62,7 @@ class RLVRWorkflow(RolloutWorkflow):
seqlens = [] seqlens = []
results = [] results = []
loop = asyncio.get_event_loop()
for resp in resps: for resp in resps:
seq = resp.input_tokens + resp.output_tokens seq = resp.input_tokens + resp.output_tokens
logprobs = [0.0] * resp.input_len + resp.output_logprobs logprobs = [0.0] * resp.input_len + resp.output_logprobs
@ -65,13 +74,26 @@ class RLVRWorkflow(RolloutWorkflow):
prompt_strs.append(prompt_str) prompt_strs.append(prompt_str)
completions_strs.append(completions_str) completions_strs.append(completions_str)
seqlens.append(len(seq)) seqlens.append(len(seq))
reward = self.reward_fn( try:
prompt=prompt_str, reward = await asyncio.wait_for(
completions=completions_str, loop.run_in_executor(
prompt_ids=resp.input_tokens, self.rw_executor,
completion_ids=resp.output_tokens, functools.partial(
**data, self.reward_fn,
) prompt_str,
completions_str,
resp.input_tokens,
resp.output_tokens,
**data,
),
),
timeout=REWARD_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.warning(
f"Computing reward timeout after {REWARD_TIMEOUT_SECONDS}s. Set reward to 0."
)
reward = 0
rewards.append(reward) rewards.append(reward)
res = dict( res = dict(
# unsqueeze to add an additional batch dimension # unsqueeze to add an additional batch dimension

View File

@ -233,7 +233,7 @@ def prepare_batch(
workflow: "RolloutWorkflow", workflow: "RolloutWorkflow",
): ):
if not hasattr(self, "data_generator"): if not hasattr(self, "data_generator"):
self.data_generator = iter(dataloader) self.data_generator = itertools.cycle(dataloader)
assert dataloader.batch_size is not None assert dataloader.batch_size is not None
while True: while True:
# Submit at least two batches to allow maximum overlap # Submit at least two batches to allow maximum overlap
@ -242,11 +242,7 @@ def prepare_batch(
and self.input_queue.qsize() + dataloader.batch_size and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize < self.input_queue.maxsize
): ):
try: data = next(self.data_generator)
data = next(self.data_generator)
except StopIteration:
self.data_generator = iter(dataloader)
data = next(self.data_generator)
for item in data: for item in data:
# submit data into input_queue # submit data into input_queue
self.submit(item, workflow=workflow) self.submit(item, workflow=workflow)
@ -264,18 +260,13 @@ rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize() rollout.initialize()
eval_rollout = ... eval_rollout = ...
data_generator = iter(train_dataloader) data_generator = iterools.cycle(train_dataloader)
for global_step in range(max_steps): for global_step in range(max_steps):
# rollout batched training data for current step # rollout batched training data for current step
if config.async_training: if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow) batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else: else:
try: batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
``` ```
If you want to use rollout workflows with custom reward functions or agentic tool If you want to use rollout workflows with custom reward functions or agentic tool
@ -375,17 +366,12 @@ Now a complete GRPO training step in AReaLite is done! The core logic of our exa
training script can be summarized as: training script can be summarized as:
```python ```python
data_generator = iter(train_dataloader) data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps): for global_step in range(max_steps):
if config.async_training: if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow) batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else: else:
try: batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
logp = actor.compute_logp(batch) logp = actor.compute_logp(batch)
batch["prox_logp"] = logp batch["prox_logp"] = logp

View File

@ -75,7 +75,7 @@ and converting it into an `LLMRequest` object for the inference engine:
class MultiTurnWorkflow(RolloutWorkflow): class MultiTurnWorkflow(RolloutWorkflow):
# ... __init__ method above ... # ... __init__ method above ...
async def arun_episode(self, engine: InferenceEngine, data): async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
# Initialize result containers # Initialize result containers
seq, logprobs, loss_mask, versions = [], [], [], [] seq, logprobs, loss_mask, versions = [], [], [], []
messages = data["messages"] messages = data["messages"]
@ -119,7 +119,7 @@ we'll apply a discount, add feedback to the conversation, and let the model try
class MultiTurnWorkflow(RolloutWorkflow): class MultiTurnWorkflow(RolloutWorkflow):
# ... previous methods ... # ... previous methods ...
async def arun_episode(self, engine: InferenceEngine, data): async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
# ... initialization code ... # ... initialization code ...
while reward == 0 and t < self.max_turns: while reward == 0 and t < self.max_turns:
# Add feedback if the previous answer was incorrect # Add feedback if the previous answer was incorrect
@ -190,7 +190,7 @@ Finally, let's complete the implementation by collecting trajectories in the
class MultiTurnWorkflow(RolloutWorkflow): class MultiTurnWorkflow(RolloutWorkflow):
# ... previous methods ... # ... previous methods ...
async def arun_episode(self, engine: InferenceEngine, data): async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
# ... episode logic above ... # ... episode logic above ...
while reward == 0 and t < self.max_turns: while reward == 0 and t < self.max_turns:
@ -417,8 +417,27 @@ Using your custom workflow is straightforward—just create it in your training
pass it to the `rollout_batch` or `prepare_batch` method: pass it to the `rollout_batch` or `prepare_batch` method:
```python ```python
# in realhf/impl/agent/__init__.py def main(args):
import realhf.impl.agent.math_multi_turn_agent # ... setup code ...
# Create your custom workflow
workflow = MultiTurnWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=tokenizer,
turn_discount=0.9,
max_turns=5,
)
# Run training—no other changes needed!
data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps):
with stats_tracker.record_timing("rollout"):
if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else:
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
# ... continue with training loop ...
``` ```
Then update your experiment configuration in Then update your experiment configuration in

View File

@ -171,16 +171,11 @@ def main(args):
) )
# Main training loop # Main training loop
data_generator = itertools.cycle(dataloader)
for global_step in range(max_steps): for global_step in range(max_steps):
# Generate training data # Generate training data
with stats_tracker.record_timing("rollout"): with stats_tracker.record_timing("rollout"):
try: batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = batch.to(actor.device) batch = batch.to(actor.device)

View File

@ -68,6 +68,7 @@ ref:
trial_name: ${trial_name} trial_name: ${trial_name}
path: ${actor.path} path: ${actor.path}
init_from_scratch: false init_from_scratch: false
disable_dropout: true
dtype: ${actor.dtype} dtype: ${actor.dtype}
mb_spec: mb_spec:
max_tokens_per_mb: 10240 max_tokens_per_mb: 10240
@ -81,9 +82,9 @@ sglang:
random_seed: ${seed} random_seed: ${seed}
skip_tokenizer_init: true skip_tokenizer_init: true
dtype: ${actor.dtype} dtype: ${actor.dtype}
max_running_requests: null max_running_requests: null
context_length: 32768 context_length: 32768
mem_fraction_static: 0.9 mem_fraction_static: 0.8
# datasets # datasets
train_dataset: train_dataset:
@ -94,7 +95,7 @@ train_dataset:
path: openai/gsm8k path: openai/gsm8k
type: rl type: rl
valid_dataset: valid_dataset:
batch_size: 256 batch_size: 256
shuffle: true shuffle: true
pin_memory: true pin_memory: true
@ -131,4 +132,5 @@ stats_logger:
experiment_name: ${experiment_name} experiment_name: ${experiment_name}
trial_name: ${trial_name} trial_name: ${trial_name}
fileroot: ${cluster.fileroot} fileroot: ${cluster.fileroot}
wandb:
mode: disabled

View File

@ -1,3 +1,4 @@
import itertools
import os import os
import sys import sys
@ -129,7 +130,7 @@ def main(args):
max_steps = total_epochs * steps_per_epoch max_steps = total_epochs * steps_per_epoch
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}") logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
data_generator = iter(train_dataloader) data_generator = itertools.cycle(train_dataloader)
for global_step in range(max_steps): for global_step in range(max_steps):
epoch = global_step // steps_per_epoch epoch = global_step // steps_per_epoch
step = global_step % steps_per_epoch step = global_step % steps_per_epoch
@ -138,12 +139,7 @@ def main(args):
if config.async_training: if config.async_training:
batch = rollout.prepare_batch(train_dataloader, workflow=workflow) batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
else: else:
try: batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
data = next(data_generator)
except StopIteration:
data_generator = iter(train_dataloader)
data = next(data_generator)
batch = rollout.rollout_batch(data, workflow=workflow)
batch = batch.to(actor.device) batch = batch.to(actor.device)
# Create barrier to synchronize all rollout processes. # Create barrier to synchronize all rollout processes.

View File

@ -6,7 +6,7 @@ pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 "deepspeed>=0.17.
pip install "sglang[all]==0.4.9.post2" pip install "sglang[all]==0.4.9.post2"
pip install megatron-core==0.11.0 nvidia-ml-py pip install megatron-core==0.11.0 nvidia-ml-py
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
pip install "flash-attn<=2.7.3" --no-build-isolation pip install "flash-attn<=2.8.2" --no-build-isolation
# Package used for calculating math reward # Package used for calculating math reward
pip install -e evaluation/latex2sympy pip install -e evaluation/latex2sympy

View File

@ -56,6 +56,8 @@ dependencies = [
"torchdata", "torchdata",
"autoflake", "autoflake",
"tensordict", "tensordict",
"pybase64",
"msgspec",
# Monitoring and logging # Monitoring and logging
"wandb", "wandb",

View File

@ -74,3 +74,6 @@ torchdata
autoflake autoflake
tensordict tensordict
deepspeed>=0.17.2 deepspeed>=0.17.2
pybase64
msgspec
transformers==4.53.1