mirror of https://github.com/inclusionAI/AReaL
[lite] [fix] Fix a performance issue and several minor issues before release (#203)
* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/370 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * . * . * . * fix * PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results. Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/392 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * . * fix * . * qwen2 grpo works * fix * fix * async works * fix * slurm launcher not tested * fix arg parse * . * sglang server wrapper * . * . * slurm run * ready for boba * debug * 32k run * . * . * fix * . * . * . * . * . * fix * . * fix * . * . * . * . * fix * . * . * . * . * . * . * . * refactor train engine * refactor train engine * . * fix update weight error * . * . * match train * format * . * fix * seems to work * . * . * . * . * . * PullRequest: 408 [Feature] Bump SGLang version to v0.4.9.post2 Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/408 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * bump arealite to sglang 0.4.9.post2 * . * PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta` * PullRequest: 422 [lite] Fix tests and scripts after updating sgl to 0.4.9 Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/422 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * bump arealite to sglang 0.4.9.post2 * . * PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta` * . * PullRequest: 423 [lite] Remove the boba example for github release. Merge branch fw/remove-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/423 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * update readme * PullRequest: 431 [Fix] Fix environment of lite Merge branch fw/lite-dev of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/431 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * change requirements * . * . * . * PullRequest: 440 [FIX] fix update weight from disk Merge branch sxj/lite-fix-disk-update of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/440 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * [FIX] fix update weight from disk * PullRequest: 442 [lite] Refactor `RemoteSGLangEngine` into two parts: `RemoteSGLangEngine` and `WorkflowExecutor`. Merge branch mzy/workflow-executor of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/442 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * refactor workflow executor * . * fix tests and eval * . * . * revert workflow executor into remote sglang engine * . * PullRequest: 456 [lite] [Bug] Use `ProcessPoolExecutor` to calculate reward to avoid rollout slow down Merge branch mzy/lite/fix-reward of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/456?tab=comment Reviewed-by: 博惟 <bowei.fw@antgroup.com> * fix reward * . * . * . * PullRequest: 460 [lite][fix] add a warning when reward computation timeout Merge branch fw/lite-fix of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/460 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * add a warning when reward computation timeout * PullRequest: 465 [lite][fix] Fix issues raised by tsao Merge branch fw/lite-fix of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/465 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix --------- Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com> Co-authored-by: 冰临 <shenxujie.sxj@antgroup.com>
This commit is contained in:
parent
7fb6a80e48
commit
e507ce281c
|
@ -195,7 +195,10 @@ class TrainEngineConfig:
|
|||
gradient_checkpointing: bool = field(
|
||||
default=True, metadata={"help": "Enable gradient checkpointing"}
|
||||
)
|
||||
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."})
|
||||
dtype: str = field(default="bfloat16", metadata={"help": "Parameter dtype."})
|
||||
grad_reduce_dtype: str = field(
|
||||
default="float32", metadata={"help": "Gradient reduce dtype."}
|
||||
)
|
||||
optimizer: Optional[OptimizerConfig] = field(
|
||||
default=None, metadata={"help": "Optimizer configuration"}
|
||||
)
|
||||
|
@ -262,7 +265,7 @@ class PPOActorConfig(TrainEngineConfig):
|
|||
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
||||
)
|
||||
adv_norm: bool = field(
|
||||
default=True, metadata={"help": "Enable advantage normalization"}
|
||||
default=True, metadata={"help": "Enable advantage normalization globally"}
|
||||
)
|
||||
|
||||
# KL Control
|
||||
|
@ -275,7 +278,9 @@ class PPOActorConfig(TrainEngineConfig):
|
|||
)
|
||||
use_decoupled_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
||||
metadata={
|
||||
"help": "Use the decoupled loss. Implicitly enable recompute_logprob."
|
||||
},
|
||||
)
|
||||
behav_imp_weight_cap: Optional[float] = field(
|
||||
default=None,
|
||||
|
@ -329,7 +334,7 @@ class SGLangConfig:
|
|||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
dtype: str = "float16"
|
||||
dtype: str = "bfloat16"
|
||||
kv_cache_dtype: str = "auto"
|
||||
# logging
|
||||
log_level: str = "warning"
|
||||
|
@ -407,31 +412,8 @@ class SGLangConfig:
|
|||
dist_init_addr=dist_init_addr,
|
||||
**args,
|
||||
)
|
||||
sglang_version = pkg_version.get_version("sglang")
|
||||
if sglang_version:
|
||||
version_less_than_0_4_4 = (
|
||||
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
|
||||
)
|
||||
version_less_than_0_4_3 = (
|
||||
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
|
||||
)
|
||||
elif pkg_version.is_available("sglang"):
|
||||
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
|
||||
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
|
||||
else:
|
||||
raise ValueError(
|
||||
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
|
||||
)
|
||||
if version_less_than_0_4_4:
|
||||
args.pop("log_requests_level")
|
||||
if version_less_than_0_4_3:
|
||||
args.pop("enable_nccl_nvls")
|
||||
args.pop("triton_attention_num_kv_splits")
|
||||
args.pop("cuda_graph_bs")
|
||||
args.pop("enable_memory_saver")
|
||||
args.pop("allow_auto_truncate")
|
||||
args.pop("file_storage_path")
|
||||
|
||||
if not pkg_version.is_version_greater_or_equal("sglang", "0.4.9.post2"):
|
||||
raise RuntimeError("Needs sglang>=0.4.9.post2 to run the code.")
|
||||
return args
|
||||
|
||||
|
||||
|
@ -466,7 +448,7 @@ class InferenceEngineConfig:
|
|||
default="round_robin",
|
||||
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
|
||||
)
|
||||
setup_timeout: float = field(default=90.0)
|
||||
setup_timeout: float = field(default=120.0)
|
||||
request_timeout: float = field(
|
||||
default=3600, metadata={"help": "Timeout for HTTP requests."}
|
||||
)
|
||||
|
|
|
@ -174,6 +174,7 @@ class InferenceEngine(abc.ABC):
|
|||
self,
|
||||
dataloader: StatefulDataLoader,
|
||||
workflow: "RolloutWorkflow",
|
||||
should_accept: Callable | None = None,
|
||||
):
|
||||
"""Asynchronously submit and wait until a full batch is ready."""
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -1,10 +1,30 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict
|
||||
import asyncio
|
||||
import itertools
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
||||
|
||||
import torch.distributed as dist
|
||||
import uvloop
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import InferenceEngineConfig
|
||||
from arealite.api.engine_api import InferenceEngine
|
||||
from arealite.api.io_struct import RolloutStat
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from realhf.base import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arealite.api.engine_api import InferenceEngine
|
||||
|
||||
logger = logging.getLogger("arealite.workflow_api")
|
||||
|
||||
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.05
|
||||
|
||||
|
||||
class RolloutWorkflow:
|
||||
|
||||
|
@ -16,3 +36,226 @@ class RolloutWorkflow:
|
|||
See concrete example implementations under the `arealite/workflow` directory.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InferenceEngineConfig,
|
||||
inference_engine: "InferenceEngine",
|
||||
):
|
||||
config.max_concurrent_rollouts = (
|
||||
config.max_concurrent_rollouts or config.consumer_batch_size
|
||||
)
|
||||
self.config = config
|
||||
self.exiting = threading.Event()
|
||||
self.paused = threading.Event()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
self.inference_engine = inference_engine
|
||||
|
||||
qsize = config.queue_size or config.max_concurrent_rollouts * 16
|
||||
self.input_queue = queue.Queue(maxsize=qsize)
|
||||
self.output_queue = queue.Queue(maxsize=qsize)
|
||||
self.result_cache: List[TensorDict] = []
|
||||
|
||||
self.rollout_stat = RolloutStat()
|
||||
|
||||
def initialize(self):
|
||||
self.rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
||||
self.rollout_thread.start()
|
||||
|
||||
def destroy(self):
|
||||
self.exiting.set()
|
||||
self.rollout_thread.join()
|
||||
|
||||
def get_capacity(self):
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
with self.lock:
|
||||
max_concurrent_rollouts = max(
|
||||
1, self.config.max_concurrent_rollouts // world_size
|
||||
)
|
||||
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
|
||||
# Staleness control
|
||||
version = self.inference_engine.get_version()
|
||||
ofp = self.config.max_head_offpolicyness
|
||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
|
||||
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
|
||||
return capacity
|
||||
|
||||
def _rollout_thread(self):
|
||||
"""Thread that runs the rollout loop."""
|
||||
try:
|
||||
uvloop.run(self._rollout_thread_async())
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
async def _rollout_thread_async(self):
|
||||
rollout_tasks = self.rollout_tasks
|
||||
rid = 0
|
||||
try:
|
||||
while not self.exiting.is_set():
|
||||
# Check capacity
|
||||
capacity = self.get_capacity()
|
||||
# Create new rollout task
|
||||
self.lock.acquire()
|
||||
while (
|
||||
capacity > 0
|
||||
and not self.paused.is_set()
|
||||
and self.input_queue.qsize() > 0
|
||||
):
|
||||
data, workflow = self.input_queue.get_nowait()
|
||||
logger.debug(f"Get data from puller: {data}")
|
||||
task = asyncio.create_task(
|
||||
workflow.arun_episode(self.inference_engine, data),
|
||||
name=str(rid),
|
||||
)
|
||||
rollout_tasks[str(rid)] = task
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
if self.config.enable_rollout_tracing:
|
||||
logger.info(
|
||||
f"Submit rollout rid {rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
capacity -= 1
|
||||
rid += 1
|
||||
tasks = list(rollout_tasks.values())
|
||||
self.lock.release()
|
||||
|
||||
# Wait for rollout completion
|
||||
done = []
|
||||
if tasks:
|
||||
done, _ = await asyncio.wait(
|
||||
tasks,
|
||||
timeout=ROLLOUT_POLL_WAIT_TIME,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
# Collect done results
|
||||
for task in done:
|
||||
traj = await task
|
||||
traj: TensorDict
|
||||
task_rid = task.get_name()
|
||||
with self.lock:
|
||||
rollout_tasks.pop(task_rid)
|
||||
self.rollout_stat.accepted += 1
|
||||
|
||||
self.rollout_stat.running -= 1
|
||||
if self.config.enable_rollout_tracing:
|
||||
logger.info(
|
||||
f"Finish rollout {task_rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
try:
|
||||
self.output_queue.put_nowait(traj)
|
||||
except queue.Full:
|
||||
raise RuntimeError(
|
||||
"Output queue full. Please increase queue_size."
|
||||
)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Cancel remaining tasks
|
||||
with self.lock:
|
||||
for task in rollout_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
||||
try:
|
||||
self.input_queue.put_nowait((data, workflow))
|
||||
except queue.Full:
|
||||
raise RuntimeError("Input queue full. Please increase queue_size.")
|
||||
|
||||
def wait(
|
||||
self,
|
||||
count: int,
|
||||
timeout: float | None = None,
|
||||
should_accept: Callable | None = None,
|
||||
) -> TensorDict:
|
||||
tik = time.perf_counter()
|
||||
accepted = len(self.result_cache)
|
||||
timeout = timeout or float(7 * 24 * 3600)
|
||||
while (
|
||||
accepted < count
|
||||
and not self.exiting.is_set()
|
||||
and time.perf_counter() - tik < timeout
|
||||
):
|
||||
try:
|
||||
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
|
||||
if should_accept is None or should_accept(result):
|
||||
self.result_cache.append(result)
|
||||
accepted += 1
|
||||
else:
|
||||
with self.lock:
|
||||
self.rollout_stat.accepted -= 1
|
||||
except queue.Empty:
|
||||
pass
|
||||
if self.exiting.is_set():
|
||||
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
|
||||
if accepted < count:
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for {count} rollouts, " f"only received {accepted}."
|
||||
)
|
||||
results, self.result_cache = (
|
||||
self.result_cache[:count],
|
||||
self.result_cache[count:],
|
||||
)
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
def rollout_batch(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
for item in data:
|
||||
self.submit(item, workflow)
|
||||
return self.wait(count=len(data))
|
||||
|
||||
def prepare_batch(
|
||||
self,
|
||||
dataloader: StatefulDataLoader,
|
||||
workflow: "RolloutWorkflow",
|
||||
should_accept: Callable | None = None,
|
||||
):
|
||||
if not hasattr(self, "data_generator"):
|
||||
self.data_generator = itertools.cycle(dataloader)
|
||||
assert dataloader.batch_size is not None
|
||||
while True:
|
||||
# Submit at least two batches to allow maximum overlap
|
||||
if (
|
||||
self.get_capacity() + dataloader.batch_size > 0
|
||||
and self.input_queue.qsize() + dataloader.batch_size
|
||||
< self.input_queue.maxsize
|
||||
):
|
||||
data = next(self.data_generator)
|
||||
for item in data:
|
||||
self.submit(item, workflow=workflow)
|
||||
try:
|
||||
return self.wait(
|
||||
dataloader.batch_size, timeout=1, should_accept=should_accept
|
||||
)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
def pause(self):
|
||||
self.paused.set()
|
||||
|
||||
def resume(self):
|
||||
self.paused.clear()
|
||||
|
|
|
@ -55,7 +55,7 @@ class FSDPEngine(BaseHFEngine):
|
|||
# Simple auto wrap policy
|
||||
self.mixed_precision_policy = MixedPrecisionPolicy(
|
||||
param_dtype=getattr(torch, self.config.dtype),
|
||||
reduce_dtype=torch.float32,
|
||||
reduce_dtype=getattr(torch, self.config.grad_reduce_dtype),
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
self.device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size)
|
||||
|
|
|
@ -128,7 +128,7 @@ class PPOActor:
|
|||
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
||||
|
||||
# Optionally perform advantage normalization.
|
||||
if self.adv_norm:
|
||||
if self.adv_norm or self.group_adv_norm:
|
||||
if self.group_adv_norm:
|
||||
adv_list = []
|
||||
for i in range(0, bs, self.group_size):
|
||||
|
|
|
@ -2,17 +2,13 @@ import asyncio
|
|||
import os
|
||||
import random
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from datetime import datetime
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import torch.distributed as dist
|
||||
import uvloop
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
@ -28,25 +24,19 @@ from arealite.api.io_struct import (
|
|||
VLMResponse,
|
||||
WeightUpdateMeta,
|
||||
)
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from arealite.api.workflow_api import RolloutWorkflow, WorkflowExecutor
|
||||
from arealite.utils.http import arequest_with_retry, get_default_connector
|
||||
from realhf.base import logging, name_resolve, names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.05
|
||||
RID_CACHE_SIZE = 128
|
||||
|
||||
|
||||
class RemoteSGLangEngine(InferenceEngine):
|
||||
|
||||
def __init__(self, config: InferenceEngineConfig):
|
||||
config.max_concurrent_rollouts = (
|
||||
config.max_concurrent_rollouts or config.consumer_batch_size
|
||||
)
|
||||
self.config = config
|
||||
|
||||
self.rid_to_address = {}
|
||||
|
@ -54,29 +44,20 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self.rid_queue = []
|
||||
|
||||
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
|
||||
self.session = None
|
||||
|
||||
if not self.addresses:
|
||||
raise RuntimeError("No configured SGLang servers.")
|
||||
logger.info("Waiting for server ready...")
|
||||
for addr in self.addresses:
|
||||
self._wait_for_server(addr)
|
||||
logger.info("Servers are all ready!")
|
||||
|
||||
self.server_idx = random.randint(0, len(self.addresses) - 1)
|
||||
|
||||
qsize = config.queue_size or config.max_concurrent_rollouts * 16
|
||||
self.input_queue = Queue(maxsize=qsize)
|
||||
self.output_queue = Queue(maxsize=qsize)
|
||||
self.result_cache = []
|
||||
|
||||
self.exiting = threading.Event()
|
||||
self.paused = threading.Event()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
self.rollout_stat = RolloutStat()
|
||||
self.distributed_weight_update_initialized = False
|
||||
|
||||
self._version = 0
|
||||
|
||||
self.workflow_executor = WorkflowExecutor(
|
||||
config=config,
|
||||
inference_engine=self,
|
||||
)
|
||||
|
||||
def _wait_for_server(self, address):
|
||||
base_url = f"http://{address}"
|
||||
tik = time.time()
|
||||
|
@ -94,127 +75,24 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
except requests.exceptions.RequestException as e:
|
||||
return False
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
||||
self.rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None = None):
|
||||
logger.info("Waiting for server ready...")
|
||||
for addr_ in self.addresses:
|
||||
self._wait_for_server(addr_)
|
||||
logger.info("Servers are all ready!")
|
||||
self.executor = ProcessPoolExecutor(max_workers=1)
|
||||
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
||||
self.rollout_thread.start()
|
||||
self.workflow_executor.initialize()
|
||||
|
||||
def destroy(self):
|
||||
self.executor.shutdown()
|
||||
self.exiting.set()
|
||||
self.rollout_thread.join()
|
||||
|
||||
def set_version(self, version):
|
||||
with self.lock:
|
||||
self._version = version
|
||||
|
||||
def get_version(self):
|
||||
with self.lock:
|
||||
return self._version
|
||||
|
||||
def _rollout_thread(self):
|
||||
"""Thread that runs the rollout loop."""
|
||||
try:
|
||||
uvloop.run(self._rollout_thread_async())
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
async def _rollout_thread_async(self):
|
||||
rollout_tasks = self.rollout_tasks
|
||||
rid = 0
|
||||
|
||||
# NOTE: session is not thread-safe, but we only submit requests in the sub-thread.
|
||||
self.session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.request_timeout,
|
||||
sock_connect=self.config.request_timeout,
|
||||
connect=self.config.request_timeout,
|
||||
),
|
||||
read_bufsize=1024 * 1024 * 10,
|
||||
connector=get_default_connector(),
|
||||
)
|
||||
|
||||
try:
|
||||
while not self.exiting.is_set():
|
||||
# Check capacity
|
||||
capacity = self.get_capacity()
|
||||
# Create new rollout task
|
||||
while (
|
||||
capacity > 0
|
||||
and not self.paused.is_set()
|
||||
and self.input_queue.qsize() > 0
|
||||
):
|
||||
data, workflow = self.input_queue.get_nowait()
|
||||
logger.debug(f"Get data from puller: {data}")
|
||||
task = asyncio.create_task(
|
||||
workflow.arun_episode(self, data), name=str(rid)
|
||||
)
|
||||
with self.lock:
|
||||
rollout_tasks[str(rid)] = task
|
||||
self.rollout_stat.submitted += 1
|
||||
self.rollout_stat.running += 1
|
||||
if self.config.enable_rollout_tracing:
|
||||
logger.info(
|
||||
f"Submit rollout rid {rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
capacity -= 1
|
||||
rid += 1
|
||||
# Wait for rollout completion
|
||||
with self.lock:
|
||||
tasks = list(rollout_tasks.values())
|
||||
done = []
|
||||
if tasks:
|
||||
done, _ = await asyncio.wait(
|
||||
tasks,
|
||||
timeout=ROLLOUT_POLL_WAIT_TIME,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
# Collect done results
|
||||
for task in done:
|
||||
traj = await task
|
||||
traj: TensorDict
|
||||
task_rid = task.get_name()
|
||||
with self.lock:
|
||||
rollout_tasks.pop(task_rid)
|
||||
self.rollout_stat.accepted += 1
|
||||
|
||||
try:
|
||||
self.output_queue.put_nowait(traj)
|
||||
except Full:
|
||||
raise RuntimeError(
|
||||
"Output queue full. Please increase queue_size."
|
||||
)
|
||||
|
||||
with self.lock:
|
||||
self.rollout_stat.running -= 1
|
||||
if self.config.enable_rollout_tracing:
|
||||
logger.info(
|
||||
f"Finish rollout {task_rid}. "
|
||||
f"Submit: {self.rollout_stat.submitted}, "
|
||||
f"running: {self.rollout_stat.running}, "
|
||||
f"accepted: {self.rollout_stat.accepted}."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Cancel remaining tasks
|
||||
with self.lock:
|
||||
for task in rollout_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def choose_server(self) -> str:
|
||||
with self.lock:
|
||||
if self.config.schedule_policy == "round_robin":
|
||||
server = self.addresses[self.server_idx]
|
||||
self.server_idx = (self.server_idx + 1) % len(self.addresses)
|
||||
|
@ -225,6 +103,18 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
self, req: LLMRequest | VLMRequest
|
||||
) -> LLMResponse | VLMResponse:
|
||||
"""Async version of generate using aiohttp."""
|
||||
if self.session is None:
|
||||
# NOTE: Lazily initialize aiohttp.ClientSession since it needs to be initialized
|
||||
# inside asyncio loop in WorkflowExecutor
|
||||
self.session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.request_timeout,
|
||||
sock_connect=self.config.request_timeout,
|
||||
connect=self.config.request_timeout,
|
||||
),
|
||||
read_bufsize=1024 * 1024 * 10,
|
||||
connector=get_default_connector(),
|
||||
)
|
||||
# Prepare request payload
|
||||
gconfig = req.gconfig
|
||||
stop_token_ids = gconfig.stop_token_ids
|
||||
|
@ -396,30 +286,8 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
fut.add_done_callback(callback)
|
||||
return fut
|
||||
|
||||
def get_capacity(self):
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
max_concurrent_rollouts = max(
|
||||
1, self.config.max_concurrent_rollouts // world_size
|
||||
)
|
||||
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
|
||||
# Staleness control
|
||||
version = self.get_version()
|
||||
ofp = self.config.max_head_offpolicyness
|
||||
with self.lock:
|
||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
|
||||
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
|
||||
return capacity
|
||||
|
||||
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
||||
try:
|
||||
self.input_queue.put_nowait((data, workflow))
|
||||
except Full:
|
||||
raise RuntimeError("Input queue full. Please increase queue_size.")
|
||||
def submit(self, data: Dict[str, Any], workflow: RolloutWorkflow) -> None:
|
||||
return self.workflow_executor.submit(data, workflow)
|
||||
|
||||
def wait(
|
||||
self,
|
||||
|
@ -427,77 +295,31 @@ class RemoteSGLangEngine(InferenceEngine):
|
|||
timeout: float | None = None,
|
||||
should_accept: Callable | None = None,
|
||||
) -> TensorDict:
|
||||
tik = time.perf_counter()
|
||||
accepted = len(self.result_cache)
|
||||
timeout = timeout or float(7 * 24 * 3600)
|
||||
while (
|
||||
accepted < count
|
||||
and not self.exiting.is_set()
|
||||
and time.perf_counter() - tik < timeout
|
||||
):
|
||||
try:
|
||||
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
|
||||
if should_accept is None or should_accept(result):
|
||||
self.result_cache.append(result)
|
||||
accepted += 1
|
||||
else:
|
||||
with self.lock:
|
||||
self.rollout_stat.accepted -= 1
|
||||
except Empty:
|
||||
pass
|
||||
if self.exiting.is_set():
|
||||
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
|
||||
if accepted < count:
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for {count} rollouts, " f"only received {accepted}."
|
||||
return self.workflow_executor.wait(
|
||||
count,
|
||||
timeout=timeout,
|
||||
should_accept=should_accept,
|
||||
)
|
||||
results, self.result_cache = (
|
||||
self.result_cache[:count],
|
||||
self.result_cache[count:],
|
||||
)
|
||||
return concat_padded_tensors(results)
|
||||
|
||||
def rollout_batch(
|
||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||
) -> TensorDict:
|
||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||
for item in data:
|
||||
self.submit(item, workflow)
|
||||
return self.wait(count=len(data))
|
||||
return self.workflow_executor.rollout_batch(data, workflow)
|
||||
|
||||
def prepare_batch(
|
||||
self,
|
||||
dataloader: StatefulDataLoader,
|
||||
workflow: "RolloutWorkflow",
|
||||
workflow: RolloutWorkflow,
|
||||
):
|
||||
if not hasattr(self, "data_generator"):
|
||||
self.data_generator = iter(dataloader)
|
||||
assert dataloader.batch_size is not None
|
||||
while True:
|
||||
# Submit at least two batches to allow maximum overlap
|
||||
if (
|
||||
self.get_capacity() + dataloader.batch_size > 0
|
||||
and self.input_queue.qsize() + dataloader.batch_size
|
||||
< self.input_queue.maxsize
|
||||
):
|
||||
try:
|
||||
data = next(self.data_generator)
|
||||
|
||||
except StopIteration:
|
||||
self.data_generator = iter(dataloader)
|
||||
data = next(self.data_generator)
|
||||
for item in data:
|
||||
self.submit(item, workflow=workflow)
|
||||
try:
|
||||
return self.wait(dataloader.batch_size, timeout=1)
|
||||
except TimeoutError:
|
||||
pass
|
||||
return self.workflow_executor.prepare_batch(dataloader, workflow)
|
||||
|
||||
def pause(self):
|
||||
self.paused.set()
|
||||
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
|
||||
return self.workflow_executor.pause()
|
||||
|
||||
def resume(self):
|
||||
self.paused.clear()
|
||||
"""Resume request submission for async rollout."""
|
||||
return self.workflow_executor.resume()
|
||||
|
||||
|
||||
def update_weights_from_disk(
|
||||
|
|
|
@ -94,6 +94,7 @@ class SGLangEngine(InferenceEngine):
|
|||
asyncio.run(self._rollout_thread_async())
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def _rollout_thread_async(self):
|
||||
data = None
|
||||
|
|
|
@ -324,7 +324,6 @@ def ray_main():
|
|||
)
|
||||
allocation_mode = config.allocation_mode
|
||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||
sglang_cmds = []
|
||||
sglang_addrs = []
|
||||
n_sglang_nodes = 0
|
||||
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||
|
|
|
@ -61,6 +61,7 @@ async def arequest_with_retry(
|
|||
ctx = _session.delete(url, timeout=timeo)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
async with ctx as response:
|
||||
if verbose:
|
||||
logger.info("http requests return")
|
||||
|
|
|
@ -31,7 +31,7 @@ class MultiTurnWorkflow(RolloutWorkflow):
|
|||
messages = data["messages"]
|
||||
# Run multi-turn rollout until correct
|
||||
t = reward = 0
|
||||
discount = 0
|
||||
discount = 1
|
||||
rid = uuid.uuid4().hex
|
||||
while reward == 0 and t < self.max_turns:
|
||||
# Amend a prompt if the previous answer is incorrect
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import uuid
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import colorama
|
||||
import torch
|
||||
|
@ -12,6 +14,11 @@ from arealite.api.engine_api import InferenceEngine
|
|||
from arealite.api.io_struct import LLMRequest
|
||||
from arealite.api.workflow_api import RolloutWorkflow
|
||||
from arealite.utils.data import concat_padded_tensors
|
||||
from realhf.base import logging
|
||||
|
||||
logger = logging.getLogger("RLVR workflow")
|
||||
|
||||
REWARD_TIMEOUT_SECONDS = 15
|
||||
|
||||
|
||||
class RLVRWorkflow(RolloutWorkflow):
|
||||
|
@ -28,6 +35,7 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
self.tokenizer = tokenizer
|
||||
self.enable_thinking = enable_thinking
|
||||
self.dump_dir = dump_dir
|
||||
self.rw_executor = ProcessPoolExecutor(max_workers=4)
|
||||
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
|
||||
os.makedirs(self.dump_dir, exist_ok=True)
|
||||
|
||||
|
@ -54,6 +62,7 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
seqlens = []
|
||||
|
||||
results = []
|
||||
loop = asyncio.get_event_loop()
|
||||
for resp in resps:
|
||||
seq = resp.input_tokens + resp.output_tokens
|
||||
logprobs = [0.0] * resp.input_len + resp.output_logprobs
|
||||
|
@ -65,13 +74,26 @@ class RLVRWorkflow(RolloutWorkflow):
|
|||
prompt_strs.append(prompt_str)
|
||||
completions_strs.append(completions_str)
|
||||
seqlens.append(len(seq))
|
||||
reward = self.reward_fn(
|
||||
prompt=prompt_str,
|
||||
completions=completions_str,
|
||||
prompt_ids=resp.input_tokens,
|
||||
completion_ids=resp.output_tokens,
|
||||
try:
|
||||
reward = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
self.rw_executor,
|
||||
functools.partial(
|
||||
self.reward_fn,
|
||||
prompt_str,
|
||||
completions_str,
|
||||
resp.input_tokens,
|
||||
resp.output_tokens,
|
||||
**data,
|
||||
),
|
||||
),
|
||||
timeout=REWARD_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Computing reward timeout after {REWARD_TIMEOUT_SECONDS}s. Set reward to 0."
|
||||
)
|
||||
reward = 0
|
||||
rewards.append(reward)
|
||||
res = dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
|
|
|
@ -233,7 +233,7 @@ def prepare_batch(
|
|||
workflow: "RolloutWorkflow",
|
||||
):
|
||||
if not hasattr(self, "data_generator"):
|
||||
self.data_generator = iter(dataloader)
|
||||
self.data_generator = itertools.cycle(dataloader)
|
||||
assert dataloader.batch_size is not None
|
||||
while True:
|
||||
# Submit at least two batches to allow maximum overlap
|
||||
|
@ -242,10 +242,6 @@ def prepare_batch(
|
|||
and self.input_queue.qsize() + dataloader.batch_size
|
||||
< self.input_queue.maxsize
|
||||
):
|
||||
try:
|
||||
data = next(self.data_generator)
|
||||
except StopIteration:
|
||||
self.data_generator = iter(dataloader)
|
||||
data = next(self.data_generator)
|
||||
for item in data:
|
||||
# submit data into input_queue
|
||||
|
@ -264,18 +260,13 @@ rollout = RemoteSGLangEngine(config.rollout)
|
|||
rollout.initialize()
|
||||
eval_rollout = ...
|
||||
|
||||
data_generator = iter(train_dataloader)
|
||||
data_generator = iterools.cycle(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
# rollout batched training data for current step
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||
```
|
||||
|
||||
If you want to use rollout workflows with custom reward functions or agentic tool
|
||||
|
@ -375,17 +366,12 @@ Now a complete GRPO training step in AReaLite is done! The core logic of our exa
|
|||
training script can be summarized as:
|
||||
|
||||
```python
|
||||
data_generator = iter(train_dataloader)
|
||||
data_generator = itertools.cycle(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||
|
||||
logp = actor.compute_logp(batch)
|
||||
batch["prox_logp"] = logp
|
||||
|
|
|
@ -75,7 +75,7 @@ and converting it into an `LLMRequest` object for the inference engine:
|
|||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... __init__ method above ...
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
|
||||
# Initialize result containers
|
||||
seq, logprobs, loss_mask, versions = [], [], [], []
|
||||
messages = data["messages"]
|
||||
|
@ -119,7 +119,7 @@ we'll apply a discount, add feedback to the conversation, and let the model try
|
|||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... previous methods ...
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
|
||||
# ... initialization code ...
|
||||
while reward == 0 and t < self.max_turns:
|
||||
# Add feedback if the previous answer was incorrect
|
||||
|
@ -190,7 +190,7 @@ Finally, let's complete the implementation by collecting trajectories in the
|
|||
class MultiTurnWorkflow(RolloutWorkflow):
|
||||
# ... previous methods ...
|
||||
|
||||
async def arun_episode(self, engine: InferenceEngine, data):
|
||||
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
|
||||
# ... episode logic above ...
|
||||
|
||||
while reward == 0 and t < self.max_turns:
|
||||
|
@ -417,8 +417,27 @@ Using your custom workflow is straightforward—just create it in your training
|
|||
pass it to the `rollout_batch` or `prepare_batch` method:
|
||||
|
||||
```python
|
||||
# in realhf/impl/agent/__init__.py
|
||||
import realhf.impl.agent.math_multi_turn_agent
|
||||
def main(args):
|
||||
# ... setup code ...
|
||||
|
||||
# Create your custom workflow
|
||||
workflow = MultiTurnWorkflow(
|
||||
reward_fn=gsm8k_reward_fn,
|
||||
gconfig=config.gconfig,
|
||||
tokenizer=tokenizer,
|
||||
turn_discount=0.9,
|
||||
max_turns=5,
|
||||
)
|
||||
|
||||
# Run training—no other changes needed!
|
||||
data_generator = itertools.cycle(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||
# ... continue with training loop ...
|
||||
```
|
||||
|
||||
Then update your experiment configuration in
|
||||
|
|
|
@ -171,16 +171,11 @@ def main(args):
|
|||
)
|
||||
|
||||
# Main training loop
|
||||
data_generator = itertools.cycle(dataloader)
|
||||
for global_step in range(max_steps):
|
||||
# Generate training data
|
||||
with stats_tracker.record_timing("rollout"):
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
|
||||
|
|
|
@ -68,6 +68,7 @@ ref:
|
|||
trial_name: ${trial_name}
|
||||
path: ${actor.path}
|
||||
init_from_scratch: false
|
||||
disable_dropout: true
|
||||
dtype: ${actor.dtype}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
|
@ -83,7 +84,7 @@ sglang:
|
|||
dtype: ${actor.dtype}
|
||||
max_running_requests: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.9
|
||||
mem_fraction_static: 0.8
|
||||
|
||||
# datasets
|
||||
train_dataset:
|
||||
|
@ -131,4 +132,5 @@ stats_logger:
|
|||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
|
||||
wandb:
|
||||
mode: disabled
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import itertools
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
@ -129,7 +130,7 @@ def main(args):
|
|||
max_steps = total_epochs * steps_per_epoch
|
||||
|
||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
data_generator = iter(train_dataloader)
|
||||
data_generator = itertools.cycle(train_dataloader)
|
||||
for global_step in range(max_steps):
|
||||
epoch = global_step // steps_per_epoch
|
||||
step = global_step % steps_per_epoch
|
||||
|
@ -138,12 +139,7 @@ def main(args):
|
|||
if config.async_training:
|
||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||
else:
|
||||
try:
|
||||
data = next(data_generator)
|
||||
except StopIteration:
|
||||
data_generator = iter(train_dataloader)
|
||||
data = next(data_generator)
|
||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
||||
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||
|
||||
batch = batch.to(actor.device)
|
||||
# Create barrier to synchronize all rollout processes.
|
||||
|
|
|
@ -6,7 +6,7 @@ pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 "deepspeed>=0.17.
|
|||
pip install "sglang[all]==0.4.9.post2"
|
||||
pip install megatron-core==0.11.0 nvidia-ml-py
|
||||
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
|
||||
pip install "flash-attn<=2.7.3" --no-build-isolation
|
||||
pip install "flash-attn<=2.8.2" --no-build-isolation
|
||||
|
||||
# Package used for calculating math reward
|
||||
pip install -e evaluation/latex2sympy
|
||||
|
|
|
@ -56,6 +56,8 @@ dependencies = [
|
|||
"torchdata",
|
||||
"autoflake",
|
||||
"tensordict",
|
||||
"pybase64",
|
||||
"msgspec",
|
||||
|
||||
# Monitoring and logging
|
||||
"wandb",
|
||||
|
|
|
@ -74,3 +74,6 @@ torchdata
|
|||
autoflake
|
||||
tensordict
|
||||
deepspeed>=0.17.2
|
||||
pybase64
|
||||
msgspec
|
||||
transformers==4.53.1
|
Loading…
Reference in New Issue