0721_merge7

This commit is contained in:
朱晗 2025-07-21 18:10:24 +08:00
parent c29561498e
commit f451dbd692
43 changed files with 0 additions and 2369 deletions

0
.github/ISSUE_TEMPLATE/bug.md vendored Executable file → Normal file
View File

0
.github/ISSUE_TEMPLATE/doc.md vendored Executable file → Normal file
View File

0
.github/ISSUE_TEMPLATE/feature.md vendored Executable file → Normal file
View File

0
.github/ISSUE_TEMPLATE/refactor.md vendored Executable file → Normal file
View File

0
.github/workflows/deploy-docs.yml vendored Executable file → Normal file
View File

0
.github/workflows/format-check.yml vendored Executable file → Normal file
View File

0
.github/workflows/installation-validation.yml vendored Executable file → Normal file
View File

0
.github/workflows/test-arealite.yml vendored Executable file → Normal file
View File

0
Dockerfile Executable file → Normal file
View File

0
LEGAL.md Executable file → Normal file
View File

0
LICENSE Executable file → Normal file
View File

0
MANIFEST.in Executable file → Normal file
View File

0
Makefile Executable file → Normal file
View File

0
README.md Executable file → Normal file
View File

View File

@ -1,148 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import abc
import functools
from dataclasses import dataclass
from typing import Any, Callable, Optional, SupportsFloat
from gymnasium import Env
from gymnasium.core import ActType, ObsType
from gymnasium.utils import seeding
from arealite.api.cli_args import (
GenerationHyperparameters,
RolloutCollectorConfig,
TrainingArgs,
)
from arealite.api.io_struct import AgentInferInput, AgentInferOutput, Trajectory
from arealite.api.llm_client_api import LLMClient
class Agent(abc.ABC):
def __init__(self, args: TrainingArgs):
self.args = args
async def aact(self, inp: AgentInferInput) -> AgentInferOutput:
"""Async version of act. Given an observation, return an action and data used for RL training."""
raise NotImplementedError()
async def areset(self) -> None:
"""Async version of reset. Resets the agent's memory."""
raise NotImplementedError()
# Re-export the gymnasium environment class
class Environment(abc.ABC, Env):
def __init__(self, args: TrainingArgs):
self.args = args
@abc.abstractmethod
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
raise NotImplementedError()
@abc.abstractmethod
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
# Initialize the RNG if the seed is manually passed
if seed is not None:
self._np_random, self._np_random_seed = seeding.np_random(seed)
class RolloutCollector(abc.ABC):
def __init__(
self,
args: TrainingArgs,
config: RolloutCollectorConfig,
agent: Agent | None = None,
env: Environment | None = None,
reward_func: Callable | None = None,
):
self.args = args
self.config = config
# Used in agentic scenarios
self.agent = agent
self.env = env
# Used in RLVR
self.reward_func = reward_func
async def arun_episode(
self,
llm_client: LLMClient,
gconfig: GenerationHyperparameters,
env_option: Optional[Any] = None,
seed: Optional[int] = None,
) -> Trajectory:
"""Async version of run_episode. Run a single episode and return the trajectory."""
raise NotImplementedError()
@dataclass
class RolloutCollectorFactory:
args: TrainingArgs
def make_collector(self, config: RolloutCollectorConfig) -> RolloutCollector:
if config.type == "rlvr":
from arealite.impl.rlvr.rlvr_collector import RlvrCollector
rlvr_config = config.rlvr
assert rlvr_config is not None
if rlvr_config.reward_type == "areal-math":
from arealite.impl.rlvr.rewards.areal_math import math_reward
reward_fn = functools.partial(
math_reward, dataset_path=rlvr_config.solution_path
)
elif rlvr_config.reward_type == "areal-code":
from arealite.impl.rlvr.rewards.areal_code import code_reward
reward_fn = functools.partial(
code_reward, dataset_path=rlvr_config.solution_path
)
elif rlvr_config.reward_type == "gsm8k":
from arealite.impl.rlvr.rewards.gsm8k import (
gsm8k_reward_fn as reward_fn,
)
elif rlvr_config.reward_type == "clevr_count_70k":
from arealite.impl.rlvr.rewards.clevr_count_70k import (
clevr_count_70k_reward_fn as reward_fn,
)
else:
raise NotImplementedError(
f"Unknown reward type: {rlvr_config.reward_type}"
)
return RlvrCollector(
self.args,
config=config,
reward_fn=reward_fn,
)
if config.type == "math_code_single_step":
from arealite.impl.agentic.math_code_single_step import (
MathCodeAgent,
MathCodeSingleStepCollector,
MathCodeSingleStepEnv,
)
agent = MathCodeAgent(self.args)
env = MathCodeSingleStepEnv(
self.args,
solution_path=config.math_code_single_step.solution_path,
)
return MathCodeSingleStepCollector(
self.args,
config=config,
agent=agent,
env=env,
)
raise NotImplementedError(f"Unknown agent type: {config.type}")

View File

@ -1,156 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import abc
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
import torch.distributed as dist
from datasets import Dataset
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import TrainerConfig, TrainingArgs
from realhf.base import constants
if TYPE_CHECKING:
from arealite.system.rollout_controller import RolloutController
# 4. use huggingface.trainerstate
# TODO: how to do checkpointing?
# follow the signature of transformers.Trainer if possible
class Trainer(abc.ABC):
def __init__(
self,
args: TrainingArgs,
trainer_config: TrainerConfig,
train_dataset: Dataset,
valid_dataset: Optional[Dataset] = None,
rollout_controller: Optional["RolloutController"] = None,
):
self.args = args
self.trainer_config = trainer_config
self.train_dataset = train_dataset
self.valid_dataset = valid_dataset
self.rollout_controller = rollout_controller
self.train_dataloader = None
self.valid_dataloader = None
def create_train_dataloader(self):
cfg = self.args.train_dataset
if dist.is_initialized():
batch_size = cfg.batch_size // dist.get_world_size()
else:
batch_size = cfg.batch_size
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=batch_size,
shuffle=cfg.shuffle,
pin_memory=cfg.pin_memory,
num_workers=cfg.num_workers,
drop_last=True,
collate_fn=lambda x: x
)
def create_valid_dataloader(self):
if self.args.valid_dataset is None:
return
cfg = self.args.valid_dataset
if dist.is_initialized():
batch_size = cfg.batch_size // dist.get_world_size()
else:
batch_size = cfg.batch_size
self.valid_dataloader = StatefulDataLoader(
dataset=self.valid_dataset,
batch_size=batch_size,
shuffle=cfg.shuffle,
pin_memory=cfg.pin_memory,
num_workers=cfg.num_workers,
drop_last=True,
collate_fn=lambda x: x
)
@property
def local_train_batch_size(self):
if not dist.is_initialized():
return self.args.train_dataset.batch_size
return self.args.train_dataset.batch_size // dist.get_world_size()
# TODO: check HF trainer signature
def train(self, resume_from_checkpoint: Optional[Union[str, bool]] = None):
raise NotImplementedError()
def get_save_checkpoint_path(
self, epoch: int, step: int, globalstep: int, name: str = "model"
):
path = os.path.join(
constants.get_save_path(self.args),
name,
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
)
os.makedirs(path, exist_ok=True)
return path
@dataclass
class TrainerFactory:
args: TrainingArgs
def make_trainer(
self,
config: TrainerConfig,
train_dataset: Dataset,
valid_dataset: Optional[Dataset] = None,
rollout_controller: Optional["RolloutController"] = None,
) -> Trainer:
if config.type == "grpo":
from arealite.impl.trainer.grpo import SpmdGRPOTrainer
return SpmdGRPOTrainer(
self.args,
config,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
rollout_controller=rollout_controller,
)
elif config.type == "sft":
from arealite.impl.trainer.sft import SFTTrainer
return SFTTrainer(
self.args,
config,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
rollout_controller=rollout_controller,
)
elif config.type == "vl_sft":
from arealite.impl.trainer.vl_sft import VL_SFTTrainer
return VL_SFTTrainer(
self.args,
config,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
rollout_controller=rollout_controller,
)
elif config.type == "vl_grpo":
from arealite.impl.trainer.vl_grpo import VL_SpmdGRPOTrainer
return VL_SpmdGRPOTrainer(
self.args,
config,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
rollout_controller=rollout_controller,
)
else:
raise NotImplementedError(f"Unknown trainer type: {config.type}")

View File

@ -1,20 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import sys
from arealite.api.cli_args import TrainingArgs, prepare_training_args
from arealite.api.llm_server_api import LLMServerFactory
from realhf.base import seeding
def main():
"""Main entry point for launching the LLM server."""
cfg: TrainingArgs = prepare_training_args(sys.argv[1:])[0]
seeding.set_random_seed(cfg.seed, "llm_server")
server = LLMServerFactory(cfg).make_server(cfg.rollout.llm_service)
server.start()
if __name__ == "__main__":
main()

View File

@ -1,58 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import os
import sys
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
from arealite.api.cli_args import TrainingArgs, prepare_training_args
from arealite.api.dataset_api import DatasetFactory
from arealite.api.rollout_api import RolloutCollectorFactory
from arealite.api.trainer_api import TrainerFactory
from arealite.system.rollout_controller import RolloutController
from realhf.base import seeding
@record
def main():
"""Main entry point for launching the trainer."""
cfg: TrainingArgs = prepare_training_args(sys.argv[1:])[0]
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
seeding.set_random_seed(cfg.seed, f"trainer{rank}")
# Initialize the global pytorch distributed communication group.
dist.init_process_group("nccl")
# Load and split dataset
dataset_factory = DatasetFactory(cfg)
train_dataset = dataset_factory.make_dataset(cfg.train_dataset, rank, world_size)
valid_dataset = None
if cfg.valid_dataset is not None:
valid_dataset = dataset_factory.make_dataset(
cfg.valid_dataset, rank, world_size
)
# Create rollout controller for online training and evaluation.
rollout_controller = None
if cfg.rollout is not None:
rollout_factory = RolloutCollectorFactory(cfg)
collector = rollout_factory.make_collector(cfg.rollout.collector)
rollout_controller = RolloutController(cfg, cfg.rollout, collector=collector)
# If trainer is given, run RL or offline training.
if cfg.trainer is not None:
trainer_factory = TrainerFactory(cfg)
trainer = trainer_factory.make_trainer(
cfg.trainer,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
rollout_controller=rollout_controller,
)
trainer.train()
if __name__ == "__main__":
main()

View File

@ -1,171 +0,0 @@
# Basic experiment info
experiment_name: gsm8k-test
trial_name: my-trial-3
seed: 1
mode: local
wandb:
mode: disabled
entity: null
project: null
name: null
job_type: null
group: null
notes: null
tags: null
config: null
tensorboard:
path: null
exp_ctrl:
total_train_epochs: 5
save_freq_epochs: 1
save_freq_steps: null
save_freq_secs: null
ckpt_freq_epochs: null
ckpt_freq_steps: null
ckpt_freq_secs: 600
eval_freq_epochs: null
eval_freq_steps: null
eval_freq_secs: null
benchmark_steps: null
benchmark_n_seqs: null
# whether to allow persistent servers
shutdown_server_on_exit: true
# Allocation and parallelism
allocation_mode: sglang.d4p1t1+d4p1t1
n_nodes: 1
n_gpus_per_node: 8
# Cluster configuration
ray_temp_path: /tmp/ray
cluster:
cluster_name: local
fileroot: /tmp/arealite/
n_nodes: 1
n_gpus_per_node: 8
name_resolve:
type: nfs
nfs_record_root: /tmp/arealite/name_resolve/
# Datasets
train_dataset:
path: json
name: null
split: train
data_files: /storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl
batch_size: 32
shuffle: True
preprocessor:
type: areal
valid_dataset: null
# Rollout config
rollout:
collector:
type: rlvr
rlvr:
reward_type: areal-math
solution_path: /storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl
num_workers: 1
max_concurrent_rollouts: null
max_head_offpolicyness: 0
filter_reward_lb: -10000
filter_reward_ub: 10000
server_backend: sglang
model_path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
gconfig:
n_samples: 16
max_new_tokens: 512
min_new_tokens: 0
top_p: 1.0
top_k: 1000000
temperature: 1.0
llm_client:
schedule_policy: round_robin
request_timeout: 3600
request_retries: 3
llm_service:
served_model_name: null
health_check_interval: 5
startup_timeout: 300
max_unhealth_count: 3
graceful_shutdown_on_unhealthy: true
sglang:
dtype: "bfloat16"
enable_mixed_chunk: false
enable_torch_compile: false
torch_compile_max_bs: 32
cuda_graph_max_bs: null
cuda_graph_bs: null
triton_attention_reduce_in_fp32: false
triton_attention_num_kv_splits: 8
num_continuous_decode_steps: 1
attention_backend: "flashinfer"
sampling_backend: null
context_length: 32768
mem_fraction_static: 0.9
max_running_requests: null
chunked_prefill_size: -1
max_prefill_tokens: 32768
schedule_policy: "lpm"
schedule_conservativeness: 1.0
cpu_offload_gb: 0
kv_cache_dtype: "auto"
log_level: "warning"
log_level_http: "warning"
log_requests: false
log_requests_level: 0
show_time_cost: false
enable_metrics: true
decode_log_interval: 1
# Trainer
trainer:
type: grpo
grpo:
async_training: true
actor:
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
init_from_scratch: false
gradient_checkpointing: false
bf16: true
optimizer:
type: adam
lr: 1.0e-6
weight_decay: 0.05
beta1: 0.9
beta2: 0.999
eps: 1.0e-08
min_lr_ratio: 0.0
lr_scheduler_type: constant
warmup_steps_proportion: 0.001
initial_loss_scale: 4294967296.0
min_loss_scale: 1.0
loss_scale_window: 5.0
hysteresis: 2
gradient_clipping: 1.0
backend:
type: fsdp
ref: null
mb_spec:
max_tokens_per_mb: 10240
# Algorithm
group_adv_norm: False
ppo_n_minibatches: 4
eps_clip: 0.2
c_clip: null
reward_scaling: 10.0
reward_bias: -0.5
max_reward_clip: 20.0
mask_no_eos_with_zero: false
discount: 1.0
gae_lambda: 1.0
adv_norm: true
kl_ctl: 0.0
recompute_logprob: true
use_decoupled_loss: true
behav_imp_weight_cap: null

View File

@ -1,104 +0,0 @@
# Basic experiment info
experiment_name: test-sft
trial_name: test-trial
seed: 1
mode: ray
n_nodes: 1
n_gpus_per_node: 8
wandb:
mode: disabled
entity: null
project: null
name: null
job_type: null
group: null
notes: null
tags: null
config: null
tensorboard:
path: null
exp_ctrl:
total_train_epochs: 2
save_freq_epochs: 1
save_freq_steps: null
save_freq_secs: null
ckpt_freq_epochs: null
ckpt_freq_steps: null
ckpt_freq_secs: 600
eval_freq_epochs: 1
eval_freq_steps: null
eval_freq_secs: null
benchmark_steps: null
benchmark_n_seqs: null
ray_temp_path: /tmp/ray
cluster:
cluster_name: local
fileroot: /tmp/arealite/
n_nodes: 32
n_gpus_per_node: 8
name_resolve:
type: nfs
nfs_record_root: /tmp/arealite/nfs_record_root/
train_dataset:
path: openai/gsm8k
preprocessor:
type: gsm8k_sft
name: main
split: train
data_files: null
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
valid_dataset:
path: openai/gsm8k
preprocessor:
type: gsm8k_sft
name: main
split: test
data_files: null
batch_size: 128
shuffle: true
pin_memory: true
num_workers: 4
trainer:
type: sft
sft:
model:
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
init_from_scratch: false
gradient_checkpointing: false
bf16: false
optimizer:
type: adam
lr: 2.0e-05
weight_decay: 0.05
beta1: 0.9
beta2: 0.95
eps: 1.0e-05
min_lr_ratio: 0.0
lr_scheduler_type: constant
warmup_steps_proportion: 0.001
initial_loss_scale: 4294967296.0
min_loss_scale: 1.0
loss_scale_window: 5.0
hysteresis: 2
gradient_clipping: 1.0
backend:
type: fsdp
fsdp:
wrap_policy:
transformer_layer_cls_to_wrap: null
offload_params: false
mb_spec:
n_mbs: 1
max_tokens_per_mb: 10240
rollout: null

View File

@ -1,87 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import os
import sys
from arealite.api.cli_args import prepare_training_args
from arealite.api.io_struct import AllocationMode
from arealite.api.llm_server_api import LLMServiceRegistry
from realhf.base import constants, name_resolve, names
from realhf.scheduler.client import JobException, JobState
from realhf.scheduler.client import make as make_scheduler
def main():
cfg, config_file = prepare_training_args(sys.argv[1:])
if cfg.shutdown_server_on_exit:
name_resolve.clear_subtree(
names.trial_root(
experiment_name=cfg.experiment_name, trial_name=cfg.trial_name
)
)
# Launch inference and training jobs
alloc_mode = AllocationMode.from_str(cfg.allocation_mode)
assert cfg.mode == "local"
scheduler = make_scheduler(cfg)
BASE_ENVIRONS = constants.get_env_vars(cfg)
for k, v in BASE_ENVIRONS.items():
os.environ[k] = v
# discover existing servers
existing_servers = LLMServiceRegistry(
cfg.experiment_name, cfg.trial_name
).get_healthy_servers()
# Launch LLM servers.
if len(existing_servers) == 0:
n_gpus_per_instance = alloc_mode.gen_pp_size * alloc_mode.gen_tp_size
servers_to_launch = alloc_mode.gen_dp_size - len(existing_servers)
scheduler.submit_array(
worker_type="llm_server",
cmd=f"python3 arealite/cli/launch_server.py --config {str(config_file)}",
count=servers_to_launch,
cpu=cfg.cpu_per_inf_proc * n_gpus_per_instance,
gpu=n_gpus_per_instance,
mem=cfg.mem_per_inf_proc * n_gpus_per_instance,
env_vars=BASE_ENVIRONS,
container_image=cfg.cluster.gpu_infer_image,
)
# Launch trainers.
scheduler.submit(
worker_type="trainer",
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} arealite/cli/launch_trainer.py --config {str(config_file)}",
cpu=cfg.cpu_per_train_proc * alloc_mode.train_world_size,
gpu=alloc_mode.train_world_size,
mem=cfg.cpu_per_train_proc * cfg.mem_per_train_proc,
container_image=cfg.cluster.gpu_image,
nodelist=cfg.nodelist,
exclude=cfg.exclude,
env_vars=BASE_ENVIRONS,
hostfile=False,
multiprog=False,
)
# Waiting for the job.
try:
scheduler.wait(
check_status=(
JobState.CANCELLED,
JobState.FAILED,
JobState.NOT_FOUND,
JobState.COMPLETED,
),
remove_status=(),
)
except (KeyboardInterrupt, JobException, TimeoutError):
kill_signal = (
"SIGKILL" if cfg.mode == "slurm" else "SIGTERM"
) # use sigkill to terminate slurm jobs
if cfg.shutdown_server_on_exit:
scheduler.stop_all(kill_signal)
else:
scheduler.stop("trainer")
if __name__ == "__main__":
main()

View File

@ -1,195 +0,0 @@
import functools
from typing import Dict, Optional, Tuple
import torch
import torch.distributed
from realhf.base import pkg_version
def actor_loss_fn(
logprobs: torch.Tensor,
old_logprobs: torch.Tensor,
advantages: torch.Tensor,
eps_clip: float,
loss_mask: torch.Tensor,
c_clip: Optional[float] = None,
proximal_logprobs: Optional[torch.Tensor] = None,
behav_imp_weight_cap: Optional[float] = None,
) -> Tuple[torch.Tensor, Dict]:
denorm_logprobs = (
proximal_logprobs if proximal_logprobs is not None else old_logprobs
)
loss_mask_count = loss_mask.count_nonzero() or 1
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * clipped_ratio
clip_mask = pg_loss1.detach() < pg_loss2.detach()
pg_loss = torch.max(pg_loss1, pg_loss2)
if c_clip is not None:
assert c_clip > 1.0, c_clip
pg_loss3 = torch.sign(advantages) * c_clip * advantages
dual_clip_mask = pg_loss3.detach() < pg_loss.detach()
pg_loss = torch.min(pg_loss, pg_loss3)
else:
dual_clip_mask = torch.zeros_like(clip_mask)
if proximal_logprobs is not None:
behav_kl = proximal_logprobs - old_logprobs
behav_imp_weight = behav_kl.exp()
behav_mask = (
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
if behav_imp_weight_cap is not None
else loss_mask
)
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
pg_loss = pg_loss * behav_imp_weight
logging_loss = pg_loss.detach()
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
clip_mask.logical_and_(loss_mask)
dual_clip_mask.logical_and_(loss_mask)
stat = dict(
loss=logging_loss,
importance_weight=ratio.detach(),
approx_kl=(logprobs - denorm_logprobs).detach(),
clip_mask=clip_mask,
dual_clip_mask=dual_clip_mask,
)
if proximal_logprobs is not None:
stat["behave_imp_weight"] = behav_imp_weight
stat["behave_approx_kl"] = behav_kl
stat["behave_mask"] = behav_mask
return pg_loss, stat
def _huber_loss(x: torch.Tensor, y: torch.Tensor, delta: float):
diff = torch.abs(x - y)
return torch.where(diff < delta, 0.5 * diff**2, delta * (diff - 0.5 * delta))
def _mse_loss(x: torch.Tensor, y: torch.Tensor):
return 0.5 * (x - y) ** 2
def critic_loss_fn(
value: torch.Tensor,
old_value: torch.Tensor,
target_value: torch.Tensor,
value_eps_clip: float,
loss_mask: torch.Tensor,
loss_fn_type: str = "mse",
) -> Tuple[torch.Tensor, Dict]:
if loss_fn_type == "huber":
loss_fn = functools.partial(_huber_loss, delta=10.0)
elif loss_fn_type == "mse":
loss_fn = _mse_loss
else:
raise NotImplementedError(f"Unknown loss fn type: {loss_fn_type}")
value_loss_original = loss_fn(value, target_value)
value_clipped = old_value + (value - old_value).clamp(
-value_eps_clip, value_eps_clip
)
value_loss_clipped = loss_fn(value_clipped, target_value)
value_loss = torch.max(value_loss_original, value_loss_clipped)
with torch.no_grad():
clip_mask = value_loss_clipped.detach() > value_loss_original.detach()
clip_mask.logical_and_(loss_mask)
stat = dict(clip_mask=clip_mask, loss=value_loss.detach())
value_loss = torch.where(loss_mask, value_loss, 0).sum() / loss_mask.count_nonzero()
return value_loss, stat
@torch.no_grad()
def get_packed_rewards(
kl_ctl: float,
clip_reward_value: float,
log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
reward_score: torch.Tensor,
cu_seqlens: torch.Tensor,
seq_no_eos_mask: torch.Tensor,
mask_no_eos_with_zero: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
tot_rewards = -kl_ctl * (log_probs - ref_log_probs)
tot_rewards[cu_seqlens[1:] - 1] = 0
kl_rewards = tot_rewards.clone()
reward_score = reward_score.clip(-clip_reward_value, clip_reward_value)
indices = torch.clip(cu_seqlens[1:] - 2, min=0)
if mask_no_eos_with_zero:
tot_rewards[indices] += torch.where(seq_no_eos_mask, 0, reward_score)
else:
tot_rewards[indices] += reward_score
return kl_rewards, tot_rewards
def pygae1d_nolp_misalign(
rewards: torch.Tensor,
values: torch.Tensor,
cu_seqlens_: torch.Tensor,
bootstrap: torch.Tensor,
gamma: float,
lam: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
cu_seqlens = cu_seqlens_.clone()
cu_seqlens[1:] += torch.ones_like(cu_seqlens_[1:]).cumsum(0)
bs = cu_seqlens_.shape[0] - 1
assert values.shape[0] == rewards.shape[0] + bs
advantages_reversed = []
returns_reversed = []
for i in reversed(range(bs)):
v_offset = cu_seqlens[i]
r_offset, r_end = cu_seqlens_[i], cu_seqlens_[i + 1]
assert cu_seqlens[i + 1] - v_offset - 1 == r_end - r_offset
lastgaelam = 0
for t in reversed(range(r_end - r_offset)):
nextvalues = values[v_offset + t + 1]
if t == r_end - r_offset - 1:
nextvalues *= bootstrap[i]
delta = rewards[r_offset + t] + gamma * nextvalues - values[v_offset + t]
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
returns_reversed.append(lastgaelam + values[v_offset + t])
advantages = torch.stack(advantages_reversed[::-1])
returns = torch.stack(returns_reversed[::-1])
return advantages, returns
def cugae1d_nolp_misalign_func(
rewards: torch.Tensor,
values: torch.Tensor,
cu_seqlens: torch.Tensor,
truncate: torch.Tensor,
gamma: float,
lam: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
if pkg_version.is_available("cugae"):
from cugae import cugae1d_nolp_misalign_func as gae_1d_nolp_misalign
else:
from realhf._C.cugae import gae_1d_nolp_misalign
assert len(rewards.shape) == len(values.shape) == len(cu_seqlens.shape) == 1
assert cu_seqlens[0] == 0 and cu_seqlens[-1] == rewards.shape[0]
return gae_1d_nolp_misalign(rewards, values, cu_seqlens, truncate, gamma, lam)
@torch.no_grad()
def get_packed_advantages_and_returns(
gamma: float,
lam: float,
values: torch.Tensor,
rewards: torch.Tensor,
short1cu_seqlens: torch.Tensor,
seq_no_eos_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if rewards.get_device() == -1:
return pygae1d_nolp_misalign(
rewards, values, short1cu_seqlens, seq_no_eos_mask, gamma, lam
)
try:
return cugae1d_nolp_misalign_func(
rewards, values, short1cu_seqlens.int(), seq_no_eos_mask.bool(), gamma, lam
)
except ModuleNotFoundError:
return pygae1d_nolp_misalign(
rewards, values, short1cu_seqlens, seq_no_eos_mask, gamma, lam
)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 892 KiB

File diff suppressed because one or more lines are too long

View File

@ -1,10 +0,0 @@
{"prompt": "<\uff5cUser\uff5c>\nBaron Munchausen told a story. \"There were a whole crowd of us. We reached a crossroads. Then half of our group turned left, a third turned right, and a fifth went straight.\" \"But wait, the Duke remarked, the sum of half, a third, and a fifth isn't equal to one, so you are lying!\" The Baron replied, \"I'm not lying, I'm rounding. For example, there are 17 people. I say that a third turned. Should one person split in your opinion? No, with rounding, six people turned. From whole numbers, the closest to the fraction $17 / 3$ is 6. And if I say that half of the 17 people turned, it means 8 or 9 people.\" It is known that Baron Munchausen never lies. What is the largest number of people that could have been in the crowd?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "00006d8f079c739f", "solutions": ["\\boxed{37}"]}
{"prompt": "<\uff5cUser\uff5c>What is the unit digit of the product\n\n$$\n(5+1)\\left(5^{3}+1\\right)\\left(5^{6}+1\\right)\\left(5^{12}+1\\right) ?\n$$\n\n(a) 0 \n(b) 1 \n(c) 2 \n(d) 5 \n(e) 6\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "000316109ea516b3", "solutions": ["\\boxed{e}"]}
{"prompt": "<\uff5cUser\uff5c>Given points \\( A(4,0) \\) and \\( B(2,2) \\) are inside the ellipse \\( \\frac{x^{2}}{25}+\\frac{y^{2}}{9}=1 \\), and \\( M \\) is a point on the ellipse, find the maximum value of \\( |MA| + |MB| \\).\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "000adcfa66ee4270", "solutions": ["\\boxed{10+2\\sqrt{10}}"]}
{"prompt": "<\uff5cUser\uff5c>There is a schoolbag containing 12 cards labeled $1, 1, 2, 2, \\cdots, 6, 6$. A person draws one card at a time without replacement. If a card is drawn that has the same number as a previously drawn card, both cards are discarded. The process ends when the person has 3 single cards in hand or all cards in the schoolbag have been drawn. Find the probability that all cards in the schoolbag are drawn.\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "001354647264e663", "solutions": ["\\boxed{\\frac{9}{385}}"]}
{"prompt": "<\uff5cUser\uff5c>For the sequence of numbers \\( n_{1}, n_{2}, n_{3}, \\ldots \\), the relation \\( n_{i} = 2 n_{i-1} + a \\) holds for all \\( i > 1 \\). If \\( n_{2} = 5 \\) and \\( n_{8} = 257 \\), what is \\( n_{5} \\)?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "0014142e5f3c28a7", "solutions": ["\\boxed{33}"]}
{"prompt": "<\uff5cUser\uff5c>Three players play tic-tac-toe together. In other words, the three players take turns placing an \"A\", \"B\", and \"C\", respectively, in one of the free spots of a \\(3 \\times 3\\) grid, and the first player to have three of their label in a row, column, or diagonal wins. How many possible final boards are there where the player who goes third wins the game? (Rotations and reflections are considered different boards, but the order of placement does not matter.)\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "0017c4e9f72d26eb", "solutions": ["\\boxed{148}"]}
{"prompt": "<\uff5cUser\uff5c>Let \\( a_{1}, a_{2}, \\cdots, a_{2014} \\) be a permutation of the positive integers \\( 1, 2, \\cdots, 2014 \\). Define\n\\[ S_{k} = a_{1} + a_{2} + \\cdots + a_{k} \\quad (k=1, 2, \\cdots, 2014). \\]\n\nWhat is the maximum number of odd numbers among \\( S_{1}, S_{2}, \\cdots, S_{2014} \\)?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "00231541a71983cd", "solutions": ["\\boxed{1511}"]}
{"prompt": "<\uff5cUser\uff5c>\nThe polynomial \\( G(x) \\) with real coefficients takes the value 2022 at exactly five distinct points \\( x_{1}<x_{2}<x_{3}<x_{4}<x_{5} \\). It is known that the graph of the function \\( y=G(x) \\) is symmetric with respect to the line \\( x=-7 \\).\n\n(a) Find \\( x_{1}+x_{3}+x_{5} \\).\n\n(b) What is the minimum degree that \\( G(x) \\) can have?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "002ba4c0d1ad1b54", "solutions": ["\\boxed{6}"]}
{"prompt": "<\uff5cUser\uff5c>The square of a natural number has 202 digits. The first 100 digits are 1, followed by 101 digits of 2. Determine the last digit and the number.\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "0041f14cc37aee13", "solutions": ["\\boxed{5}"]}
{"prompt": "<\uff5cUser\uff5c>The descriptors 'even', 'factors of 240', 'multiple of 3', 'odd', 'prime' and 'square' are to be placed in some order as row and column headings around a grid in positions \\(a, b, c, d, e,\\) and \\(f\\). The digits 1 through 9 are to be placed in the empty cells inside the grid so that each digit satisfies both the relevant row and column headings.\n(i) Show that it is possible to complete the grid.\n(ii) In how many different ways can the grid be completed?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "00514ff45cc98a48", "solutions": ["\\boxed{72}"]}

View File

@ -1,172 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import json
from datetime import datetime
from pathlib import Path
import pytest
from datasets import load_dataset
from arealite.api.cli_args import (
DatasetPreprocessor,
GenerationHyperparameters,
GSM8KPreprocessor,
MathCodeSingleStepConfig,
RLVRConfig,
RolloutCollectorConfig,
SGLangConfig,
TrainingArgs,
)
from arealite.api.io_struct import Trajectory
from arealite.api.llm_client_api import LLMClientFactory
from arealite.api.llm_server_api import LLMServerFactory
from arealite.api.rollout_api import RolloutCollectorFactory
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import name_resolve, seeding
EXPR_NAME = "test_rollout"
TRIAL_NAME = "test_rollout"
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-1.7B/"
@pytest.fixture(scope="module")
def tokenizer():
yield load_hf_tokenizer(MODEL_PATH)
@pytest.fixture(scope="module")
def args():
args = TrainingArgs(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
args.rollout.model_path = MODEL_PATH
seeding.set_random_seed(args.seed, EXPR_NAME)
name_resolve.reconfigure(args.cluster.name_resolve)
yield args
name_resolve.reset()
@pytest.fixture(scope="module")
def sglang_server(args):
args.rollout.sglang = SGLangConfig()
server = LLMServerFactory(args).make_server(args.rollout.llm_service)
server._startup()
yield
server._graceful_exit(0)
@pytest.mark.parametrize("task", ["math", "code"])
@pytest.mark.asyncio
async def test_rlvr_rollout(args, sglang_server, tokenizer, task):
jsonl_file = Path(__file__).parent / "data" / f"rlvr_{task}_dataset.jsonl"
args.rollout.server_backend = "sglang"
args.rollout.gconfig = gconfig = GenerationHyperparameters(max_new_tokens=16)
args.rollout.collector = RolloutCollectorConfig(
type="rlvr",
rlvr=RLVRConfig(reward_type=f"areal-{task}", solution_path=jsonl_file),
)
llm_client = LLMClientFactory(args).make_client(args.rollout.llm_client)
collector = RolloutCollectorFactory(args).make_collector(args.rollout.collector)
# Test the rollout collector with the provided JSONL data
with open(jsonl_file, "r") as f:
for i, l in enumerate(f.readlines()):
data = json.loads(l)
env_option = dict(
query_id=data["query_id"],
input_ids=tokenizer.encode(data["prompt"]),
prompt=data["prompt"],
)
res = await collector.arun_episode(
llm_client=llm_client,
gconfig=gconfig,
env_option=env_option,
)
assert isinstance(res, Trajectory)
assert isinstance(res.data, dict)
assert res.prompt == env_option
shape = res.data["input_ids"].shape
for k in ["prompt_mask", "logprobs", "versions"]:
assert res.data[k].shape == shape
assert res.stats.episode_length == 1
assert res.stats.total_reward in [0, 1], res.stats.total_reward
assert res.stats.start_time < datetime.now().timestamp()
@pytest.mark.asyncio
async def test_gsm8k_rollout(args, sglang_server, tokenizer):
args.rollout.server_backend = "sglang"
args.rollout.gconfig = gconfig = GenerationHyperparameters(max_new_tokens=16)
args.rollout.collector = RolloutCollectorConfig(
type="rlvr", rlvr=RLVRConfig(reward_type="gsm8k")
)
collector = RolloutCollectorFactory(args).make_collector(args.rollout.collector)
args.train_dataset.path = "openai/gsm8k"
args.train_dataset.name = "main"
args.train_dataset.split = "train"
args.train_dataset.preprocessor = DatasetPreprocessor(
"gsm8k_rl", gsm8k=GSM8KPreprocessor("strict")
)
from arealite.api.dataset_api import DatasetFactory
llm_client = LLMClientFactory(args).make_client(args.rollout.llm_client)
dataset = (
DatasetFactory(args)
.make_dataset(args.train_dataset, rank=0, world_size=1)
.select(range(10))
)
for i in range(len(dataset)):
env_option = dataset[i]
res = await collector.arun_episode(
llm_client=llm_client,
gconfig=gconfig,
env_option=env_option,
)
assert isinstance(res, Trajectory)
assert isinstance(res.data, dict)
assert res.prompt == env_option
shape = res.data["input_ids"].shape
for k in ["prompt_mask", "logprobs", "versions"]:
assert res.data[k].shape == shape
assert res.stats.episode_length == 1
assert res.stats.total_reward in [0, 1], res.stats.total_reward
assert res.stats.start_time < datetime.now().timestamp()
@pytest.mark.parametrize("task", ["math", "code"])
@pytest.mark.asyncio
async def test_math_code_agentic_rollout(args, task, sglang_server, tokenizer):
jsonl_file = Path(__file__).parent / "data" / f"rlvr_{task}_dataset.jsonl"
args.rollout.server_backend = "sglang"
args.rollout.gconfig = gconfig = GenerationHyperparameters(max_new_tokens=16)
args.rollout.collector = RolloutCollectorConfig(
type="math_code_single_step",
math_code_single_step=MathCodeSingleStepConfig(solution_path=jsonl_file),
)
collector = RolloutCollectorFactory(args).make_collector(args.rollout.collector)
llm_client = LLMClientFactory(args).make_client(args.rollout.llm_client)
# Test the rollout collector with the provided JSONL data
with open(jsonl_file, "r") as f:
for i, l in enumerate(f.readlines()):
data = json.loads(l)
env_option = dict(
query_id=data["query_id"],
input_ids=tokenizer.encode(data["prompt"]),
)
res = await collector.arun_episode(
llm_client=llm_client,
gconfig=gconfig,
env_option=env_option,
)
assert isinstance(res, Trajectory)
assert isinstance(res.data, dict)
assert res.prompt == env_option
shape = res.data["input_ids"].shape
for k in ["prompt_mask", "logprobs", "versions"]:
assert res.data[k].shape == shape
assert res.stats.episode_length == 1
assert res.stats.total_reward in [0, 1], res.stats.total_reward
assert res.stats.start_time < datetime.now().timestamp()

View File

@ -1,209 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import time
from copy import deepcopy
from pathlib import Path
import pytest
import torch.multiprocessing as mp
from datasets import load_dataset
from torchdata.stateful_dataloader import StatefulDataLoader
from arealite.api.cli_args import RLVRConfig, SGLangConfig, TrainingArgs
from arealite.api.io_struct import Trajectory
from arealite.api.llm_server_api import LLMServerFactory
from arealite.api.rollout_api import RolloutCollectorFactory
from arealite.system.rollout_controller import RolloutController
from arealite.tests.utils import mock_rollout_output
from arealite.utils import concat_padded_tensors
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import name_resolve, names, seeding
EXPR_NAME = "test_rollout_controller"
TRIAL_NAME = "test_rollout_controller"
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-1.7B/"
@pytest.fixture(scope="module")
def args():
args = TrainingArgs(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
seeding.set_random_seed(args.seed, EXPR_NAME)
args.rollout.model_path = MODEL_PATH
args.rollout.llm_client.tokenizer_path = MODEL_PATH
args.train_dataset.batch_size = 2
args.rollout.collector.rlvr = RLVRConfig(
solution_path=str(Path(__file__).parent / "data" / f"rlvr_math_dataset.jsonl")
)
start_method = mp.get_start_method()
mp.set_start_method("fork", force=True)
name_resolve.reconfigure(args.cluster.name_resolve)
yield args
name_resolve.reset()
mp.set_start_method(start_method, force=True)
@pytest.fixture(scope="module")
def sglang_server(args):
args.rollout.sglang = SGLangConfig()
server = LLMServerFactory(args).make_server(args.rollout.llm_service)
server._startup()
yield
server._graceful_exit(0)
@pytest.fixture
def dataloader(args):
dataset = load_dataset(
"json",
split="train",
data_files=str(Path(__file__).parent / "data" / f"rlvr_math_dataset.jsonl"),
)
tokenizer = load_hf_tokenizer(MODEL_PATH)
dataset = dataset.map(lambda x: tokenizer(x["prompt"]), batched=True)
yield StatefulDataLoader(
dataset,
batch_size=args.train_dataset.batch_size,
collate_fn=lambda x: x,
drop_last=True,
)
@pytest.mark.parametrize("num_workers", [1, 4])
@pytest.mark.parametrize("n_samples", [1, 2])
def test_generate_batch(args, sglang_server, dataloader, n_samples, num_workers):
args = deepcopy(args)
args.rollout.num_workers = num_workers
args.rollout.gconfig.n_samples = n_samples
args.rollout.gconfig.max_new_tokens = 16
rollout_factory = RolloutCollectorFactory(args)
collector = rollout_factory.make_collector(args.rollout.collector)
rollout_controller = RolloutController(args, args.rollout, collector=collector)
data = next(iter(dataloader))
batch_size = len(data)
result = rollout_controller.generate_batch(batch_size, env_options=data)
assert len(result) == batch_size * n_samples
assert all(isinstance(traj, Trajectory) for traj in result)
for traj in result:
shape = traj.data["input_ids"].shape
assert len(shape) == 2
for v in traj.data.values():
assert v.shape == shape or len(v.shape) == 1
data = concat_padded_tensors([traj.data for traj in result])
assert data["input_ids"].shape[0] == batch_size * n_samples
shape = data["input_ids"].shape
assert len(shape) == 2
for v in data.values():
assert v.shape == shape or len(v.shape) == 1
@pytest.mark.parametrize("batch_size", [1, 2, 3])
@pytest.mark.parametrize("n_samples", [1, 2, 4])
def test_mock_trajs(batch_size, n_samples):
# Test the consistency with mocked rollout output
result = mock_rollout_output(batch_size, n_samples)
assert len(result) == batch_size * n_samples
assert all(isinstance(traj, Trajectory) for traj in result)
for traj in result:
shape = traj.data["input_ids"].shape
assert len(shape) == 2
for v in traj.data.values():
assert v.shape == shape or len(v.shape) == 1
data = concat_padded_tensors([traj.data for traj in result])
assert data["input_ids"].shape[0] == batch_size * n_samples
shape = data["input_ids"].shape
assert len(shape) == 2
for v in data.values():
assert v.shape == shape or len(v.shape) == 1
@pytest.mark.parametrize("n_samples", [1, 4, 16])
@pytest.mark.parametrize("num_workers", [1, 2, 5])
def test_async_rollout(args, sglang_server, dataloader, n_samples, num_workers):
args = deepcopy(args)
args.rollout.gconfig.n_samples = n_samples
args.rollout.gconfig.max_new_tokens = 16
args.train_dataset.batch_size = 2
args.rollout.max_concurrent_rollouts = 16
args.rollout.num_workers = num_workers
rollout_factory = RolloutCollectorFactory(args)
collector = rollout_factory.make_collector(args.rollout.collector)
rollout_controller = RolloutController(args, args.rollout, collector=collector)
# start loop
rollout_controller.start_generate_loop()
assert hasattr(rollout_controller, "_collector_thread")
assert rollout_controller._collector_thread.is_alive()
# Submit data to workers
data = next(iter(dataloader))
rollout_controller.submit(data)
# wait for batch
batch_size = 2
result = rollout_controller.prepare_batch(batch_size)
assert len(result) == batch_size * n_samples
assert all(isinstance(traj, Trajectory) for traj in result)
for traj in result:
shape = traj.data["input_ids"].shape
assert len(shape) == 2
for v in traj.data.values():
assert v.shape == shape or len(v.shape) == 1
data = concat_padded_tensors([traj.data for traj in result])
assert data["input_ids"].shape[0] == batch_size * n_samples
shape = data["input_ids"].shape
assert len(shape) == 2
for v in data.values():
assert v.shape == shape or len(v.shape) == 1
# exit
rollout_controller.stop_generate_loop()
assert rollout_controller._exiting.is_set()
assert not rollout_controller._collector_thread.is_alive()
assert not rollout_controller._worker_processes
@pytest.mark.parametrize("ofp", [1, 2, 4, 16])
def test_async_staleness_control(args, sglang_server, dataloader, ofp):
args = deepcopy(args)
args.rollout.gconfig.n_samples = 2
args.rollout.gconfig.max_new_tokens = 4
args.rollout.max_head_offpolicyness = ofp
args.rollout.max_concurrent_rollouts = 100
rollout_factory = RolloutCollectorFactory(args)
collector = rollout_factory.make_collector(args.rollout.collector)
rollout_controller = RolloutController(args, args.rollout, collector=collector)
name = names.model_version(args.experiment_name, args.trial_name, "actor")
name_resolve.add(name, str(0), replace=True)
# start loop
rollout_controller.start_generate_loop()
batch_size = args.train_dataset.batch_size
gen = iter(dataloader)
rollout_controller.submit(next(gen))
rollout_controller.submit(next(gen))
# wait for some time
time.sleep(15)
assert len(rollout_controller._buffer) == min(
batch_size * 2, batch_size * (ofp + 1)
)
# Update model version
name = names.model_version(args.experiment_name, args.trial_name, "actor")
name_resolve.add(name, str(1), replace=True)
print("Updated model version", flush=True)
# submit again
rollout_controller.submit(next(gen))
rollout_controller.submit(next(gen))
# wait for some time
time.sleep(15)
assert len(rollout_controller._buffer) == min(
batch_size * 4, batch_size * (ofp + 2)
)
# exit
rollout_controller.stop_generate_loop()

View File

@ -1,108 +0,0 @@
"""Test script for FSDP Engine implementation."""
import os
from typing import Dict
import torch
from datasets import load_dataset
from arealite.api.cli_args import (
DatasetConfig,
DatasetPreprocessor,
EngineBackendConfig,
EngineConfig,
OptimizerConfig,
SFTTrainerConfig,
TrainerConfig,
TrainingArgs,
)
from arealite.api.dataset_api import DatasetFactory
from arealite.api.trainer_api import TrainerFactory
def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
"""Mock loss function for testing."""
return torch.mean(logits)
def mock_loss_weight_fn(logits: torch.Tensor, input_data: Dict) -> float:
"""Mock loss weight function for testing."""
return float(input_data["attention_mask"].sum())
def test_sft():
"""Test SFTTrainer"""
# environment variables for torch distributed
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "7777"
train_dataset = DatasetConfig(
path="openai/gsm8k",
preprocessor=DatasetPreprocessor("gsm8k_sft"),
name="main",
split="train",
batch_size=8,
shuffle=True,
pin_memory=True,
)
valid_dataset = DatasetConfig(
path="openai/gsm8k",
preprocessor=DatasetPreprocessor("gsm8k_sft"),
name="main",
split="test",
batch_size=8,
shuffle=False,
pin_memory=True,
)
engine_config = EngineConfig(
path="/storage/openpsi/models/Qwen__Qwen3-1.7B/",
gradient_checkpointing=False,
optimizer=OptimizerConfig(),
backend=EngineBackendConfig(type="hf"),
)
sft_config = SFTTrainerConfig(
model=engine_config,
)
train_config = TrainerConfig(
type="sft",
sft=sft_config,
)
args = TrainingArgs(
experiment_name="test-sft",
trial_name="test",
mode="local",
n_nodes=1,
n_gpus_per_node=1,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
trainer=train_config,
)
rollout_controller = None
dataset_factory = DatasetFactory(args)
train_dataset = dataset_factory.make_dataset(args.train_dataset, 0, 1)
train_dataset = train_dataset.select(range(100))
valid_dataset = None
if args.valid_dataset is not None:
valid_dataset = dataset_factory.make_dataset(args.valid_dataset, 0, 1)
valid_dataset = valid_dataset.select(range(100))
if args.trainer is not None:
trainer_factory = TrainerFactory(args)
trainer = trainer_factory.make_trainer(
args.trainer,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
rollout_controller=rollout_controller,
)
trainer.train()
print("All tests passed!")
test_sft()

View File

@ -1,112 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import os
import uuid
import pytest
from arealite.api.cli_args import (
EngineBackendConfig,
EngineConfig,
GenerationHyperparameters,
LLMClientConfig,
OptimizerConfig,
SGLangConfig,
TrainingArgs,
)
from arealite.api.engine_api import EngineFactory
from arealite.api.io_struct import FinetuneSpec, LLMRequest, LLMResponse
from arealite.api.llm_client_api import LLMClient
from arealite.api.llm_server_api import LLMServerFactory
from realhf.base import name_resolve, seeding
EXPR_NAME = "test_sglang_client"
TRIAL_NAME = "test_sglang_client"
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-1.7B/"
@pytest.fixture(scope="module")
def args():
args = TrainingArgs(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
args.rollout.model_path = MODEL_PATH
seeding.set_random_seed(args.seed, EXPR_NAME)
name_resolve.reconfigure(args.cluster.name_resolve)
yield args
name_resolve.reset()
@pytest.fixture(scope="module")
def sglang_server(args):
args.rollout.sglang = SGLangConfig(mem_fraction_static=0.3)
server = LLMServerFactory(args).make_server(args.rollout.llm_service)
server._startup()
yield
server._graceful_exit(0)
@pytest.fixture(scope="module")
def sglang_client(args, sglang_server):
from arealite.system.sglang_client import SGLangClient
args.rollout.server_backend = "sglang"
llm_client = LLMClientConfig()
client = SGLangClient(args, client_config=llm_client)
yield client
@pytest.mark.asyncio
async def test_sglang_generate(sglang_client):
req = LLMRequest(
rid=str(uuid.uuid4()),
text="hello! how are you today",
gconfig=GenerationHyperparameters(max_new_tokens=16),
)
resp = await sglang_client.agenerate(req)
assert isinstance(resp, LLMResponse)
assert resp.input_tokens == req.input_ids
assert (
len(resp.output_logprobs)
== len(resp.output_tokens)
== len(resp.output_versions)
)
assert isinstance(resp.completion, str)
@pytest.mark.asyncio
async def test_sglang_update_weights_from_disk(sglang_client: LLMClient):
servers = sglang_client.get_healthy_servers()
assert len(servers) == 1
await sglang_client.aupdate_weights_from_disk(
server_info=servers[0], path=MODEL_PATH
)
@pytest.fixture(scope="module")
def engine(sglang_server):
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "7777"
engine_config = EngineConfig(
path=MODEL_PATH,
gradient_checkpointing=False,
optimizer=OptimizerConfig(),
backend=EngineBackendConfig(type="fsdp"),
)
mock_args = TrainingArgs(n_nodes=1, n_gpus_per_node=1)
engine_factory = EngineFactory(mock_args)
engine = engine_factory.make_engine(engine_config)
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.init_distributed(None, ft_spec)
print("✓ Engine created successfully")
yield engine
def test_sglang_update_weights_from_distributed(
engine, sglang_server, sglang_client: LLMClient
):
engine.update_weights_to(sglang_client)

File diff suppressed because one or more lines are too long

View File

@ -1,36 +0,0 @@
import random
import torch
from arealite.api.io_struct import Trajectory, TrajStats
def mock_rollout_output(bs, n_samples):
trajs = []
min_seqlen, max_seqlen = 8, 16
for _ in range(bs * n_samples):
input_len = random.randint(min_seqlen, max_seqlen)
prompt_len = random.randint(1, min_seqlen - 1)
input_ids = torch.randint(0, 100, (input_len,))
prompt_mask = torch.tensor([1] * prompt_len + [0] * (input_len - prompt_len))
logprobs = -torch.randn(input_len).abs()
versions = torch.zeros(input_len)
traj = Trajectory(
prompt=None,
data=dict(
input_ids=input_ids.unsqueeze(0),
prompt_mask=prompt_mask.unsqueeze(0),
logprobs=logprobs.unsqueeze(0),
versions=versions.unsqueeze(0),
rewards=torch.tensor([random.random()]),
),
stats=TrajStats(
start_time=0,
total_reward=0,
episode_length=1,
info={},
),
)
trajs.append(traj)
return trajs

View File

@ -1,317 +0,0 @@
Metadata-Version: 2.4
Name: realhf
Version: 0.3.0.dev0
Summary: ReaL: Efficient RLHF Training of Large Language Models with Parameter Reallocation
Keywords: distributed-systems,reinforcement-learning-from-human-feedback,large-language-models,llm-training
Classifier: Development Status :: 2 - Pre-Alpha
Classifier: Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.2
Classifier: Intended Audience :: Developers
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>2.0.0
Requires-Dist: huggingface_hub
Requires-Dist: datasets
Requires-Dist: accelerate
Requires-Dist: transformers==4.51.1
Requires-Dist: numpy<2.0.0
Requires-Dist: scipy
Requires-Dist: pandas
Requires-Dist: matplotlib
Requires-Dist: seaborn
Requires-Dist: h5py
Requires-Dist: nltk
Requires-Dist: sentencepiece
Requires-Dist: einops
Requires-Dist: tqdm
Requires-Dist: rich
Requires-Dist: orjson>=3.10.16
Requires-Dist: pydantic
Requires-Dist: PyYAML
Requires-Dist: hydra-core==1.4.0.dev1
Requires-Dist: packaging
Requires-Dist: tabulate
Requires-Dist: gymnasium>=1.1.1
Requires-Dist: torchdata
Requires-Dist: autoflake
Requires-Dist: tensordict
Requires-Dist: wandb
Requires-Dist: tensorboardx
Requires-Dist: colorama
Requires-Dist: colorlog
Requires-Dist: psutil
Requires-Dist: pynvml
Requires-Dist: swanlab[dashboard]
Requires-Dist: ninja
Requires-Dist: numba
Requires-Dist: blosc
Requires-Dist: pybind11>=2.10.0
Requires-Dist: networkx==3.3
Requires-Dist: aiofiles
Requires-Dist: aiohttp>=3.11.10
Requires-Dist: httpx>=0.28.1
Requires-Dist: pyzmq
Requires-Dist: paramiko
Requires-Dist: etcd3
Requires-Dist: protobuf<3.21
Requires-Dist: ray
Requires-Dist: redis
Requires-Dist: fastapi>=0.115.12
Requires-Dist: uvicorn>=0.34.2
Requires-Dist: uvloop>=0.21.0
Requires-Dist: flask
Requires-Dist: build>=1.2.1
Requires-Dist: wheel>=0.43.0
Requires-Dist: setuptools<75.9,>=62.3.0
Requires-Dist: cookiecutter>2.1.1
Requires-Dist: distro-info>=1.0
Requires-Dist: python-debian>=0.1.49
Requires-Dist: func_timeout
Requires-Dist: regex
Requires-Dist: python_dateutil
Requires-Dist: word2number
Requires-Dist: Pebble
Requires-Dist: timeout-decorator
Requires-Dist: prettytable
Requires-Dist: pytest
Requires-Dist: ipython
Requires-Dist: jupyter-book
Requires-Dist: sphinx
Requires-Dist: sphinx-nefertiti
Requires-Dist: black==25.1.0
Requires-Dist: isort==5.13.2
Requires-Dist: clang-format==19.1.7
Provides-Extra: dev
Requires-Dist: pytest; extra == "dev"
Requires-Dist: black==25.1.0; extra == "dev"
Requires-Dist: isort==5.13.2; extra == "dev"
Requires-Dist: clang-format==19.1.7; extra == "dev"
Provides-Extra: docs
Requires-Dist: sphinx; extra == "docs"
Requires-Dist: sphinx-nefertiti; extra == "docs"
Requires-Dist: jupyter-book; extra == "docs"
Dynamic: license-file
<h1 align="center">
<em>AReaL</em>: Ant Reasoning Reinforcement Learning for LLMs
</h1>
<p align="center">
| <a href="https://arxiv.org/pdf/2505.24298"><b>Paper</b></a> | <a href="https://inclusionai.github.io/AReaL/"><b>Documentation</b></a> | <a href="https://deepwiki.com/inclusionAI/AReaL"><b>Ask DeepWiki</b></a> | <a href="https://huggingface.co/collections/inclusionAI/areal-boba-2-683f0e819ccb7bb2e1b2f2d5"><b>🤗 Models & Data</b></a> |
<a href="./assets/wechat_qrcode.png" target="_blank"><b>WeChat Group</b></a> |
</p>
<img align="right" alt="ReaL" src="/assets/logo.png" width="20%">
AReaL (Ant Reasoning RL) is an open-source **fully asynchronous reinforcement learning training system** for large reasoning models developed at **the RL Lab, Ant Research**. Built upon the open-source project [RealHF](https://github.com/openpsi-project/ReaLHF), we are fully committed to open-source by providing training details, data, and infrastructure required to reproduce results along with the model itself. AReaL aims to help everyone build their own AI agents easily and affordably. Our team loves milk tea because it's delicious, customizable, and affordable. We hope you enjoy our project just like how you enjoy real-world milk tea (cheers).
**AReaL Highlights**
+ 🔥 <span style="color: red; font-weight: bold;">**[NEW] Asynchronous RL:**</span> With algorithm-system co-design, AReaL supports fully asynchronous RL for **the fastest training**! Experimental support for multi-turn agentic RL is also provided.
+ 🛠️ **Open & Reproducible**: We continuously release _all code, datasets, and training recipes_ for RL training of LLMs.
+ 🚀 **Scalability**: AReaL can seamlessly adapt to different computational resource settings, ranging from a single node to 1K GPUs.
+ 🔪 **Cutting-Edge Performance:** AReaL can produce models with cutting-edge reasoning capabilities in math and coding. We are also actively working on agentic tasks.
## News
**[2025/06/03] (v0.3, boba²)** We release **boba²** (double-boba) for fully asynchronous RL training, which achieves a **2.77x speedup while obtaining on-par or even better training performance** compared to synchronous systems. Moreover, asynchronous RL makes it extremely easy to set up multi-turn agentic RL training! Check out [our v0.3 overview blog](/blog/AReaL_v0_3.md) and the [research paper](https://arxiv.org/pdf/2505.24298).
**[2025/03/31] (v0.2, boba)** Here comes our next milestone release - boba! Please call it A-ReaL-boba! This release includes much faster training with SGLang support and SOTA 7B and 32B models on math reasoning. Check our [v0.2 technical blog](/blog/AReaL_v0_2.md).
**[2025/02/24] (v0.1)** Our initial release includes reproducible results for 1.5B and 7B LRMs. Check our [v0.1 technical blog](/blog/AReaL_v0_1.md).
## Release Highlights
In our AReaL-boba² (A-ReaL-double-boba) release, we highlight the top 3 most important features:
+ A fully asynchronous RL training pipeline with **system and RL algorithm co-design**, achieving over 2.77x speedup without any performance drop. Check the [benchmark scripts and instructions here](https://github.com/inclusionAI/AReaL/tree/main/benchmark/verl_v0_3_0_post1_76084d3).
+ SOTA coding models, i.e., a 14B model with a **69.1 score on LCB-v5**. To reproduce, check the [configs](https://github.com/inclusionAI/AReaL/tree/main/examples/configs/v0.3-qwen3-code) and [instructions](https://inclusionai.github.io/AReaL/references/reproduce.html).
+ Experimental support for **multi-turn** agentic RL training. Check our [complete example](https://inclusionai.github.io/AReaL/customization/agent.html).
For the complete system design and more training details, please check [our v0.3 blog](/blog/AReaL_v0_3.md) and our [research paper](https://arxiv.org/pdf/2505.24298).
**Jump to the [quickstart section](https://github.com/inclusionAI/AReaL?tab=readme-ov-file#getting-started) if you want to quickly run an experiment and get your hands dirty!** 😈
### Overview of Asynchronous RL Training
During the synchronous RL training process, a generation step must wait until the longest sequence completes within the batch of LLM outputs. Due to the varying output lengths for LRMs, a synchronous RL system suffers from massive GPU idle time, leading to training inefficiency. Some recent works ([DeepCoder](https://pretty-radio-b75.notion.site/DeepCoder-A-Fully-Open-Source-14B-Coder-at-O3-mini-Level-1cf81902c14680b3bee5eb349a512a51), [Intellect](https://www.primeintellect.ai/blog/intellect-2)) propose overlapping a single training step with a single generation step to accelerate training. However, the largest bottleneck remains unchanged: the samples within a batch are still from the same model version, leading to waiting and GPU idle time.
![Synchronous vs One-step Overlap RL](/assets/sync_one_step_gen.png)
*Fig.1. Left: Execution timeline of synchronous RL training. Right: Execution timeline of one-step overlap RL system.*
AReaL adopts a fully asynchronous RL training framework that completely decouples generation from training. In AReaL, LLM generation runs in a streaming manner, with each rollout worker continuously producing outputs without waiting. Meanwhile, trainer workers perform parallel model updates upon receiving training batches.
![Asynchronous RL Training](/assets/async_timeline.png)
*Fig 2. Execution timeline of our fully asynchronous RL system.*
AReaL follows a system-algorithm co-design principle: on the system side, AReaL efficiently syncs model parameters and carefully controls the staleness of each training sample; on the algorithm side, AReaL improves the objective of PPO to make async-RL stable.
We compare the scalability of **asynchronous RL** training based on our AReaL-boba² system with **classical synchronous RL** training (we adopt the fastest open-source system veRL, main branch on 05/07/2025) across different model sizes and different numbers of H800 GPUs. AReaL demonstrates much improved scaling capabilities with respect to training throughput. This is also partially due to AReaL decoupling training and generation, leading to much fewer GPU memory fragments.
![Scaling Comparison](/assets/async_scaling_vs_verl.png)
*Fig.3 The scaling trend of asynchronous RL (based on AReaL-boba2) and classical synchronous RL (based on veRL) with different model sizes. Dotted lines indicate ideal linear scaling.*
### SOTA Code Generation Model by AReaL-boba²
We use **Qwen3** as our base model. After asynchronous RL training, we achieve SOTA results on LiveCodeBench, Codeforces, and CodeContests benchmarks.
| **Model (8B)** | **LiveCodeBench v5**<br/>**(2024.10-2025.2)** | **Codeforces** | **CodeContests** |
| :---: | :---: | :---: | :---: |
| Qwen3-8B | 58.8 | 1879/96.7% | 31.4 |
| DeepSeek-R1-0528-Qwen3-8B | 58.4 | 1945/97.3% | 31.0 |
| [🤗 AReaL-boba²-8B-Open](https://huggingface.co/inclusionAI/AReaL-boba-2-8B-subset) | 62.0 | 1933/97.2% | **41.4** |
| [🤗 AReaL-boba²-8B](https://huggingface.co/inclusionAI/AReaL-boba-2-8B) | **63.0** | **1962/97.5%** | 40.8 |
| **Model (14B)** | **LiveCodeBench v5**<br/>**(2024.10-2025.2)** | **Codeforces** | **CodeContests** |
| :---: | :---: | :---: | :---: |
| Qwen3-14B | 65.4 | 1978/97.7% | 38.3 |
| DeepCoder-14B-Preview | 60.6 | 1936/95.3% | 40.1 |
| [🤗 AReaL-boba²-14B-Open](https://huggingface.co/inclusionAI/AReaL-boba-2-14B-subset) | 67.3 | 1990/97.8% | **46.2** |
| [🤗 AReal-boba²-14B](https://huggingface.co/inclusionAI/AReaL-boba-2-14B) | **69.1** | **2044/98.2%** | 46.1 |
| **Larger Models** | **LiveCodeBench v5**<br/>**(2024.10-2025.2)** | **Codeforces** | **CodeContests** |
| :---: | :---: | :---: | :---: |
| Qwen3-235B | 70.7 | 2056 | - |
| DeepSeek-R1 | 64.3 | 2029 | - |
| OpenAI-o3-mini (Medium) | 66.3 | 2036 | - |
*Table 1: Coding Task Performance Comparison. AReaL-boba²-8B/14B-Open denotes training results on open-source data. AReaL-boba²-8B/14B models are trained with an additional small amount of internal data and achieve SOTA performance on LiveCodeBench, Codeforces & CodeContests.*
We highlight the [tutorials](https://inclusionai.github.io/AReaL/customization/dataset.html) and [code walkthroughs](https://inclusionai.github.io/AReaL/developer/overview.html) about the following key features for asynchronous training:
+ [Streaming generation and reward computation](https://inclusionai.github.io/AReaL/developer/rollout/rollout_worker.html)
+ [Interruptible rollout](https://inclusionai.github.io/AReaL/developer/rollout/gserver.html)
+ [Data staleness control with the rollout controller](https://inclusionai.github.io/AReaL/developer/rollout/gserver.html)
+ [The adoption of decoupled PPO loss](https://inclusionai.github.io/AReaL/customization/algorithm.html)
### RL Training for Multi-turn Agent
AReaL-boba² allows you to independently customize the [dataset](https://inclusionai.github.io/AReaL/customization/dataset.html), [rollout behavior](https://inclusionai.github.io/AReaL/customization/agent.html), and the [training algorithm](https://inclusionai.github.io/AReaL/customization/algorithm.html), without needing to modify the heavy system-level code.
In particular, we show a simple example to develop a multi-turn math agent for RL training. Please see the learning curve below and reference the [step-by-step guide](https://inclusionai.github.io/AReaL/customization/agent.html) if you want to implement your own agentic RL project.
## Getting Started
Obtain the training data:
- [Math](https://huggingface.co/datasets/inclusionAI/AReaL-boba-Data)
- [Code](https://huggingface.co/datasets/inclusionAI/AReaL-boba-2-RL-Code)
For code training data, a simple preprocessing script was provided in `examples/data_preprocess/preprocess_training_data.py`:
```bash
python3 preprocess_training_data.py --data_path $original_data_path --output_path $training_data_path
```
Train Qwen3 1.7B locally (Remember to modify `dataset.path` in the script below):
```bash
bash examples/run_async_ppo.sh
```
Evaluation:
```bash
cd evaluation
# Evaluate the model
python eval_and_aggregate.py \
--model_path ${MODEL_PATH} \
--output_path ${OUTPUT_PATH} \
--data_names aime24,aime25 \
--max_gen_tokens 32768 \
--data_names codeforces,lcb_v5 \
--prompt_type qwen3-think-pure \
--temperature 1.0
```
## Resources
+ [Documentation](https://inclusionai.github.io/AReaL/)
+ [Contributing](https://inclusionai.github.io/AReaL/contrib.html)
### Quickstart
+ [Installation](https://inclusionai.github.io/AReaL/tutorial/installation.html)
+ [Example: Improving the math capability of Qwen3 with PPO](https://inclusionai.github.io/AReaL/tutorial/quickstart.html)
### Benchmark and Reproduction
+ **Reproduce boba² Code Models**
- 🤗 **Model weights**: [8B-code](https://huggingface.co/inclusionAI/AReaL-boba-2-8B), [14B-code](https://huggingface.co/inclusionAI/AReaL-boba-2-14B), [8B-code-open](https://huggingface.co/inclusionAI/AReaL-boba-2-8B-subset), [14B-code-open](https://huggingface.co/inclusionAI/AReaL-boba-2-14B-subset)
- [Evaluation Guide](https://inclusionai.github.io/AReaL/tutorial/eval.html)
- [Training configs](https://github.com/inclusionAI/AReaL/tree/main/examples/configs/v0.3-qwen3-code) and [instructions](https://inclusionai.github.io/AReaL/references/reproduce.html)
+ [Scripts for Benchmark Training Throughput](https://github.com/inclusionAI/AReaL/tree/main/benchmark/verl_v0_3_0_post1_76084d3)
### Customization Guide
- [Use your own dataset](https://inclusionai.github.io/AReaL/customization/dataset.html)
- [Modifying the reward function and rollout behavior (multi-turn agentic RL)](https://inclusionai.github.io/AReaL/customization/agent.html)
- [Modifying PPO to GRPO](https://inclusionai.github.io/AReaL/customization/algorithm.html#grouped-advantage-normalization)
- [Developing the decoupled PPO loss](https://inclusionai.github.io/AReaL/customization/algorithm.html#the-decoupled-ppo-loss)
### System Code Walkthrough
+ [Trainer](https://inclusionai.github.io/AReaL/developer/trainer/model_worker.html)
+ [Model Backend and Algorithm Interface](https://inclusionai.github.io/AReaL/developer/trainer/algo_interface.html)
+ [Rollout Controller](https://inclusionai.github.io/AReaL/developer/rollout/gserver.html)
+ [Streaming generation and reward computation](https://inclusionai.github.io/AReaL/developer/rollout/rollout_worker.html)
## Future Plan
AReaL is under active development. We plan to have minor releases weekly and major releases monthly. Community engagement and contributions are extremely welcome. We are also **hiring interns and full-time employees** with open positions in both the US and China.
For the research and development plan already in place, please see the following list:
### System Development
- [x] Support for SGLang
- [x] RL training with coding problems
- [x] Asynchronous generation and RL training
- [ ] Optimizations for distributed training: expert parallel for MOE and zero-bubble pipelining
- [ ] RL for vision-language models (VLM)
- [x] Multi-turn agentic RL
- [ ] Function calling and tool use
### Algorithm Development
- [x] RL training recipes for 1.5B and 7B models
- [x] A complete RL training recipe for 32B models
- [ ] Sample-efficient multi-task RL algorithms
- [ ] Agentic capabilities with end-to-end RL
- [ ] Stable RL training for larger MOE models
## Acknowledgement
We would like to note that major contributors are from the RL Lab at Ant Research and the Institute for Interdisciplinary Information Sciences, Tsinghua University.
Our team has also received invaluable assistance from the Data Intelligence Lab at Ant Research for data support and from the Super Computing Technology (SCT) team at Ant Group, particularly in the realm of large-scale cluster operations and maintenance.
We also appreciate all the pioneering works from the community, particularly the [ReaLHF](https://github.com/openpsi-project/ReaLHF) project from OpenPsi Inc. and other projects, including but not limited to [DeepScaleR](https://github.com/agentica-project/deepscaler), [Open-Reasoner-Zero](https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero/tree/main), [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF), [VeRL](https://github.com/volcengine/verl), [SGLang](https://github.com/sgl-project/sglang), [QwQ](https://github.com/QwenLM/QwQ), [Light-R1](https://github.com/Qihoo360/Light-R1) and [DAPO](https://github.com/BytedTsinghua-SIA/DAPO).
## Citation
```bibtex
@inproceedings{mei2025real,
author = {Mei, Zhiyu and Fu, Wei and Li, Kaiwei and Wang, Guangju and Zhang, Huanchen and Wu, Yi},
title = {ReaL: Efficient RLHF Training of Large Language Models with Parameter Reallocation},
booktitle = {Proceedings of the Eighth Conference on Machine Learning and Systems,
MLSys 2025, Santa Clara, CA, USA, May 12-15, 2025},
publisher = {mlsys.org},
year = {2025},
}
```
```bibtex
@misc{fu2025areal,
title={AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning},
author={Wei Fu and Jiaxuan Gao and Xujie Shen and Chen Zhu and Zhiyu Mei and Chuyi He and Shusheng Xu and Guo Wei and Jun Mei and Jiashu Wang and Tongkai Yang and Binhang Yuan and Yi Wu},
year={2025},
eprint={2505.24298},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2505.24298},
}
```

View File

@ -1,349 +0,0 @@
LICENSE
MANIFEST.in
README.md
pyproject.toml
setup.py
arealite/ppo_functional.py
arealite/api/cli_args.py
arealite/api/engine_api.py
arealite/api/env_api.py
arealite/api/io_struct.py
arealite/api/llm_client_api.py
arealite/api/llm_server_api.py
arealite/api/reward_api.py
arealite/api/rollout_api.py
arealite/api/trainer_api.py
arealite/api/vlm_client_api.py
arealite/api/workflow_api.py
arealite/cli/launch_server.py
arealite/cli/launch_trainer.py
arealite/dataset/__init__.py
arealite/engine/__init__.py
arealite/engine/constant.py
arealite/engine/fsdp_engine.py
arealite/engine/hf_engine.py
arealite/engine/sglang_engine.py
arealite/engine/sglang_remote.py
arealite/engine/vl_fsdp_engine.py
arealite/engine/sft/lm_engine.py
arealite/env/__init__.py
arealite/impl/engine/fsdp_wrapper.py
arealite/impl/engine/hf_wrapper.py
arealite/impl/rlvr/rlvr_collector.py
arealite/impl/rlvr/vl_rlvr_collector.py
arealite/impl/rlvr/rewards/areal_code.py
arealite/impl/rlvr/rewards/areal_math.py
arealite/impl/rlvr/rewards/clevr_count_70k.py
arealite/impl/rlvr/rewards/gsm8k.py
arealite/impl/trainer/grpo.py
arealite/impl/trainer/sft.py
arealite/impl/trainer/vl_grpo.py
arealite/impl/trainer/vl_sft.py
arealite/launcher/with_ray.py
arealite/launcher/with_scheduler.py
arealite/reward/__init__.py
arealite/system/rollout_controller.py
arealite/system/rollout_worker.py
arealite/system/sglang_client.py
arealite/system/sglang_server.py
arealite/system/vl_sglang_client.py
arealite/tests/test_engine.py
arealite/tests/test_fsdp_engine.py
arealite/tests/test_grpo.py
arealite/tests/test_rlvr_workflow.py
arealite/tests/test_rollout.py
arealite/tests/test_rollout_controller.py
arealite/tests/test_sft.py
arealite/tests/test_sglang_client.py
arealite/tests/test_sglang_engine.py
arealite/tests/test_utils.py
arealite/tests/test_vlm_grpo.py
arealite/tests/test_vlm_sft.py
arealite/tests/test_wrapper.py
arealite/tests/utils.py
arealite/utils/__init__.py
arealite/utils/data.py
arealite/utils/evaluator.py
arealite/utils/fs.py
arealite/utils/fsdp.py
arealite/utils/functional.py
arealite/utils/padding.py
arealite/utils/save_load.py
arealite/utils/saver.py
arealite/utils/stats_logger.py
arealite/utils/wrapper.py
arealite/workflow/rlvr.py
benchmark/verl_v0_3_0_post1_76084d3/build_cmd.py
csrc/cugae/gae.cu
csrc/interval_op/interval_op.cpp
csrc/interval_op/interval_op.cu
evaluation/aggregate_acc_from_generated.py
evaluation/cf_elo_caculator.py
evaluation/code_eval.py
evaluation/data_loader.py
evaluation/eval_and_aggregate.py
evaluation/evaluate.py
evaluation/examples.py
evaluation/grader.py
evaluation/math_eval.py
evaluation/math_utils.py
evaluation/model_utils.py
evaluation/parser.py
evaluation/python_executor.py
evaluation/rm_maj_eval.py
evaluation/trajectory.py
evaluation/utils.py
evaluation/code_verifier/local_verify.py
evaluation/code_verifier/testing_util.py
evaluation/latex2sympy/__init__.py
evaluation/latex2sympy/asciimath_printer.py
evaluation/latex2sympy/latex2sympy2.py
evaluation/latex2sympy/setup.py
evaluation/latex2sympy/gen/PSLexer.py
evaluation/latex2sympy/gen/PSListener.py
evaluation/latex2sympy/gen/PSParser.py
evaluation/latex2sympy/gen/__init__.py
evaluation/latex2sympy/sandbox/linalg_equations.py
evaluation/latex2sympy/sandbox/linalg_span.py
evaluation/latex2sympy/sandbox/matrix.py
evaluation/latex2sympy/sandbox/matrix_placeholders.py
evaluation/latex2sympy/sandbox/sandbox.py
evaluation/latex2sympy/sandbox/sandbox_equality.py
evaluation/latex2sympy/sandbox/sectan.py
evaluation/latex2sympy/sandbox/vector.py
evaluation/latex2sympy/tests/__init__.py
evaluation/latex2sympy/tests/abs_test.py
evaluation/latex2sympy/tests/all_bad_test.py
evaluation/latex2sympy/tests/all_good_test.py
evaluation/latex2sympy/tests/atom_expr_test.py
evaluation/latex2sympy/tests/binomial_test.py
evaluation/latex2sympy/tests/ceil_test.py
evaluation/latex2sympy/tests/complex_test.py
evaluation/latex2sympy/tests/context.py
evaluation/latex2sympy/tests/exp_test.py
evaluation/latex2sympy/tests/floor_test.py
evaluation/latex2sympy/tests/gcd_test.py
evaluation/latex2sympy/tests/greek_test.py
evaluation/latex2sympy/tests/grouping_test.py
evaluation/latex2sympy/tests/lcm_test.py
evaluation/latex2sympy/tests/left_right_cdot_test.py
evaluation/latex2sympy/tests/linalg_test.py
evaluation/latex2sympy/tests/max_test.py
evaluation/latex2sympy/tests/min_test.py
evaluation/latex2sympy/tests/mod_test.py
evaluation/latex2sympy/tests/overline_test.py
evaluation/latex2sympy/tests/pi_test.py
evaluation/latex2sympy/tests/trig_test.py
evaluation/latex2sympy/tests/variable_test.py
examples/arealite/gsm8k_sft.py
examples/arealite/dataset/clevr_count_70k.py
examples/arealite/dataset/gsm8k.py
examples/data_preprocess/codeforce_process.py
examples/data_preprocess/math_code_process.py
examples/data_preprocess/math_process.py
examples/data_preprocess/preprocess_training_data.py
examples/env/setup_env_and_start_train.py
examples/env/validate_installation.py
functioncall/__init__.py
functioncall/base/__init__.py
functioncall/base/call.py
functioncall/base/utils.py
functioncall/code/__init__.py
functioncall/code/local_verify.py
functioncall/code/verify.py
functioncall/code/function/handler.py
functioncall/code/function/testing_util.py
functioncall/math/__init__.py
functioncall/math/verify.py
functioncall/math/function/grader.py
functioncall/math/function/handler.py
functioncall/math/function/parser.py
functioncall/test/performance_eval.py
realhf/__init__.py
realhf/utils.py
realhf/version.py
realhf.egg-info/PKG-INFO
realhf.egg-info/SOURCES.txt
realhf.egg-info/dependency_links.txt
realhf.egg-info/requires.txt
realhf.egg-info/top_level.txt
realhf/api/cli_args.py
realhf/api/core/agent_api.py
realhf/api/core/config.py
realhf/api/core/data_api.py
realhf/api/core/dfg.py
realhf/api/core/env_api.py
realhf/api/core/model_api.py
realhf/api/core/system_api.py
realhf/api/from_hf/__init__.py
realhf/api/from_hf/gemma.py
realhf/api/from_hf/gpt2.py
realhf/api/from_hf/llama.py
realhf/api/from_hf/mistral.py
realhf/api/from_hf/mixtral.py
realhf/api/from_hf/qwen2.py
realhf/api/from_hf/qwen3.py
realhf/api/quickstart/__init__.py
realhf/api/quickstart/device_mesh.py
realhf/api/quickstart/entrypoint.py
realhf/api/quickstart/search.py
realhf/apps/__init__.py
realhf/apps/main.py
realhf/apps/quickstart.py
realhf/apps/remote.py
realhf/base/__init__.py
realhf/base/cluster.py
realhf/base/constants.py
realhf/base/datapack.py
realhf/base/gpu_utils.py
realhf/base/importing.py
realhf/base/logging.py
realhf/base/monitor.py
realhf/base/name_resolve.py
realhf/base/names.py
realhf/base/network.py
realhf/base/numpy_utils.py
realhf/base/pkg_version.py
realhf/base/prologue.py
realhf/base/ray_utils.py
realhf/base/recover.py
realhf/base/saveload_utils.py
realhf/base/security.py
realhf/base/seeding.py
realhf/base/slurm_utils.py
realhf/base/stats_tracker.py
realhf/base/testing.py
realhf/base/timeutil.py
realhf/base/topology.py
realhf/experiments/async_exp/async_ppo_math_exp.py
realhf/experiments/async_exp/async_rl_exp.py
realhf/experiments/common/check.py
realhf/experiments/common/common.py
realhf/experiments/common/math_code_eval_exp.py
realhf/experiments/common/null_exp.py
realhf/experiments/common/ppo_math_exp.py
realhf/experiments/common/sft_exp.py
realhf/experiments/common/utils.py
realhf/impl/agent/__init__.py
realhf/impl/agent/math_multi_turn_agent.py
realhf/impl/agent/math_single_step_agent.py
realhf/impl/agent/null_agent.py
realhf/impl/dataset/__init__.py
realhf/impl/dataset/math_code_dataset.py
realhf/impl/dataset/math_parser.py
realhf/impl/dataset/prompt_answer_dataset.py
realhf/impl/dataset/prompt_dataset.py
realhf/impl/dataset/rw_paired_dataset.py
realhf/impl/environment/__init__.py
realhf/impl/environment/math_code_single_step_env.py
realhf/impl/model/__init__.py
realhf/impl/model/backend/inference.py
realhf/impl/model/backend/megatron.py
realhf/impl/model/backend/mock_train.py
realhf/impl/model/backend/pipe_runner.py
realhf/impl/model/backend/sglang.py
realhf/impl/model/backend/vllm.py
realhf/impl/model/backend/thirdparty/megatron/__init__.py
realhf/impl/model/backend/thirdparty/megatron/v0_6_0/lr_schduler.py
realhf/impl/model/backend/thirdparty/vllm/__init__.py
realhf/impl/model/backend/thirdparty/vllm/context.py
realhf/impl/model/backend/thirdparty/vllm/custom_cache_manager.py
realhf/impl/model/backend/thirdparty/vllm/engine.py
realhf/impl/model/backend/thirdparty/vllm/executor.py
realhf/impl/model/comm/global_comm.py
realhf/impl/model/comm/param_realloc.py
realhf/impl/model/conversion/hf_registry.py
realhf/impl/model/interface/fused_interface.py
realhf/impl/model/interface/math_rw_interface.py
realhf/impl/model/interface/ppo_interface.py
realhf/impl/model/interface/sft_interface.py
realhf/impl/model/modules/__init__.py
realhf/impl/model/modules/activations.py
realhf/impl/model/modules/attn.py
realhf/impl/model/modules/embedding.py
realhf/impl/model/modules/mlp.py
realhf/impl/model/modules/rms.py
realhf/impl/model/modules/rotary.py
realhf/impl/model/modules/moe/__init__.py
realhf/impl/model/modules/moe/experts.py
realhf/impl/model/modules/moe/layer.py
realhf/impl/model/modules/moe/router.py
realhf/impl/model/modules/moe/token_dispatcher.py
realhf/impl/model/nn/flatten_param.py
realhf/impl/model/nn/real_llm_api.py
realhf/impl/model/nn/real_llm_base.py
realhf/impl/model/nn/real_llm_generate.py
realhf/impl/model/nn/real_llm_parallel.py
realhf/impl/model/parallelism/pipeline_parallel/instruction.py
realhf/impl/model/parallelism/pipeline_parallel/p2p.py
realhf/impl/model/parallelism/pipeline_parallel/static_schedule.py
realhf/impl/model/parallelism/pipeline_parallel/tensor_storage.py
realhf/impl/model/parallelism/tensor_parallel/mappings.py
realhf/impl/model/parallelism/tensor_parallel/modules.py
realhf/impl/model/parallelism/tensor_parallel/utils.py
realhf/impl/model/utils/cuda_graph.py
realhf/impl/model/utils/dpo_functional.py
realhf/impl/model/utils/functional.py
realhf/impl/model/utils/logits_warper.py
realhf/impl/model/utils/moe.py
realhf/impl/model/utils/padding.py
realhf/impl/model/utils/ppo_functional.py
realhf/impl/model/utils/random.py
realhf/scheduler/client.py
realhf/scheduler/evaluator.py
realhf/scheduler/local/client.py
realhf/scheduler/slurm/client.py
realhf/scheduler/slurm/utils.py
realhf/system/__init__.py
realhf/system/buffer.py
realhf/system/controller.py
realhf/system/data_manager.py
realhf/system/flops_counter.py
realhf/system/function_executor.py
realhf/system/generation_server.py
realhf/system/gserver_manager.py
realhf/system/master_worker.py
realhf/system/model_function_call.py
realhf/system/model_worker.py
realhf/system/partial_rollout.py
realhf/system/push_pull_stream.py
realhf/system/redistributor.py
realhf/system/request_reply_stream.py
realhf/system/rollout_worker.py
realhf/system/stream_dataset.py
realhf/system/worker_base.py
realhf/system/worker_control.py
tests/__init__.py
tests/fixtures.py
tests/agent/test_math_single_step_agent.py
tests/comm/test_data_transfer.py
tests/comm/test_param_realloc.py
tests/cpp_extensions/test_cugae.py
tests/cpp_extensions/test_grouped_gemm.py
tests/cpp_extensions/test_interval_ops.py
tests/data/test_dfg.py
tests/data/test_dual_clip.py
tests/data/test_epoch_counter.py
tests/data/test_load_data.py
tests/data/test_sequence_gather_split.py
tests/data/test_stats_tracker.py
tests/distributed/test_find_port.py
tests/distributed/test_name_resolve.py
tests/experiments/test_buffer_recover.py
tests/experiments/test_math_ppo.py
tests/experiments/test_sft.py
tests/experiments/utils.py
tests/interfaces/test_multi_task_reward.py
tests/legacy/test_sglang_tp.py
tests/legacy/test_vllm_tp.py
tests/model/test_cpu_inference.py
tests/model/test_distributed_load_hf.py
tests/reward/test_math_reward.py
tests/system/test_gserver_manager.py
tests/system/test_partial_rollout.py
tests/system/test_push_pull_stream.py
tests/system/test_stream_dataset.py
training/main_async_ppo.py
training/main_sft.py
training/main_sync_ppo.py
training/utils.py

View File

@ -1 +0,0 @@

View File

@ -1,83 +0,0 @@
torch>2.0.0
huggingface_hub
datasets
accelerate
transformers==4.51.1
numpy<2.0.0
scipy
pandas
matplotlib
seaborn
h5py
nltk
sentencepiece
einops
tqdm
rich
orjson>=3.10.16
pydantic
PyYAML
hydra-core==1.4.0.dev1
packaging
tabulate
gymnasium>=1.1.1
torchdata
autoflake
tensordict
wandb
tensorboardx
colorama
colorlog
psutil
pynvml
swanlab[dashboard]
ninja
numba
blosc
pybind11>=2.10.0
networkx==3.3
aiofiles
aiohttp>=3.11.10
httpx>=0.28.1
pyzmq
paramiko
etcd3
protobuf<3.21
ray
redis
fastapi>=0.115.12
uvicorn>=0.34.2
uvloop>=0.21.0
flask
build>=1.2.1
wheel>=0.43.0
setuptools<75.9,>=62.3.0
cookiecutter>2.1.1
distro-info>=1.0
python-debian>=0.1.49
func_timeout
regex
python_dateutil
word2number
Pebble
timeout-decorator
prettytable
pytest
ipython
jupyter-book
sphinx
sphinx-nefertiti
black==25.1.0
isort==5.13.2
clang-format==19.1.7
[dev]
pytest
black==25.1.0
isort==5.13.2
clang-format==19.1.7
[docs]
sphinx
sphinx-nefertiti
jupyter-book

View File

@ -1,14 +0,0 @@
arealite
assets
benchmark
blog
ci
csrc
docs
evaluation
examples
functioncall
patch
realhf
tests
training