mirror of https://github.com/inclusionAI/AReaL
[lite] [fix] Fix a performance issue and several minor issues before release (#203)
* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing * PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * fix * . * PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * fix destroy process group * PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/358 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * fix loss mask * fix * . * PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/368 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/371 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher Merge branch mzy/lite/launcher of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/370 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * . * . * . * fix * PullRequest: 392 [lite] Fix several bugs regarding RL learning and add an example to reproduce boba-math results. Merge branch fw/lite-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/392 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * . * fix * . * qwen2 grpo works * fix * fix * async works * fix * slurm launcher not tested * fix arg parse * . * sglang server wrapper * . * . * slurm run * ready for boba * debug * 32k run * . * . * fix * . * . * . * . * . * fix * . * fix * . * . * . * . * fix * . * . * . * . * . * . * . * refactor train engine * refactor train engine * . * fix update weight error * . * . * match train * format * . * fix * seems to work * . * . * . * . * . * PullRequest: 408 [Feature] Bump SGLang version to v0.4.9.post2 Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/408 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * bump arealite to sglang 0.4.9.post2 * . * PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta` * PullRequest: 422 [lite] Fix tests and scripts after updating sgl to 0.4.9 Merge branch fw/sgl049 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/422 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * bump arealite to sglang 0.4.9.post2 * . * PullRequest: 412 [lite] Minor refactor on `UpdateWeightMeta` * . * PullRequest: 423 [lite] Remove the boba example for github release. Merge branch fw/remove-boba of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/423 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * update readme * PullRequest: 431 [Fix] Fix environment of lite Merge branch fw/lite-dev of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/431 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * change requirements * . * . * . * PullRequest: 440 [FIX] fix update weight from disk Merge branch sxj/lite-fix-disk-update of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/440 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * [FIX] fix update weight from disk * PullRequest: 442 [lite] Refactor `RemoteSGLangEngine` into two parts: `RemoteSGLangEngine` and `WorkflowExecutor`. Merge branch mzy/workflow-executor of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/442 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * refactor workflow executor * . * fix tests and eval * . * . * revert workflow executor into remote sglang engine * . * PullRequest: 456 [lite] [Bug] Use `ProcessPoolExecutor` to calculate reward to avoid rollout slow down Merge branch mzy/lite/fix-reward of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/456?tab=comment Reviewed-by: 博惟 <bowei.fw@antgroup.com> * fix reward * . * . * . * PullRequest: 460 [lite][fix] add a warning when reward computation timeout Merge branch fw/lite-fix of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/460 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * add a warning when reward computation timeout * PullRequest: 465 [lite][fix] Fix issues raised by tsao Merge branch fw/lite-fix of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/465 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix --------- Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com> Co-authored-by: 冰临 <shenxujie.sxj@antgroup.com>
This commit is contained in:
parent
7fb6a80e48
commit
e507ce281c
|
@ -195,7 +195,10 @@ class TrainEngineConfig:
|
||||||
gradient_checkpointing: bool = field(
|
gradient_checkpointing: bool = field(
|
||||||
default=True, metadata={"help": "Enable gradient checkpointing"}
|
default=True, metadata={"help": "Enable gradient checkpointing"}
|
||||||
)
|
)
|
||||||
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."})
|
dtype: str = field(default="bfloat16", metadata={"help": "Parameter dtype."})
|
||||||
|
grad_reduce_dtype: str = field(
|
||||||
|
default="float32", metadata={"help": "Gradient reduce dtype."}
|
||||||
|
)
|
||||||
optimizer: Optional[OptimizerConfig] = field(
|
optimizer: Optional[OptimizerConfig] = field(
|
||||||
default=None, metadata={"help": "Optimizer configuration"}
|
default=None, metadata={"help": "Optimizer configuration"}
|
||||||
)
|
)
|
||||||
|
@ -262,7 +265,7 @@ class PPOActorConfig(TrainEngineConfig):
|
||||||
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
default=1.0, metadata={"help": "Lambda parameter for GAE"}
|
||||||
)
|
)
|
||||||
adv_norm: bool = field(
|
adv_norm: bool = field(
|
||||||
default=True, metadata={"help": "Enable advantage normalization"}
|
default=True, metadata={"help": "Enable advantage normalization globally"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# KL Control
|
# KL Control
|
||||||
|
@ -275,7 +278,9 @@ class PPOActorConfig(TrainEngineConfig):
|
||||||
)
|
)
|
||||||
use_decoupled_loss: bool = field(
|
use_decoupled_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
metadata={
|
||||||
|
"help": "Use the decoupled loss. Implicitly enable recompute_logprob."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
behav_imp_weight_cap: Optional[float] = field(
|
behav_imp_weight_cap: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -329,7 +334,7 @@ class SGLangConfig:
|
||||||
schedule_policy: str = "lpm"
|
schedule_policy: str = "lpm"
|
||||||
schedule_conservativeness: float = 1.0
|
schedule_conservativeness: float = 1.0
|
||||||
cpu_offload_gb: int = 0
|
cpu_offload_gb: int = 0
|
||||||
dtype: str = "float16"
|
dtype: str = "bfloat16"
|
||||||
kv_cache_dtype: str = "auto"
|
kv_cache_dtype: str = "auto"
|
||||||
# logging
|
# logging
|
||||||
log_level: str = "warning"
|
log_level: str = "warning"
|
||||||
|
@ -407,31 +412,8 @@ class SGLangConfig:
|
||||||
dist_init_addr=dist_init_addr,
|
dist_init_addr=dist_init_addr,
|
||||||
**args,
|
**args,
|
||||||
)
|
)
|
||||||
sglang_version = pkg_version.get_version("sglang")
|
if not pkg_version.is_version_greater_or_equal("sglang", "0.4.9.post2"):
|
||||||
if sglang_version:
|
raise RuntimeError("Needs sglang>=0.4.9.post2 to run the code.")
|
||||||
version_less_than_0_4_4 = (
|
|
||||||
pkg_version.compare_versions(sglang_version, "0.4.4") < 0
|
|
||||||
)
|
|
||||||
version_less_than_0_4_3 = (
|
|
||||||
pkg_version.compare_versions(sglang_version, "0.4.3") < 0
|
|
||||||
)
|
|
||||||
elif pkg_version.is_available("sglang"):
|
|
||||||
version_less_than_0_4_4 = pkg_version.is_version_less("sglang", "0.4.4")
|
|
||||||
version_less_than_0_4_3 = pkg_version.is_version_less("sglang", "0.4.3")
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"A installed SGLang package or a specific SGLang version should be provided to build SGLang server cmd."
|
|
||||||
)
|
|
||||||
if version_less_than_0_4_4:
|
|
||||||
args.pop("log_requests_level")
|
|
||||||
if version_less_than_0_4_3:
|
|
||||||
args.pop("enable_nccl_nvls")
|
|
||||||
args.pop("triton_attention_num_kv_splits")
|
|
||||||
args.pop("cuda_graph_bs")
|
|
||||||
args.pop("enable_memory_saver")
|
|
||||||
args.pop("allow_auto_truncate")
|
|
||||||
args.pop("file_storage_path")
|
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
@ -466,7 +448,7 @@ class InferenceEngineConfig:
|
||||||
default="round_robin",
|
default="round_robin",
|
||||||
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
|
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
|
||||||
)
|
)
|
||||||
setup_timeout: float = field(default=90.0)
|
setup_timeout: float = field(default=120.0)
|
||||||
request_timeout: float = field(
|
request_timeout: float = field(
|
||||||
default=3600, metadata={"help": "Timeout for HTTP requests."}
|
default=3600, metadata={"help": "Timeout for HTTP requests."}
|
||||||
)
|
)
|
||||||
|
|
|
@ -174,6 +174,7 @@ class InferenceEngine(abc.ABC):
|
||||||
self,
|
self,
|
||||||
dataloader: StatefulDataLoader,
|
dataloader: StatefulDataLoader,
|
||||||
workflow: "RolloutWorkflow",
|
workflow: "RolloutWorkflow",
|
||||||
|
should_accept: Callable | None = None,
|
||||||
):
|
):
|
||||||
"""Asynchronously submit and wait until a full batch is ready."""
|
"""Asynchronously submit and wait until a full batch is ready."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -1,10 +1,30 @@
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
import asyncio
|
||||||
|
import itertools
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
import uvloop
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
|
||||||
|
from arealite.api.cli_args import InferenceEngineConfig
|
||||||
|
from arealite.api.engine_api import InferenceEngine
|
||||||
|
from arealite.api.io_struct import RolloutStat
|
||||||
|
from arealite.utils.data import concat_padded_tensors
|
||||||
|
from realhf.base import logging
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from arealite.api.engine_api import InferenceEngine
|
from arealite.api.engine_api import InferenceEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger("arealite.workflow_api")
|
||||||
|
|
||||||
|
|
||||||
|
ROLLOUT_POLL_WAIT_TIME = 0.05
|
||||||
|
|
||||||
|
|
||||||
class RolloutWorkflow:
|
class RolloutWorkflow:
|
||||||
|
|
||||||
|
@ -16,3 +36,226 @@ class RolloutWorkflow:
|
||||||
See concrete example implementations under the `arealite/workflow` directory.
|
See concrete example implementations under the `arealite/workflow` directory.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowExecutor:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: InferenceEngineConfig,
|
||||||
|
inference_engine: "InferenceEngine",
|
||||||
|
):
|
||||||
|
config.max_concurrent_rollouts = (
|
||||||
|
config.max_concurrent_rollouts or config.consumer_batch_size
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.exiting = threading.Event()
|
||||||
|
self.paused = threading.Event()
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
self.inference_engine = inference_engine
|
||||||
|
|
||||||
|
qsize = config.queue_size or config.max_concurrent_rollouts * 16
|
||||||
|
self.input_queue = queue.Queue(maxsize=qsize)
|
||||||
|
self.output_queue = queue.Queue(maxsize=qsize)
|
||||||
|
self.result_cache: List[TensorDict] = []
|
||||||
|
|
||||||
|
self.rollout_stat = RolloutStat()
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
self.rollout_tasks: Dict[str, asyncio.Task] = {}
|
||||||
|
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
||||||
|
self.rollout_thread.start()
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
self.exiting.set()
|
||||||
|
self.rollout_thread.join()
|
||||||
|
|
||||||
|
def get_capacity(self):
|
||||||
|
if dist.is_initialized():
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
else:
|
||||||
|
world_size = 1
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
max_concurrent_rollouts = max(
|
||||||
|
1, self.config.max_concurrent_rollouts // world_size
|
||||||
|
)
|
||||||
|
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
|
||||||
|
# Staleness control
|
||||||
|
version = self.inference_engine.get_version()
|
||||||
|
ofp = self.config.max_head_offpolicyness
|
||||||
|
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
||||||
|
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
|
||||||
|
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
|
||||||
|
return capacity
|
||||||
|
|
||||||
|
def _rollout_thread(self):
|
||||||
|
"""Thread that runs the rollout loop."""
|
||||||
|
try:
|
||||||
|
uvloop.run(self._rollout_thread_async())
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def _rollout_thread_async(self):
|
||||||
|
rollout_tasks = self.rollout_tasks
|
||||||
|
rid = 0
|
||||||
|
try:
|
||||||
|
while not self.exiting.is_set():
|
||||||
|
# Check capacity
|
||||||
|
capacity = self.get_capacity()
|
||||||
|
# Create new rollout task
|
||||||
|
self.lock.acquire()
|
||||||
|
while (
|
||||||
|
capacity > 0
|
||||||
|
and not self.paused.is_set()
|
||||||
|
and self.input_queue.qsize() > 0
|
||||||
|
):
|
||||||
|
data, workflow = self.input_queue.get_nowait()
|
||||||
|
logger.debug(f"Get data from puller: {data}")
|
||||||
|
task = asyncio.create_task(
|
||||||
|
workflow.arun_episode(self.inference_engine, data),
|
||||||
|
name=str(rid),
|
||||||
|
)
|
||||||
|
rollout_tasks[str(rid)] = task
|
||||||
|
self.rollout_stat.submitted += 1
|
||||||
|
self.rollout_stat.running += 1
|
||||||
|
if self.config.enable_rollout_tracing:
|
||||||
|
logger.info(
|
||||||
|
f"Submit rollout rid {rid}. "
|
||||||
|
f"Submit: {self.rollout_stat.submitted}, "
|
||||||
|
f"running: {self.rollout_stat.running}, "
|
||||||
|
f"accepted: {self.rollout_stat.accepted}."
|
||||||
|
)
|
||||||
|
capacity -= 1
|
||||||
|
rid += 1
|
||||||
|
tasks = list(rollout_tasks.values())
|
||||||
|
self.lock.release()
|
||||||
|
|
||||||
|
# Wait for rollout completion
|
||||||
|
done = []
|
||||||
|
if tasks:
|
||||||
|
done, _ = await asyncio.wait(
|
||||||
|
tasks,
|
||||||
|
timeout=ROLLOUT_POLL_WAIT_TIME,
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
# Collect done results
|
||||||
|
for task in done:
|
||||||
|
traj = await task
|
||||||
|
traj: TensorDict
|
||||||
|
task_rid = task.get_name()
|
||||||
|
with self.lock:
|
||||||
|
rollout_tasks.pop(task_rid)
|
||||||
|
self.rollout_stat.accepted += 1
|
||||||
|
|
||||||
|
self.rollout_stat.running -= 1
|
||||||
|
if self.config.enable_rollout_tracing:
|
||||||
|
logger.info(
|
||||||
|
f"Finish rollout {task_rid}. "
|
||||||
|
f"Submit: {self.rollout_stat.submitted}, "
|
||||||
|
f"running: {self.rollout_stat.running}, "
|
||||||
|
f"accepted: {self.rollout_stat.accepted}."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.output_queue.put_nowait(traj)
|
||||||
|
except queue.Full:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Output queue full. Please increase queue_size."
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# Cancel remaining tasks
|
||||||
|
with self.lock:
|
||||||
|
for task in rollout_tasks.values():
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
||||||
|
try:
|
||||||
|
self.input_queue.put_nowait((data, workflow))
|
||||||
|
except queue.Full:
|
||||||
|
raise RuntimeError("Input queue full. Please increase queue_size.")
|
||||||
|
|
||||||
|
def wait(
|
||||||
|
self,
|
||||||
|
count: int,
|
||||||
|
timeout: float | None = None,
|
||||||
|
should_accept: Callable | None = None,
|
||||||
|
) -> TensorDict:
|
||||||
|
tik = time.perf_counter()
|
||||||
|
accepted = len(self.result_cache)
|
||||||
|
timeout = timeout or float(7 * 24 * 3600)
|
||||||
|
while (
|
||||||
|
accepted < count
|
||||||
|
and not self.exiting.is_set()
|
||||||
|
and time.perf_counter() - tik < timeout
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
|
||||||
|
if should_accept is None or should_accept(result):
|
||||||
|
self.result_cache.append(result)
|
||||||
|
accepted += 1
|
||||||
|
else:
|
||||||
|
with self.lock:
|
||||||
|
self.rollout_stat.accepted -= 1
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
if self.exiting.is_set():
|
||||||
|
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
|
||||||
|
if accepted < count:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Timed out waiting for {count} rollouts, " f"only received {accepted}."
|
||||||
|
)
|
||||||
|
results, self.result_cache = (
|
||||||
|
self.result_cache[:count],
|
||||||
|
self.result_cache[count:],
|
||||||
|
)
|
||||||
|
return concat_padded_tensors(results)
|
||||||
|
|
||||||
|
def rollout_batch(
|
||||||
|
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||||
|
) -> TensorDict:
|
||||||
|
"""Submit a batch of requests to the inference engine and wait for the results."""
|
||||||
|
for item in data:
|
||||||
|
self.submit(item, workflow)
|
||||||
|
return self.wait(count=len(data))
|
||||||
|
|
||||||
|
def prepare_batch(
|
||||||
|
self,
|
||||||
|
dataloader: StatefulDataLoader,
|
||||||
|
workflow: "RolloutWorkflow",
|
||||||
|
should_accept: Callable | None = None,
|
||||||
|
):
|
||||||
|
if not hasattr(self, "data_generator"):
|
||||||
|
self.data_generator = itertools.cycle(dataloader)
|
||||||
|
assert dataloader.batch_size is not None
|
||||||
|
while True:
|
||||||
|
# Submit at least two batches to allow maximum overlap
|
||||||
|
if (
|
||||||
|
self.get_capacity() + dataloader.batch_size > 0
|
||||||
|
and self.input_queue.qsize() + dataloader.batch_size
|
||||||
|
< self.input_queue.maxsize
|
||||||
|
):
|
||||||
|
data = next(self.data_generator)
|
||||||
|
for item in data:
|
||||||
|
self.submit(item, workflow=workflow)
|
||||||
|
try:
|
||||||
|
return self.wait(
|
||||||
|
dataloader.batch_size, timeout=1, should_accept=should_accept
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def pause(self):
|
||||||
|
self.paused.set()
|
||||||
|
|
||||||
|
def resume(self):
|
||||||
|
self.paused.clear()
|
||||||
|
|
|
@ -55,7 +55,7 @@ class FSDPEngine(BaseHFEngine):
|
||||||
# Simple auto wrap policy
|
# Simple auto wrap policy
|
||||||
self.mixed_precision_policy = MixedPrecisionPolicy(
|
self.mixed_precision_policy = MixedPrecisionPolicy(
|
||||||
param_dtype=getattr(torch, self.config.dtype),
|
param_dtype=getattr(torch, self.config.dtype),
|
||||||
reduce_dtype=torch.float32,
|
reduce_dtype=getattr(torch, self.config.grad_reduce_dtype),
|
||||||
cast_forward_inputs=True,
|
cast_forward_inputs=True,
|
||||||
)
|
)
|
||||||
self.device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size)
|
self.device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size)
|
||||||
|
|
|
@ -128,7 +128,7 @@ class PPOActor:
|
||||||
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
||||||
|
|
||||||
# Optionally perform advantage normalization.
|
# Optionally perform advantage normalization.
|
||||||
if self.adv_norm:
|
if self.adv_norm or self.group_adv_norm:
|
||||||
if self.group_adv_norm:
|
if self.group_adv_norm:
|
||||||
adv_list = []
|
adv_list = []
|
||||||
for i in range(0, bs, self.group_size):
|
for i in range(0, bs, self.group_size):
|
||||||
|
|
|
@ -2,17 +2,13 @@ import asyncio
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
from concurrent.futures import Future, ProcessPoolExecutor
|
from concurrent.futures import Future, ProcessPoolExecutor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from queue import Empty, Full, Queue
|
from typing import Any, Callable, Dict, List
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
import torch.distributed as dist
|
|
||||||
import uvloop
|
import uvloop
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
|
@ -28,25 +24,19 @@ from arealite.api.io_struct import (
|
||||||
VLMResponse,
|
VLMResponse,
|
||||||
WeightUpdateMeta,
|
WeightUpdateMeta,
|
||||||
)
|
)
|
||||||
from arealite.utils.data import concat_padded_tensors
|
from arealite.api.workflow_api import RolloutWorkflow, WorkflowExecutor
|
||||||
from arealite.utils.http import arequest_with_retry, get_default_connector
|
from arealite.utils.http import arequest_with_retry, get_default_connector
|
||||||
from realhf.base import logging, name_resolve, names
|
from realhf.base import logging, name_resolve, names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from arealite.api.workflow_api import RolloutWorkflow
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
ROLLOUT_POLL_WAIT_TIME = 0.05
|
|
||||||
RID_CACHE_SIZE = 128
|
RID_CACHE_SIZE = 128
|
||||||
|
|
||||||
|
|
||||||
class RemoteSGLangEngine(InferenceEngine):
|
class RemoteSGLangEngine(InferenceEngine):
|
||||||
|
|
||||||
def __init__(self, config: InferenceEngineConfig):
|
def __init__(self, config: InferenceEngineConfig):
|
||||||
config.max_concurrent_rollouts = (
|
|
||||||
config.max_concurrent_rollouts or config.consumer_batch_size
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.rid_to_address = {}
|
self.rid_to_address = {}
|
||||||
|
@ -54,29 +44,20 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
self.rid_queue = []
|
self.rid_queue = []
|
||||||
|
|
||||||
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
|
self.addresses = os.getenv("AREAL_LLM_SERVER_ADDRS").split(",")
|
||||||
|
self.session = None
|
||||||
|
|
||||||
if not self.addresses:
|
if not self.addresses:
|
||||||
raise RuntimeError("No configured SGLang servers.")
|
raise RuntimeError("No configured SGLang servers.")
|
||||||
logger.info("Waiting for server ready...")
|
|
||||||
for addr in self.addresses:
|
|
||||||
self._wait_for_server(addr)
|
|
||||||
logger.info("Servers are all ready!")
|
|
||||||
|
|
||||||
self.server_idx = random.randint(0, len(self.addresses) - 1)
|
self.server_idx = random.randint(0, len(self.addresses) - 1)
|
||||||
|
|
||||||
qsize = config.queue_size or config.max_concurrent_rollouts * 16
|
|
||||||
self.input_queue = Queue(maxsize=qsize)
|
|
||||||
self.output_queue = Queue(maxsize=qsize)
|
|
||||||
self.result_cache = []
|
|
||||||
|
|
||||||
self.exiting = threading.Event()
|
|
||||||
self.paused = threading.Event()
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
|
|
||||||
self.rollout_stat = RolloutStat()
|
|
||||||
self.distributed_weight_update_initialized = False
|
self.distributed_weight_update_initialized = False
|
||||||
|
|
||||||
self._version = 0
|
self._version = 0
|
||||||
|
|
||||||
|
self.workflow_executor = WorkflowExecutor(
|
||||||
|
config=config,
|
||||||
|
inference_engine=self,
|
||||||
|
)
|
||||||
|
|
||||||
def _wait_for_server(self, address):
|
def _wait_for_server(self, address):
|
||||||
base_url = f"http://{address}"
|
base_url = f"http://{address}"
|
||||||
tik = time.time()
|
tik = time.time()
|
||||||
|
@ -94,137 +75,46 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
|
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None = None):
|
||||||
self.rollout_tasks: Dict[str, asyncio.Task] = {}
|
logger.info("Waiting for server ready...")
|
||||||
|
for addr_ in self.addresses:
|
||||||
|
self._wait_for_server(addr_)
|
||||||
|
logger.info("Servers are all ready!")
|
||||||
self.executor = ProcessPoolExecutor(max_workers=1)
|
self.executor = ProcessPoolExecutor(max_workers=1)
|
||||||
self.rollout_thread = threading.Thread(target=self._rollout_thread)
|
self.workflow_executor.initialize()
|
||||||
self.rollout_thread.start()
|
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
self.executor.shutdown()
|
self.executor.shutdown()
|
||||||
self.exiting.set()
|
|
||||||
self.rollout_thread.join()
|
|
||||||
|
|
||||||
def set_version(self, version):
|
def set_version(self, version):
|
||||||
with self.lock:
|
self._version = version
|
||||||
self._version = version
|
|
||||||
|
|
||||||
def get_version(self):
|
def get_version(self):
|
||||||
with self.lock:
|
return self._version
|
||||||
return self._version
|
|
||||||
|
|
||||||
def _rollout_thread(self):
|
|
||||||
"""Thread that runs the rollout loop."""
|
|
||||||
try:
|
|
||||||
uvloop.run(self._rollout_thread_async())
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
async def _rollout_thread_async(self):
|
|
||||||
rollout_tasks = self.rollout_tasks
|
|
||||||
rid = 0
|
|
||||||
|
|
||||||
# NOTE: session is not thread-safe, but we only submit requests in the sub-thread.
|
|
||||||
self.session = aiohttp.ClientSession(
|
|
||||||
timeout=aiohttp.ClientTimeout(
|
|
||||||
total=self.config.request_timeout,
|
|
||||||
sock_connect=self.config.request_timeout,
|
|
||||||
connect=self.config.request_timeout,
|
|
||||||
),
|
|
||||||
read_bufsize=1024 * 1024 * 10,
|
|
||||||
connector=get_default_connector(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while not self.exiting.is_set():
|
|
||||||
# Check capacity
|
|
||||||
capacity = self.get_capacity()
|
|
||||||
# Create new rollout task
|
|
||||||
while (
|
|
||||||
capacity > 0
|
|
||||||
and not self.paused.is_set()
|
|
||||||
and self.input_queue.qsize() > 0
|
|
||||||
):
|
|
||||||
data, workflow = self.input_queue.get_nowait()
|
|
||||||
logger.debug(f"Get data from puller: {data}")
|
|
||||||
task = asyncio.create_task(
|
|
||||||
workflow.arun_episode(self, data), name=str(rid)
|
|
||||||
)
|
|
||||||
with self.lock:
|
|
||||||
rollout_tasks[str(rid)] = task
|
|
||||||
self.rollout_stat.submitted += 1
|
|
||||||
self.rollout_stat.running += 1
|
|
||||||
if self.config.enable_rollout_tracing:
|
|
||||||
logger.info(
|
|
||||||
f"Submit rollout rid {rid}. "
|
|
||||||
f"Submit: {self.rollout_stat.submitted}, "
|
|
||||||
f"running: {self.rollout_stat.running}, "
|
|
||||||
f"accepted: {self.rollout_stat.accepted}."
|
|
||||||
)
|
|
||||||
capacity -= 1
|
|
||||||
rid += 1
|
|
||||||
# Wait for rollout completion
|
|
||||||
with self.lock:
|
|
||||||
tasks = list(rollout_tasks.values())
|
|
||||||
done = []
|
|
||||||
if tasks:
|
|
||||||
done, _ = await asyncio.wait(
|
|
||||||
tasks,
|
|
||||||
timeout=ROLLOUT_POLL_WAIT_TIME,
|
|
||||||
return_when=asyncio.FIRST_COMPLETED,
|
|
||||||
)
|
|
||||||
# Collect done results
|
|
||||||
for task in done:
|
|
||||||
traj = await task
|
|
||||||
traj: TensorDict
|
|
||||||
task_rid = task.get_name()
|
|
||||||
with self.lock:
|
|
||||||
rollout_tasks.pop(task_rid)
|
|
||||||
self.rollout_stat.accepted += 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.output_queue.put_nowait(traj)
|
|
||||||
except Full:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Output queue full. Please increase queue_size."
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
self.rollout_stat.running -= 1
|
|
||||||
if self.config.enable_rollout_tracing:
|
|
||||||
logger.info(
|
|
||||||
f"Finish rollout {task_rid}. "
|
|
||||||
f"Submit: {self.rollout_stat.submitted}, "
|
|
||||||
f"running: {self.rollout_stat.running}, "
|
|
||||||
f"accepted: {self.rollout_stat.accepted}."
|
|
||||||
)
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
finally:
|
|
||||||
# Cancel remaining tasks
|
|
||||||
with self.lock:
|
|
||||||
for task in rollout_tasks.values():
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
try:
|
|
||||||
await task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def choose_server(self) -> str:
|
def choose_server(self) -> str:
|
||||||
with self.lock:
|
if self.config.schedule_policy == "round_robin":
|
||||||
if self.config.schedule_policy == "round_robin":
|
server = self.addresses[self.server_idx]
|
||||||
server = self.addresses[self.server_idx]
|
self.server_idx = (self.server_idx + 1) % len(self.addresses)
|
||||||
self.server_idx = (self.server_idx + 1) % len(self.addresses)
|
return server
|
||||||
return server
|
|
||||||
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
raise NotImplementedError("Only round-robin scheduling is implemented.")
|
||||||
|
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self, req: LLMRequest | VLMRequest
|
self, req: LLMRequest | VLMRequest
|
||||||
) -> LLMResponse | VLMResponse:
|
) -> LLMResponse | VLMResponse:
|
||||||
"""Async version of generate using aiohttp."""
|
"""Async version of generate using aiohttp."""
|
||||||
|
if self.session is None:
|
||||||
|
# NOTE: Lazily initialize aiohttp.ClientSession since it needs to be initialized
|
||||||
|
# inside asyncio loop in WorkflowExecutor
|
||||||
|
self.session = aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
total=self.config.request_timeout,
|
||||||
|
sock_connect=self.config.request_timeout,
|
||||||
|
connect=self.config.request_timeout,
|
||||||
|
),
|
||||||
|
read_bufsize=1024 * 1024 * 10,
|
||||||
|
connector=get_default_connector(),
|
||||||
|
)
|
||||||
# Prepare request payload
|
# Prepare request payload
|
||||||
gconfig = req.gconfig
|
gconfig = req.gconfig
|
||||||
stop_token_ids = gconfig.stop_token_ids
|
stop_token_ids = gconfig.stop_token_ids
|
||||||
|
@ -396,30 +286,8 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
fut.add_done_callback(callback)
|
fut.add_done_callback(callback)
|
||||||
return fut
|
return fut
|
||||||
|
|
||||||
def get_capacity(self):
|
def submit(self, data: Dict[str, Any], workflow: RolloutWorkflow) -> None:
|
||||||
if dist.is_initialized():
|
return self.workflow_executor.submit(data, workflow)
|
||||||
world_size = dist.get_world_size()
|
|
||||||
else:
|
|
||||||
world_size = 1
|
|
||||||
|
|
||||||
max_concurrent_rollouts = max(
|
|
||||||
1, self.config.max_concurrent_rollouts // world_size
|
|
||||||
)
|
|
||||||
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
|
|
||||||
# Staleness control
|
|
||||||
version = self.get_version()
|
|
||||||
ofp = self.config.max_head_offpolicyness
|
|
||||||
with self.lock:
|
|
||||||
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
|
|
||||||
consumer_bs = max(1, self.config.consumer_batch_size // world_size)
|
|
||||||
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
|
|
||||||
return capacity
|
|
||||||
|
|
||||||
def submit(self, data: Dict[str, Any], workflow: "RolloutWorkflow") -> None:
|
|
||||||
try:
|
|
||||||
self.input_queue.put_nowait((data, workflow))
|
|
||||||
except Full:
|
|
||||||
raise RuntimeError("Input queue full. Please increase queue_size.")
|
|
||||||
|
|
||||||
def wait(
|
def wait(
|
||||||
self,
|
self,
|
||||||
|
@ -427,77 +295,31 @@ class RemoteSGLangEngine(InferenceEngine):
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
should_accept: Callable | None = None,
|
should_accept: Callable | None = None,
|
||||||
) -> TensorDict:
|
) -> TensorDict:
|
||||||
tik = time.perf_counter()
|
return self.workflow_executor.wait(
|
||||||
accepted = len(self.result_cache)
|
count,
|
||||||
timeout = timeout or float(7 * 24 * 3600)
|
timeout=timeout,
|
||||||
while (
|
should_accept=should_accept,
|
||||||
accepted < count
|
|
||||||
and not self.exiting.is_set()
|
|
||||||
and time.perf_counter() - tik < timeout
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
result = self.output_queue.get(timeout=ROLLOUT_POLL_WAIT_TIME)
|
|
||||||
if should_accept is None or should_accept(result):
|
|
||||||
self.result_cache.append(result)
|
|
||||||
accepted += 1
|
|
||||||
else:
|
|
||||||
with self.lock:
|
|
||||||
self.rollout_stat.accepted -= 1
|
|
||||||
except Empty:
|
|
||||||
pass
|
|
||||||
if self.exiting.is_set():
|
|
||||||
raise RuntimeError("Rollout engine is exiting, cannot wait for results.")
|
|
||||||
if accepted < count:
|
|
||||||
raise TimeoutError(
|
|
||||||
f"Timed out waiting for {count} rollouts, " f"only received {accepted}."
|
|
||||||
)
|
|
||||||
results, self.result_cache = (
|
|
||||||
self.result_cache[:count],
|
|
||||||
self.result_cache[count:],
|
|
||||||
)
|
)
|
||||||
return concat_padded_tensors(results)
|
|
||||||
|
|
||||||
def rollout_batch(
|
def rollout_batch(
|
||||||
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
self, data: List[Dict[str, Any]], workflow: "RolloutWorkflow"
|
||||||
) -> TensorDict:
|
) -> TensorDict:
|
||||||
"""Submit a batch of requests to the inference engine and wait for the results."""
|
return self.workflow_executor.rollout_batch(data, workflow)
|
||||||
for item in data:
|
|
||||||
self.submit(item, workflow)
|
|
||||||
return self.wait(count=len(data))
|
|
||||||
|
|
||||||
def prepare_batch(
|
def prepare_batch(
|
||||||
self,
|
self,
|
||||||
dataloader: StatefulDataLoader,
|
dataloader: StatefulDataLoader,
|
||||||
workflow: "RolloutWorkflow",
|
workflow: RolloutWorkflow,
|
||||||
):
|
):
|
||||||
if not hasattr(self, "data_generator"):
|
return self.workflow_executor.prepare_batch(dataloader, workflow)
|
||||||
self.data_generator = iter(dataloader)
|
|
||||||
assert dataloader.batch_size is not None
|
|
||||||
while True:
|
|
||||||
# Submit at least two batches to allow maximum overlap
|
|
||||||
if (
|
|
||||||
self.get_capacity() + dataloader.batch_size > 0
|
|
||||||
and self.input_queue.qsize() + dataloader.batch_size
|
|
||||||
< self.input_queue.maxsize
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
data = next(self.data_generator)
|
|
||||||
|
|
||||||
except StopIteration:
|
|
||||||
self.data_generator = iter(dataloader)
|
|
||||||
data = next(self.data_generator)
|
|
||||||
for item in data:
|
|
||||||
self.submit(item, workflow=workflow)
|
|
||||||
try:
|
|
||||||
return self.wait(dataloader.batch_size, timeout=1)
|
|
||||||
except TimeoutError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def pause(self):
|
def pause(self):
|
||||||
self.paused.set()
|
"""Pause request submission for async rollout. Used during evaluation to prevent data over generation."""
|
||||||
|
return self.workflow_executor.pause()
|
||||||
|
|
||||||
def resume(self):
|
def resume(self):
|
||||||
self.paused.clear()
|
"""Resume request submission for async rollout."""
|
||||||
|
return self.workflow_executor.resume()
|
||||||
|
|
||||||
|
|
||||||
def update_weights_from_disk(
|
def update_weights_from_disk(
|
||||||
|
|
|
@ -94,6 +94,7 @@ class SGLangEngine(InferenceEngine):
|
||||||
asyncio.run(self._rollout_thread_async())
|
asyncio.run(self._rollout_thread_async())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
async def _rollout_thread_async(self):
|
async def _rollout_thread_async(self):
|
||||||
data = None
|
data = None
|
||||||
|
|
|
@ -324,7 +324,6 @@ def ray_main():
|
||||||
)
|
)
|
||||||
allocation_mode = config.allocation_mode
|
allocation_mode = config.allocation_mode
|
||||||
allocation_mode = AllocationMode.from_str(allocation_mode)
|
allocation_mode = AllocationMode.from_str(allocation_mode)
|
||||||
sglang_cmds = []
|
|
||||||
sglang_addrs = []
|
sglang_addrs = []
|
||||||
n_sglang_nodes = 0
|
n_sglang_nodes = 0
|
||||||
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
if allocation_mode.type_ == AllocationType.DECOUPLED_SGLANG:
|
||||||
|
|
|
@ -61,6 +61,7 @@ async def arequest_with_retry(
|
||||||
ctx = _session.delete(url, timeout=timeo)
|
ctx = _session.delete(url, timeout=timeo)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
|
|
||||||
async with ctx as response:
|
async with ctx as response:
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info("http requests return")
|
logger.info("http requests return")
|
||||||
|
|
|
@ -31,7 +31,7 @@ class MultiTurnWorkflow(RolloutWorkflow):
|
||||||
messages = data["messages"]
|
messages = data["messages"]
|
||||||
# Run multi-turn rollout until correct
|
# Run multi-turn rollout until correct
|
||||||
t = reward = 0
|
t = reward = 0
|
||||||
discount = 0
|
discount = 1
|
||||||
rid = uuid.uuid4().hex
|
rid = uuid.uuid4().hex
|
||||||
while reward == 0 and t < self.max_turns:
|
while reward == 0 and t < self.max_turns:
|
||||||
# Amend a prompt if the previous answer is incorrect
|
# Amend a prompt if the previous answer is incorrect
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
import colorama
|
import colorama
|
||||||
import torch
|
import torch
|
||||||
|
@ -12,6 +14,11 @@ from arealite.api.engine_api import InferenceEngine
|
||||||
from arealite.api.io_struct import LLMRequest
|
from arealite.api.io_struct import LLMRequest
|
||||||
from arealite.api.workflow_api import RolloutWorkflow
|
from arealite.api.workflow_api import RolloutWorkflow
|
||||||
from arealite.utils.data import concat_padded_tensors
|
from arealite.utils.data import concat_padded_tensors
|
||||||
|
from realhf.base import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("RLVR workflow")
|
||||||
|
|
||||||
|
REWARD_TIMEOUT_SECONDS = 15
|
||||||
|
|
||||||
|
|
||||||
class RLVRWorkflow(RolloutWorkflow):
|
class RLVRWorkflow(RolloutWorkflow):
|
||||||
|
@ -28,6 +35,7 @@ class RLVRWorkflow(RolloutWorkflow):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.enable_thinking = enable_thinking
|
self.enable_thinking = enable_thinking
|
||||||
self.dump_dir = dump_dir
|
self.dump_dir = dump_dir
|
||||||
|
self.rw_executor = ProcessPoolExecutor(max_workers=4)
|
||||||
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
|
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
|
||||||
os.makedirs(self.dump_dir, exist_ok=True)
|
os.makedirs(self.dump_dir, exist_ok=True)
|
||||||
|
|
||||||
|
@ -54,6 +62,7 @@ class RLVRWorkflow(RolloutWorkflow):
|
||||||
seqlens = []
|
seqlens = []
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
for resp in resps:
|
for resp in resps:
|
||||||
seq = resp.input_tokens + resp.output_tokens
|
seq = resp.input_tokens + resp.output_tokens
|
||||||
logprobs = [0.0] * resp.input_len + resp.output_logprobs
|
logprobs = [0.0] * resp.input_len + resp.output_logprobs
|
||||||
|
@ -65,13 +74,26 @@ class RLVRWorkflow(RolloutWorkflow):
|
||||||
prompt_strs.append(prompt_str)
|
prompt_strs.append(prompt_str)
|
||||||
completions_strs.append(completions_str)
|
completions_strs.append(completions_str)
|
||||||
seqlens.append(len(seq))
|
seqlens.append(len(seq))
|
||||||
reward = self.reward_fn(
|
try:
|
||||||
prompt=prompt_str,
|
reward = await asyncio.wait_for(
|
||||||
completions=completions_str,
|
loop.run_in_executor(
|
||||||
prompt_ids=resp.input_tokens,
|
self.rw_executor,
|
||||||
completion_ids=resp.output_tokens,
|
functools.partial(
|
||||||
**data,
|
self.reward_fn,
|
||||||
)
|
prompt_str,
|
||||||
|
completions_str,
|
||||||
|
resp.input_tokens,
|
||||||
|
resp.output_tokens,
|
||||||
|
**data,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
timeout=REWARD_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"Computing reward timeout after {REWARD_TIMEOUT_SECONDS}s. Set reward to 0."
|
||||||
|
)
|
||||||
|
reward = 0
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
res = dict(
|
res = dict(
|
||||||
# unsqueeze to add an additional batch dimension
|
# unsqueeze to add an additional batch dimension
|
||||||
|
|
|
@ -233,7 +233,7 @@ def prepare_batch(
|
||||||
workflow: "RolloutWorkflow",
|
workflow: "RolloutWorkflow",
|
||||||
):
|
):
|
||||||
if not hasattr(self, "data_generator"):
|
if not hasattr(self, "data_generator"):
|
||||||
self.data_generator = iter(dataloader)
|
self.data_generator = itertools.cycle(dataloader)
|
||||||
assert dataloader.batch_size is not None
|
assert dataloader.batch_size is not None
|
||||||
while True:
|
while True:
|
||||||
# Submit at least two batches to allow maximum overlap
|
# Submit at least two batches to allow maximum overlap
|
||||||
|
@ -242,11 +242,7 @@ def prepare_batch(
|
||||||
and self.input_queue.qsize() + dataloader.batch_size
|
and self.input_queue.qsize() + dataloader.batch_size
|
||||||
< self.input_queue.maxsize
|
< self.input_queue.maxsize
|
||||||
):
|
):
|
||||||
try:
|
data = next(self.data_generator)
|
||||||
data = next(self.data_generator)
|
|
||||||
except StopIteration:
|
|
||||||
self.data_generator = iter(dataloader)
|
|
||||||
data = next(self.data_generator)
|
|
||||||
for item in data:
|
for item in data:
|
||||||
# submit data into input_queue
|
# submit data into input_queue
|
||||||
self.submit(item, workflow=workflow)
|
self.submit(item, workflow=workflow)
|
||||||
|
@ -264,18 +260,13 @@ rollout = RemoteSGLangEngine(config.rollout)
|
||||||
rollout.initialize()
|
rollout.initialize()
|
||||||
eval_rollout = ...
|
eval_rollout = ...
|
||||||
|
|
||||||
data_generator = iter(train_dataloader)
|
data_generator = iterools.cycle(train_dataloader)
|
||||||
for global_step in range(max_steps):
|
for global_step in range(max_steps):
|
||||||
# rollout batched training data for current step
|
# rollout batched training data for current step
|
||||||
if config.async_training:
|
if config.async_training:
|
||||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||||
else:
|
else:
|
||||||
try:
|
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||||
data = next(data_generator)
|
|
||||||
except StopIteration:
|
|
||||||
data_generator = iter(train_dataloader)
|
|
||||||
data = next(data_generator)
|
|
||||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If you want to use rollout workflows with custom reward functions or agentic tool
|
If you want to use rollout workflows with custom reward functions or agentic tool
|
||||||
|
@ -375,17 +366,12 @@ Now a complete GRPO training step in AReaLite is done! The core logic of our exa
|
||||||
training script can be summarized as:
|
training script can be summarized as:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
data_generator = iter(train_dataloader)
|
data_generator = itertools.cycle(train_dataloader)
|
||||||
for global_step in range(max_steps):
|
for global_step in range(max_steps):
|
||||||
if config.async_training:
|
if config.async_training:
|
||||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||||
else:
|
else:
|
||||||
try:
|
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||||
data = next(data_generator)
|
|
||||||
except StopIteration:
|
|
||||||
data_generator = iter(train_dataloader)
|
|
||||||
data = next(data_generator)
|
|
||||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
|
||||||
|
|
||||||
logp = actor.compute_logp(batch)
|
logp = actor.compute_logp(batch)
|
||||||
batch["prox_logp"] = logp
|
batch["prox_logp"] = logp
|
||||||
|
|
|
@ -75,7 +75,7 @@ and converting it into an `LLMRequest` object for the inference engine:
|
||||||
class MultiTurnWorkflow(RolloutWorkflow):
|
class MultiTurnWorkflow(RolloutWorkflow):
|
||||||
# ... __init__ method above ...
|
# ... __init__ method above ...
|
||||||
|
|
||||||
async def arun_episode(self, engine: InferenceEngine, data):
|
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
|
||||||
# Initialize result containers
|
# Initialize result containers
|
||||||
seq, logprobs, loss_mask, versions = [], [], [], []
|
seq, logprobs, loss_mask, versions = [], [], [], []
|
||||||
messages = data["messages"]
|
messages = data["messages"]
|
||||||
|
@ -119,7 +119,7 @@ we'll apply a discount, add feedback to the conversation, and let the model try
|
||||||
class MultiTurnWorkflow(RolloutWorkflow):
|
class MultiTurnWorkflow(RolloutWorkflow):
|
||||||
# ... previous methods ...
|
# ... previous methods ...
|
||||||
|
|
||||||
async def arun_episode(self, engine: InferenceEngine, data):
|
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
|
||||||
# ... initialization code ...
|
# ... initialization code ...
|
||||||
while reward == 0 and t < self.max_turns:
|
while reward == 0 and t < self.max_turns:
|
||||||
# Add feedback if the previous answer was incorrect
|
# Add feedback if the previous answer was incorrect
|
||||||
|
@ -190,7 +190,7 @@ Finally, let's complete the implementation by collecting trajectories in the
|
||||||
class MultiTurnWorkflow(RolloutWorkflow):
|
class MultiTurnWorkflow(RolloutWorkflow):
|
||||||
# ... previous methods ...
|
# ... previous methods ...
|
||||||
|
|
||||||
async def arun_episode(self, engine: InferenceEngine, data):
|
async def arun_episode(self, engine: InferenceEngine, data) -> TensorDict:
|
||||||
# ... episode logic above ...
|
# ... episode logic above ...
|
||||||
|
|
||||||
while reward == 0 and t < self.max_turns:
|
while reward == 0 and t < self.max_turns:
|
||||||
|
@ -417,8 +417,27 @@ Using your custom workflow is straightforward—just create it in your training
|
||||||
pass it to the `rollout_batch` or `prepare_batch` method:
|
pass it to the `rollout_batch` or `prepare_batch` method:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# in realhf/impl/agent/__init__.py
|
def main(args):
|
||||||
import realhf.impl.agent.math_multi_turn_agent
|
# ... setup code ...
|
||||||
|
|
||||||
|
# Create your custom workflow
|
||||||
|
workflow = MultiTurnWorkflow(
|
||||||
|
reward_fn=gsm8k_reward_fn,
|
||||||
|
gconfig=config.gconfig,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
turn_discount=0.9,
|
||||||
|
max_turns=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run training—no other changes needed!
|
||||||
|
data_generator = itertools.cycle(train_dataloader)
|
||||||
|
for global_step in range(max_steps):
|
||||||
|
with stats_tracker.record_timing("rollout"):
|
||||||
|
if config.async_training:
|
||||||
|
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||||
|
else:
|
||||||
|
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||||
|
# ... continue with training loop ...
|
||||||
```
|
```
|
||||||
|
|
||||||
Then update your experiment configuration in
|
Then update your experiment configuration in
|
||||||
|
|
|
@ -171,16 +171,11 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Main training loop
|
# Main training loop
|
||||||
|
data_generator = itertools.cycle(dataloader)
|
||||||
for global_step in range(max_steps):
|
for global_step in range(max_steps):
|
||||||
# Generate training data
|
# Generate training data
|
||||||
with stats_tracker.record_timing("rollout"):
|
with stats_tracker.record_timing("rollout"):
|
||||||
try:
|
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||||
data = next(data_generator)
|
|
||||||
except StopIteration:
|
|
||||||
data_generator = iter(train_dataloader)
|
|
||||||
data = next(data_generator)
|
|
||||||
|
|
||||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
|
||||||
|
|
||||||
batch = batch.to(actor.device)
|
batch = batch.to(actor.device)
|
||||||
|
|
||||||
|
|
|
@ -68,6 +68,7 @@ ref:
|
||||||
trial_name: ${trial_name}
|
trial_name: ${trial_name}
|
||||||
path: ${actor.path}
|
path: ${actor.path}
|
||||||
init_from_scratch: false
|
init_from_scratch: false
|
||||||
|
disable_dropout: true
|
||||||
dtype: ${actor.dtype}
|
dtype: ${actor.dtype}
|
||||||
mb_spec:
|
mb_spec:
|
||||||
max_tokens_per_mb: 10240
|
max_tokens_per_mb: 10240
|
||||||
|
@ -81,9 +82,9 @@ sglang:
|
||||||
random_seed: ${seed}
|
random_seed: ${seed}
|
||||||
skip_tokenizer_init: true
|
skip_tokenizer_init: true
|
||||||
dtype: ${actor.dtype}
|
dtype: ${actor.dtype}
|
||||||
max_running_requests: null
|
max_running_requests: null
|
||||||
context_length: 32768
|
context_length: 32768
|
||||||
mem_fraction_static: 0.9
|
mem_fraction_static: 0.8
|
||||||
|
|
||||||
# datasets
|
# datasets
|
||||||
train_dataset:
|
train_dataset:
|
||||||
|
@ -94,7 +95,7 @@ train_dataset:
|
||||||
path: openai/gsm8k
|
path: openai/gsm8k
|
||||||
type: rl
|
type: rl
|
||||||
|
|
||||||
valid_dataset:
|
valid_dataset:
|
||||||
batch_size: 256
|
batch_size: 256
|
||||||
shuffle: true
|
shuffle: true
|
||||||
pin_memory: true
|
pin_memory: true
|
||||||
|
@ -131,4 +132,5 @@ stats_logger:
|
||||||
experiment_name: ${experiment_name}
|
experiment_name: ${experiment_name}
|
||||||
trial_name: ${trial_name}
|
trial_name: ${trial_name}
|
||||||
fileroot: ${cluster.fileroot}
|
fileroot: ${cluster.fileroot}
|
||||||
|
wandb:
|
||||||
|
mode: disabled
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import itertools
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -129,7 +130,7 @@ def main(args):
|
||||||
max_steps = total_epochs * steps_per_epoch
|
max_steps = total_epochs * steps_per_epoch
|
||||||
|
|
||||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||||
data_generator = iter(train_dataloader)
|
data_generator = itertools.cycle(train_dataloader)
|
||||||
for global_step in range(max_steps):
|
for global_step in range(max_steps):
|
||||||
epoch = global_step // steps_per_epoch
|
epoch = global_step // steps_per_epoch
|
||||||
step = global_step % steps_per_epoch
|
step = global_step % steps_per_epoch
|
||||||
|
@ -138,12 +139,7 @@ def main(args):
|
||||||
if config.async_training:
|
if config.async_training:
|
||||||
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
|
||||||
else:
|
else:
|
||||||
try:
|
batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
|
||||||
data = next(data_generator)
|
|
||||||
except StopIteration:
|
|
||||||
data_generator = iter(train_dataloader)
|
|
||||||
data = next(data_generator)
|
|
||||||
batch = rollout.rollout_batch(data, workflow=workflow)
|
|
||||||
|
|
||||||
batch = batch.to(actor.device)
|
batch = batch.to(actor.device)
|
||||||
# Create barrier to synchronize all rollout processes.
|
# Create barrier to synchronize all rollout processes.
|
||||||
|
|
|
@ -6,7 +6,7 @@ pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 "deepspeed>=0.17.
|
||||||
pip install "sglang[all]==0.4.9.post2"
|
pip install "sglang[all]==0.4.9.post2"
|
||||||
pip install megatron-core==0.11.0 nvidia-ml-py
|
pip install megatron-core==0.11.0 nvidia-ml-py
|
||||||
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
|
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
|
||||||
pip install "flash-attn<=2.7.3" --no-build-isolation
|
pip install "flash-attn<=2.8.2" --no-build-isolation
|
||||||
|
|
||||||
# Package used for calculating math reward
|
# Package used for calculating math reward
|
||||||
pip install -e evaluation/latex2sympy
|
pip install -e evaluation/latex2sympy
|
||||||
|
|
|
@ -56,6 +56,8 @@ dependencies = [
|
||||||
"torchdata",
|
"torchdata",
|
||||||
"autoflake",
|
"autoflake",
|
||||||
"tensordict",
|
"tensordict",
|
||||||
|
"pybase64",
|
||||||
|
"msgspec",
|
||||||
|
|
||||||
# Monitoring and logging
|
# Monitoring and logging
|
||||||
"wandb",
|
"wandb",
|
||||||
|
|
|
@ -74,3 +74,6 @@ torchdata
|
||||||
autoflake
|
autoflake
|
||||||
tensordict
|
tensordict
|
||||||
deepspeed>=0.17.2
|
deepspeed>=0.17.2
|
||||||
|
pybase64
|
||||||
|
msgspec
|
||||||
|
transformers==4.53.1
|
Loading…
Reference in New Issue