mirror of https://github.com/inclusionAI/AReaL
0721_merge7
This commit is contained in:
parent
c29561498e
commit
f451dbd692
|
@ -1,148 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import abc
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, SupportsFloat
|
||||
|
||||
from gymnasium import Env
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.utils import seeding
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
GenerationHyperparameters,
|
||||
RolloutCollectorConfig,
|
||||
TrainingArgs,
|
||||
)
|
||||
from arealite.api.io_struct import AgentInferInput, AgentInferOutput, Trajectory
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
|
||||
|
||||
class Agent(abc.ABC):
|
||||
def __init__(self, args: TrainingArgs):
|
||||
self.args = args
|
||||
|
||||
async def aact(self, inp: AgentInferInput) -> AgentInferOutput:
|
||||
"""Async version of act. Given an observation, return an action and data used for RL training."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Async version of reset. Resets the agent's memory."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# Re-export the gymnasium environment class
|
||||
class Environment(abc.ABC, Env):
|
||||
def __init__(self, args: TrainingArgs):
|
||||
self.args = args
|
||||
|
||||
@abc.abstractmethod
|
||||
def step(
|
||||
self, action: ActType
|
||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
|
||||
# Initialize the RNG if the seed is manually passed
|
||||
if seed is not None:
|
||||
self._np_random, self._np_random_seed = seeding.np_random(seed)
|
||||
|
||||
|
||||
class RolloutCollector(abc.ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArgs,
|
||||
config: RolloutCollectorConfig,
|
||||
agent: Agent | None = None,
|
||||
env: Environment | None = None,
|
||||
reward_func: Callable | None = None,
|
||||
):
|
||||
self.args = args
|
||||
self.config = config
|
||||
|
||||
# Used in agentic scenarios
|
||||
self.agent = agent
|
||||
self.env = env
|
||||
|
||||
# Used in RLVR
|
||||
self.reward_func = reward_func
|
||||
|
||||
async def arun_episode(
|
||||
self,
|
||||
llm_client: LLMClient,
|
||||
gconfig: GenerationHyperparameters,
|
||||
env_option: Optional[Any] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> Trajectory:
|
||||
"""Async version of run_episode. Run a single episode and return the trajectory."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutCollectorFactory:
|
||||
args: TrainingArgs
|
||||
|
||||
def make_collector(self, config: RolloutCollectorConfig) -> RolloutCollector:
|
||||
if config.type == "rlvr":
|
||||
from arealite.impl.rlvr.rlvr_collector import RlvrCollector
|
||||
|
||||
rlvr_config = config.rlvr
|
||||
assert rlvr_config is not None
|
||||
if rlvr_config.reward_type == "areal-math":
|
||||
from arealite.impl.rlvr.rewards.areal_math import math_reward
|
||||
|
||||
reward_fn = functools.partial(
|
||||
math_reward, dataset_path=rlvr_config.solution_path
|
||||
)
|
||||
elif rlvr_config.reward_type == "areal-code":
|
||||
from arealite.impl.rlvr.rewards.areal_code import code_reward
|
||||
|
||||
reward_fn = functools.partial(
|
||||
code_reward, dataset_path=rlvr_config.solution_path
|
||||
)
|
||||
elif rlvr_config.reward_type == "gsm8k":
|
||||
from arealite.impl.rlvr.rewards.gsm8k import (
|
||||
gsm8k_reward_fn as reward_fn,
|
||||
)
|
||||
elif rlvr_config.reward_type == "clevr_count_70k":
|
||||
from arealite.impl.rlvr.rewards.clevr_count_70k import (
|
||||
clevr_count_70k_reward_fn as reward_fn,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown reward type: {rlvr_config.reward_type}"
|
||||
)
|
||||
|
||||
return RlvrCollector(
|
||||
self.args,
|
||||
config=config,
|
||||
reward_fn=reward_fn,
|
||||
)
|
||||
if config.type == "math_code_single_step":
|
||||
from arealite.impl.agentic.math_code_single_step import (
|
||||
MathCodeAgent,
|
||||
MathCodeSingleStepCollector,
|
||||
MathCodeSingleStepEnv,
|
||||
)
|
||||
|
||||
agent = MathCodeAgent(self.args)
|
||||
env = MathCodeSingleStepEnv(
|
||||
self.args,
|
||||
solution_path=config.math_code_single_step.solution_path,
|
||||
)
|
||||
|
||||
return MathCodeSingleStepCollector(
|
||||
self.args,
|
||||
config=config,
|
||||
agent=agent,
|
||||
env=env,
|
||||
)
|
||||
raise NotImplementedError(f"Unknown agent type: {config.type}")
|
|
@ -1,156 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import abc
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
from datasets import Dataset
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import TrainerConfig, TrainingArgs
|
||||
from realhf.base import constants
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arealite.system.rollout_controller import RolloutController
|
||||
|
||||
# 4. use huggingface.trainerstate
|
||||
# TODO: how to do checkpointing?
|
||||
|
||||
# follow the signature of transformers.Trainer if possible
|
||||
|
||||
|
||||
|
||||
|
||||
class Trainer(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArgs,
|
||||
trainer_config: TrainerConfig,
|
||||
train_dataset: Dataset,
|
||||
valid_dataset: Optional[Dataset] = None,
|
||||
rollout_controller: Optional["RolloutController"] = None,
|
||||
):
|
||||
self.args = args
|
||||
self.trainer_config = trainer_config
|
||||
|
||||
self.train_dataset = train_dataset
|
||||
self.valid_dataset = valid_dataset
|
||||
|
||||
self.rollout_controller = rollout_controller
|
||||
|
||||
self.train_dataloader = None
|
||||
self.valid_dataloader = None
|
||||
|
||||
def create_train_dataloader(self):
|
||||
cfg = self.args.train_dataset
|
||||
if dist.is_initialized():
|
||||
batch_size = cfg.batch_size // dist.get_world_size()
|
||||
else:
|
||||
batch_size = cfg.batch_size
|
||||
self.train_dataloader = StatefulDataLoader(
|
||||
dataset=self.train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=cfg.shuffle,
|
||||
pin_memory=cfg.pin_memory,
|
||||
num_workers=cfg.num_workers,
|
||||
drop_last=True,
|
||||
collate_fn=lambda x: x
|
||||
|
||||
)
|
||||
|
||||
def create_valid_dataloader(self):
|
||||
if self.args.valid_dataset is None:
|
||||
return
|
||||
cfg = self.args.valid_dataset
|
||||
if dist.is_initialized():
|
||||
batch_size = cfg.batch_size // dist.get_world_size()
|
||||
else:
|
||||
batch_size = cfg.batch_size
|
||||
self.valid_dataloader = StatefulDataLoader(
|
||||
dataset=self.valid_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=cfg.shuffle,
|
||||
pin_memory=cfg.pin_memory,
|
||||
num_workers=cfg.num_workers,
|
||||
drop_last=True,
|
||||
collate_fn=lambda x: x
|
||||
)
|
||||
|
||||
@property
|
||||
def local_train_batch_size(self):
|
||||
if not dist.is_initialized():
|
||||
return self.args.train_dataset.batch_size
|
||||
return self.args.train_dataset.batch_size // dist.get_world_size()
|
||||
|
||||
# TODO: check HF trainer signature
|
||||
def train(self, resume_from_checkpoint: Optional[Union[str, bool]] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_save_checkpoint_path(
|
||||
self, epoch: int, step: int, globalstep: int, name: str = "model"
|
||||
):
|
||||
path = os.path.join(
|
||||
constants.get_save_path(self.args),
|
||||
name,
|
||||
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
|
||||
)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerFactory:
|
||||
args: TrainingArgs
|
||||
|
||||
def make_trainer(
|
||||
self,
|
||||
config: TrainerConfig,
|
||||
train_dataset: Dataset,
|
||||
valid_dataset: Optional[Dataset] = None,
|
||||
rollout_controller: Optional["RolloutController"] = None,
|
||||
) -> Trainer:
|
||||
if config.type == "grpo":
|
||||
from arealite.impl.trainer.grpo import SpmdGRPOTrainer
|
||||
|
||||
return SpmdGRPOTrainer(
|
||||
self.args,
|
||||
config,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
rollout_controller=rollout_controller,
|
||||
)
|
||||
elif config.type == "sft":
|
||||
from arealite.impl.trainer.sft import SFTTrainer
|
||||
|
||||
return SFTTrainer(
|
||||
self.args,
|
||||
config,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
rollout_controller=rollout_controller,
|
||||
)
|
||||
elif config.type == "vl_sft":
|
||||
from arealite.impl.trainer.vl_sft import VL_SFTTrainer
|
||||
|
||||
return VL_SFTTrainer(
|
||||
self.args,
|
||||
config,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
rollout_controller=rollout_controller,
|
||||
)
|
||||
elif config.type == "vl_grpo":
|
||||
from arealite.impl.trainer.vl_grpo import VL_SpmdGRPOTrainer
|
||||
|
||||
return VL_SpmdGRPOTrainer(
|
||||
self.args,
|
||||
config,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
rollout_controller=rollout_controller,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown trainer type: {config.type}")
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import sys
|
||||
|
||||
from arealite.api.cli_args import TrainingArgs, prepare_training_args
|
||||
from arealite.api.llm_server_api import LLMServerFactory
|
||||
from realhf.base import seeding
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for launching the LLM server."""
|
||||
cfg: TrainingArgs = prepare_training_args(sys.argv[1:])[0]
|
||||
seeding.set_random_seed(cfg.seed, "llm_server")
|
||||
server = LLMServerFactory(cfg).make_server(cfg.rollout.llm_service)
|
||||
server.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,58 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.elastic.multiprocessing.errors import record
|
||||
|
||||
from arealite.api.cli_args import TrainingArgs, prepare_training_args
|
||||
from arealite.api.dataset_api import DatasetFactory
|
||||
from arealite.api.rollout_api import RolloutCollectorFactory
|
||||
from arealite.api.trainer_api import TrainerFactory
|
||||
from arealite.system.rollout_controller import RolloutController
|
||||
from realhf.base import seeding
|
||||
|
||||
|
||||
@record
|
||||
def main():
|
||||
"""Main entry point for launching the trainer."""
|
||||
cfg: TrainingArgs = prepare_training_args(sys.argv[1:])[0]
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
seeding.set_random_seed(cfg.seed, f"trainer{rank}")
|
||||
|
||||
# Initialize the global pytorch distributed communication group.
|
||||
dist.init_process_group("nccl")
|
||||
|
||||
# Load and split dataset
|
||||
dataset_factory = DatasetFactory(cfg)
|
||||
train_dataset = dataset_factory.make_dataset(cfg.train_dataset, rank, world_size)
|
||||
valid_dataset = None
|
||||
if cfg.valid_dataset is not None:
|
||||
valid_dataset = dataset_factory.make_dataset(
|
||||
cfg.valid_dataset, rank, world_size
|
||||
)
|
||||
|
||||
# Create rollout controller for online training and evaluation.
|
||||
rollout_controller = None
|
||||
if cfg.rollout is not None:
|
||||
rollout_factory = RolloutCollectorFactory(cfg)
|
||||
collector = rollout_factory.make_collector(cfg.rollout.collector)
|
||||
rollout_controller = RolloutController(cfg, cfg.rollout, collector=collector)
|
||||
|
||||
# If trainer is given, run RL or offline training.
|
||||
if cfg.trainer is not None:
|
||||
trainer_factory = TrainerFactory(cfg)
|
||||
trainer = trainer_factory.make_trainer(
|
||||
cfg.trainer,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
rollout_controller=rollout_controller,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,171 +0,0 @@
|
|||
# Basic experiment info
|
||||
experiment_name: gsm8k-test
|
||||
trial_name: my-trial-3
|
||||
seed: 1
|
||||
mode: local
|
||||
wandb:
|
||||
mode: disabled
|
||||
entity: null
|
||||
project: null
|
||||
name: null
|
||||
job_type: null
|
||||
group: null
|
||||
notes: null
|
||||
tags: null
|
||||
config: null
|
||||
tensorboard:
|
||||
path: null
|
||||
|
||||
exp_ctrl:
|
||||
total_train_epochs: 5
|
||||
save_freq_epochs: 1
|
||||
save_freq_steps: null
|
||||
save_freq_secs: null
|
||||
ckpt_freq_epochs: null
|
||||
ckpt_freq_steps: null
|
||||
ckpt_freq_secs: 600
|
||||
eval_freq_epochs: null
|
||||
eval_freq_steps: null
|
||||
eval_freq_secs: null
|
||||
benchmark_steps: null
|
||||
benchmark_n_seqs: null
|
||||
|
||||
# whether to allow persistent servers
|
||||
shutdown_server_on_exit: true
|
||||
|
||||
# Allocation and parallelism
|
||||
allocation_mode: sglang.d4p1t1+d4p1t1
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
|
||||
# Cluster configuration
|
||||
ray_temp_path: /tmp/ray
|
||||
cluster:
|
||||
cluster_name: local
|
||||
fileroot: /tmp/arealite/
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /tmp/arealite/name_resolve/
|
||||
|
||||
# Datasets
|
||||
train_dataset:
|
||||
path: json
|
||||
name: null
|
||||
split: train
|
||||
data_files: /storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl
|
||||
batch_size: 32
|
||||
shuffle: True
|
||||
preprocessor:
|
||||
type: areal
|
||||
|
||||
valid_dataset: null
|
||||
|
||||
# Rollout config
|
||||
rollout:
|
||||
collector:
|
||||
type: rlvr
|
||||
rlvr:
|
||||
reward_type: areal-math
|
||||
solution_path: /storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl
|
||||
num_workers: 1
|
||||
max_concurrent_rollouts: null
|
||||
max_head_offpolicyness: 0
|
||||
filter_reward_lb: -10000
|
||||
filter_reward_ub: 10000
|
||||
server_backend: sglang
|
||||
model_path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
||||
gconfig:
|
||||
n_samples: 16
|
||||
max_new_tokens: 512
|
||||
min_new_tokens: 0
|
||||
top_p: 1.0
|
||||
top_k: 1000000
|
||||
temperature: 1.0
|
||||
llm_client:
|
||||
schedule_policy: round_robin
|
||||
request_timeout: 3600
|
||||
request_retries: 3
|
||||
llm_service:
|
||||
served_model_name: null
|
||||
health_check_interval: 5
|
||||
startup_timeout: 300
|
||||
max_unhealth_count: 3
|
||||
graceful_shutdown_on_unhealthy: true
|
||||
sglang:
|
||||
dtype: "bfloat16"
|
||||
enable_mixed_chunk: false
|
||||
enable_torch_compile: false
|
||||
torch_compile_max_bs: 32
|
||||
cuda_graph_max_bs: null
|
||||
cuda_graph_bs: null
|
||||
triton_attention_reduce_in_fp32: false
|
||||
triton_attention_num_kv_splits: 8
|
||||
num_continuous_decode_steps: 1
|
||||
attention_backend: "flashinfer"
|
||||
sampling_backend: null
|
||||
context_length: 32768
|
||||
mem_fraction_static: 0.9
|
||||
max_running_requests: null
|
||||
chunked_prefill_size: -1
|
||||
max_prefill_tokens: 32768
|
||||
schedule_policy: "lpm"
|
||||
schedule_conservativeness: 1.0
|
||||
cpu_offload_gb: 0
|
||||
kv_cache_dtype: "auto"
|
||||
log_level: "warning"
|
||||
log_level_http: "warning"
|
||||
log_requests: false
|
||||
log_requests_level: 0
|
||||
show_time_cost: false
|
||||
enable_metrics: true
|
||||
decode_log_interval: 1
|
||||
|
||||
# Trainer
|
||||
trainer:
|
||||
type: grpo
|
||||
grpo:
|
||||
async_training: true
|
||||
actor:
|
||||
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: false
|
||||
bf16: true
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 1.0e-6
|
||||
weight_decay: 0.05
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
eps: 1.0e-08
|
||||
min_lr_ratio: 0.0
|
||||
lr_scheduler_type: constant
|
||||
warmup_steps_proportion: 0.001
|
||||
initial_loss_scale: 4294967296.0
|
||||
min_loss_scale: 1.0
|
||||
loss_scale_window: 5.0
|
||||
hysteresis: 2
|
||||
gradient_clipping: 1.0
|
||||
backend:
|
||||
type: fsdp
|
||||
ref: null
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 10240
|
||||
# Algorithm
|
||||
group_adv_norm: False
|
||||
ppo_n_minibatches: 4
|
||||
eps_clip: 0.2
|
||||
c_clip: null
|
||||
reward_scaling: 10.0
|
||||
reward_bias: -0.5
|
||||
max_reward_clip: 20.0
|
||||
mask_no_eos_with_zero: false
|
||||
discount: 1.0
|
||||
gae_lambda: 1.0
|
||||
adv_norm: true
|
||||
kl_ctl: 0.0
|
||||
recompute_logprob: true
|
||||
use_decoupled_loss: true
|
||||
behav_imp_weight_cap: null
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
# Basic experiment info
|
||||
experiment_name: test-sft
|
||||
trial_name: test-trial
|
||||
seed: 1
|
||||
mode: ray
|
||||
n_nodes: 1
|
||||
n_gpus_per_node: 8
|
||||
|
||||
wandb:
|
||||
mode: disabled
|
||||
entity: null
|
||||
project: null
|
||||
name: null
|
||||
job_type: null
|
||||
group: null
|
||||
notes: null
|
||||
tags: null
|
||||
config: null
|
||||
|
||||
tensorboard:
|
||||
path: null
|
||||
|
||||
exp_ctrl:
|
||||
total_train_epochs: 2
|
||||
save_freq_epochs: 1
|
||||
save_freq_steps: null
|
||||
save_freq_secs: null
|
||||
ckpt_freq_epochs: null
|
||||
ckpt_freq_steps: null
|
||||
ckpt_freq_secs: 600
|
||||
eval_freq_epochs: 1
|
||||
eval_freq_steps: null
|
||||
eval_freq_secs: null
|
||||
benchmark_steps: null
|
||||
benchmark_n_seqs: null
|
||||
|
||||
ray_temp_path: /tmp/ray
|
||||
cluster:
|
||||
cluster_name: local
|
||||
fileroot: /tmp/arealite/
|
||||
n_nodes: 32
|
||||
n_gpus_per_node: 8
|
||||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /tmp/arealite/nfs_record_root/
|
||||
|
||||
train_dataset:
|
||||
path: openai/gsm8k
|
||||
preprocessor:
|
||||
type: gsm8k_sft
|
||||
name: main
|
||||
split: train
|
||||
data_files: null
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
|
||||
valid_dataset:
|
||||
path: openai/gsm8k
|
||||
preprocessor:
|
||||
type: gsm8k_sft
|
||||
name: main
|
||||
split: test
|
||||
data_files: null
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
|
||||
trainer:
|
||||
type: sft
|
||||
sft:
|
||||
model:
|
||||
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
gradient_checkpointing: false
|
||||
bf16: false
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2.0e-05
|
||||
weight_decay: 0.05
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
eps: 1.0e-05
|
||||
min_lr_ratio: 0.0
|
||||
lr_scheduler_type: constant
|
||||
warmup_steps_proportion: 0.001
|
||||
initial_loss_scale: 4294967296.0
|
||||
min_loss_scale: 1.0
|
||||
loss_scale_window: 5.0
|
||||
hysteresis: 2
|
||||
gradient_clipping: 1.0
|
||||
backend:
|
||||
type: fsdp
|
||||
fsdp:
|
||||
wrap_policy:
|
||||
transformer_layer_cls_to_wrap: null
|
||||
offload_params: false
|
||||
mb_spec:
|
||||
n_mbs: 1
|
||||
max_tokens_per_mb: 10240
|
||||
|
||||
rollout: null
|
|
@ -1,87 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from arealite.api.cli_args import prepare_training_args
|
||||
from arealite.api.io_struct import AllocationMode
|
||||
from arealite.api.llm_server_api import LLMServiceRegistry
|
||||
from realhf.base import constants, name_resolve, names
|
||||
from realhf.scheduler.client import JobException, JobState
|
||||
from realhf.scheduler.client import make as make_scheduler
|
||||
|
||||
|
||||
def main():
|
||||
cfg, config_file = prepare_training_args(sys.argv[1:])
|
||||
if cfg.shutdown_server_on_exit:
|
||||
name_resolve.clear_subtree(
|
||||
names.trial_root(
|
||||
experiment_name=cfg.experiment_name, trial_name=cfg.trial_name
|
||||
)
|
||||
)
|
||||
|
||||
# Launch inference and training jobs
|
||||
alloc_mode = AllocationMode.from_str(cfg.allocation_mode)
|
||||
assert cfg.mode == "local"
|
||||
scheduler = make_scheduler(cfg)
|
||||
BASE_ENVIRONS = constants.get_env_vars(cfg)
|
||||
for k, v in BASE_ENVIRONS.items():
|
||||
os.environ[k] = v
|
||||
|
||||
# discover existing servers
|
||||
existing_servers = LLMServiceRegistry(
|
||||
cfg.experiment_name, cfg.trial_name
|
||||
).get_healthy_servers()
|
||||
# Launch LLM servers.
|
||||
if len(existing_servers) == 0:
|
||||
n_gpus_per_instance = alloc_mode.gen_pp_size * alloc_mode.gen_tp_size
|
||||
servers_to_launch = alloc_mode.gen_dp_size - len(existing_servers)
|
||||
scheduler.submit_array(
|
||||
worker_type="llm_server",
|
||||
cmd=f"python3 arealite/cli/launch_server.py --config {str(config_file)}",
|
||||
count=servers_to_launch,
|
||||
cpu=cfg.cpu_per_inf_proc * n_gpus_per_instance,
|
||||
gpu=n_gpus_per_instance,
|
||||
mem=cfg.mem_per_inf_proc * n_gpus_per_instance,
|
||||
env_vars=BASE_ENVIRONS,
|
||||
container_image=cfg.cluster.gpu_infer_image,
|
||||
)
|
||||
# Launch trainers.
|
||||
scheduler.submit(
|
||||
worker_type="trainer",
|
||||
cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} arealite/cli/launch_trainer.py --config {str(config_file)}",
|
||||
cpu=cfg.cpu_per_train_proc * alloc_mode.train_world_size,
|
||||
gpu=alloc_mode.train_world_size,
|
||||
mem=cfg.cpu_per_train_proc * cfg.mem_per_train_proc,
|
||||
container_image=cfg.cluster.gpu_image,
|
||||
nodelist=cfg.nodelist,
|
||||
exclude=cfg.exclude,
|
||||
env_vars=BASE_ENVIRONS,
|
||||
hostfile=False,
|
||||
multiprog=False,
|
||||
)
|
||||
|
||||
# Waiting for the job.
|
||||
try:
|
||||
scheduler.wait(
|
||||
check_status=(
|
||||
JobState.CANCELLED,
|
||||
JobState.FAILED,
|
||||
JobState.NOT_FOUND,
|
||||
JobState.COMPLETED,
|
||||
),
|
||||
remove_status=(),
|
||||
)
|
||||
except (KeyboardInterrupt, JobException, TimeoutError):
|
||||
kill_signal = (
|
||||
"SIGKILL" if cfg.mode == "slurm" else "SIGTERM"
|
||||
) # use sigkill to terminate slurm jobs
|
||||
if cfg.shutdown_server_on_exit:
|
||||
scheduler.stop_all(kill_signal)
|
||||
else:
|
||||
scheduler.stop("trainer")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,195 +0,0 @@
|
|||
import functools
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from realhf.base import pkg_version
|
||||
|
||||
|
||||
def actor_loss_fn(
|
||||
logprobs: torch.Tensor,
|
||||
old_logprobs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
eps_clip: float,
|
||||
loss_mask: torch.Tensor,
|
||||
c_clip: Optional[float] = None,
|
||||
proximal_logprobs: Optional[torch.Tensor] = None,
|
||||
behav_imp_weight_cap: Optional[float] = None,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
denorm_logprobs = (
|
||||
proximal_logprobs if proximal_logprobs is not None else old_logprobs
|
||||
)
|
||||
loss_mask_count = loss_mask.count_nonzero() or 1
|
||||
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
|
||||
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
|
||||
pg_loss1 = -advantages * ratio
|
||||
pg_loss2 = -advantages * clipped_ratio
|
||||
clip_mask = pg_loss1.detach() < pg_loss2.detach()
|
||||
pg_loss = torch.max(pg_loss1, pg_loss2)
|
||||
if c_clip is not None:
|
||||
assert c_clip > 1.0, c_clip
|
||||
pg_loss3 = torch.sign(advantages) * c_clip * advantages
|
||||
dual_clip_mask = pg_loss3.detach() < pg_loss.detach()
|
||||
pg_loss = torch.min(pg_loss, pg_loss3)
|
||||
else:
|
||||
dual_clip_mask = torch.zeros_like(clip_mask)
|
||||
if proximal_logprobs is not None:
|
||||
behav_kl = proximal_logprobs - old_logprobs
|
||||
behav_imp_weight = behav_kl.exp()
|
||||
behav_mask = (
|
||||
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
|
||||
if behav_imp_weight_cap is not None
|
||||
else loss_mask
|
||||
)
|
||||
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
|
||||
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
|
||||
pg_loss = pg_loss * behav_imp_weight
|
||||
logging_loss = pg_loss.detach()
|
||||
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
|
||||
clip_mask.logical_and_(loss_mask)
|
||||
dual_clip_mask.logical_and_(loss_mask)
|
||||
stat = dict(
|
||||
loss=logging_loss,
|
||||
importance_weight=ratio.detach(),
|
||||
approx_kl=(logprobs - denorm_logprobs).detach(),
|
||||
clip_mask=clip_mask,
|
||||
dual_clip_mask=dual_clip_mask,
|
||||
)
|
||||
if proximal_logprobs is not None:
|
||||
stat["behave_imp_weight"] = behav_imp_weight
|
||||
stat["behave_approx_kl"] = behav_kl
|
||||
stat["behave_mask"] = behav_mask
|
||||
return pg_loss, stat
|
||||
|
||||
|
||||
def _huber_loss(x: torch.Tensor, y: torch.Tensor, delta: float):
|
||||
diff = torch.abs(x - y)
|
||||
return torch.where(diff < delta, 0.5 * diff**2, delta * (diff - 0.5 * delta))
|
||||
|
||||
|
||||
def _mse_loss(x: torch.Tensor, y: torch.Tensor):
|
||||
return 0.5 * (x - y) ** 2
|
||||
|
||||
|
||||
def critic_loss_fn(
|
||||
value: torch.Tensor,
|
||||
old_value: torch.Tensor,
|
||||
target_value: torch.Tensor,
|
||||
value_eps_clip: float,
|
||||
loss_mask: torch.Tensor,
|
||||
loss_fn_type: str = "mse",
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
if loss_fn_type == "huber":
|
||||
loss_fn = functools.partial(_huber_loss, delta=10.0)
|
||||
elif loss_fn_type == "mse":
|
||||
loss_fn = _mse_loss
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown loss fn type: {loss_fn_type}")
|
||||
value_loss_original = loss_fn(value, target_value)
|
||||
value_clipped = old_value + (value - old_value).clamp(
|
||||
-value_eps_clip, value_eps_clip
|
||||
)
|
||||
value_loss_clipped = loss_fn(value_clipped, target_value)
|
||||
value_loss = torch.max(value_loss_original, value_loss_clipped)
|
||||
with torch.no_grad():
|
||||
clip_mask = value_loss_clipped.detach() > value_loss_original.detach()
|
||||
clip_mask.logical_and_(loss_mask)
|
||||
stat = dict(clip_mask=clip_mask, loss=value_loss.detach())
|
||||
value_loss = torch.where(loss_mask, value_loss, 0).sum() / loss_mask.count_nonzero()
|
||||
return value_loss, stat
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_packed_rewards(
|
||||
kl_ctl: float,
|
||||
clip_reward_value: float,
|
||||
log_probs: torch.Tensor,
|
||||
ref_log_probs: torch.Tensor,
|
||||
reward_score: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
seq_no_eos_mask: torch.Tensor,
|
||||
mask_no_eos_with_zero: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
tot_rewards = -kl_ctl * (log_probs - ref_log_probs)
|
||||
tot_rewards[cu_seqlens[1:] - 1] = 0
|
||||
kl_rewards = tot_rewards.clone()
|
||||
reward_score = reward_score.clip(-clip_reward_value, clip_reward_value)
|
||||
indices = torch.clip(cu_seqlens[1:] - 2, min=0)
|
||||
if mask_no_eos_with_zero:
|
||||
tot_rewards[indices] += torch.where(seq_no_eos_mask, 0, reward_score)
|
||||
else:
|
||||
tot_rewards[indices] += reward_score
|
||||
return kl_rewards, tot_rewards
|
||||
|
||||
|
||||
def pygae1d_nolp_misalign(
|
||||
rewards: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
cu_seqlens_: torch.Tensor,
|
||||
bootstrap: torch.Tensor,
|
||||
gamma: float,
|
||||
lam: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cu_seqlens = cu_seqlens_.clone()
|
||||
cu_seqlens[1:] += torch.ones_like(cu_seqlens_[1:]).cumsum(0)
|
||||
bs = cu_seqlens_.shape[0] - 1
|
||||
assert values.shape[0] == rewards.shape[0] + bs
|
||||
advantages_reversed = []
|
||||
returns_reversed = []
|
||||
for i in reversed(range(bs)):
|
||||
v_offset = cu_seqlens[i]
|
||||
r_offset, r_end = cu_seqlens_[i], cu_seqlens_[i + 1]
|
||||
assert cu_seqlens[i + 1] - v_offset - 1 == r_end - r_offset
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(r_end - r_offset)):
|
||||
nextvalues = values[v_offset + t + 1]
|
||||
if t == r_end - r_offset - 1:
|
||||
nextvalues *= bootstrap[i]
|
||||
delta = rewards[r_offset + t] + gamma * nextvalues - values[v_offset + t]
|
||||
lastgaelam = delta + gamma * lam * lastgaelam
|
||||
advantages_reversed.append(lastgaelam)
|
||||
returns_reversed.append(lastgaelam + values[v_offset + t])
|
||||
advantages = torch.stack(advantages_reversed[::-1])
|
||||
returns = torch.stack(returns_reversed[::-1])
|
||||
return advantages, returns
|
||||
|
||||
|
||||
def cugae1d_nolp_misalign_func(
|
||||
rewards: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
truncate: torch.Tensor,
|
||||
gamma: float,
|
||||
lam: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if pkg_version.is_available("cugae"):
|
||||
from cugae import cugae1d_nolp_misalign_func as gae_1d_nolp_misalign
|
||||
else:
|
||||
from realhf._C.cugae import gae_1d_nolp_misalign
|
||||
assert len(rewards.shape) == len(values.shape) == len(cu_seqlens.shape) == 1
|
||||
assert cu_seqlens[0] == 0 and cu_seqlens[-1] == rewards.shape[0]
|
||||
return gae_1d_nolp_misalign(rewards, values, cu_seqlens, truncate, gamma, lam)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_packed_advantages_and_returns(
|
||||
gamma: float,
|
||||
lam: float,
|
||||
values: torch.Tensor,
|
||||
rewards: torch.Tensor,
|
||||
short1cu_seqlens: torch.Tensor,
|
||||
seq_no_eos_mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if rewards.get_device() == -1:
|
||||
return pygae1d_nolp_misalign(
|
||||
rewards, values, short1cu_seqlens, seq_no_eos_mask, gamma, lam
|
||||
)
|
||||
try:
|
||||
return cugae1d_nolp_misalign_func(
|
||||
rewards, values, short1cu_seqlens.int(), seq_no_eos_mask.bool(), gamma, lam
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return pygae1d_nolp_misalign(
|
||||
rewards, values, short1cu_seqlens, seq_no_eos_mask, gamma, lam
|
||||
)
|
Binary file not shown.
Before Width: | Height: | Size: 2.7 MiB |
Binary file not shown.
Before Width: | Height: | Size: 892 KiB |
File diff suppressed because one or more lines are too long
|
@ -1,10 +0,0 @@
|
|||
{"prompt": "<\uff5cUser\uff5c>\nBaron Munchausen told a story. \"There were a whole crowd of us. We reached a crossroads. Then half of our group turned left, a third turned right, and a fifth went straight.\" \"But wait, the Duke remarked, the sum of half, a third, and a fifth isn't equal to one, so you are lying!\" The Baron replied, \"I'm not lying, I'm rounding. For example, there are 17 people. I say that a third turned. Should one person split in your opinion? No, with rounding, six people turned. From whole numbers, the closest to the fraction $17 / 3$ is 6. And if I say that half of the 17 people turned, it means 8 or 9 people.\" It is known that Baron Munchausen never lies. What is the largest number of people that could have been in the crowd?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "00006d8f079c739f", "solutions": ["\\boxed{37}"]}
|
||||
{"prompt": "<\uff5cUser\uff5c>What is the unit digit of the product\n\n$$\n(5+1)\\left(5^{3}+1\\right)\\left(5^{6}+1\\right)\\left(5^{12}+1\\right) ?\n$$\n\n(a) 0 \n(b) 1 \n(c) 2 \n(d) 5 \n(e) 6\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "000316109ea516b3", "solutions": ["\\boxed{e}"]}
|
||||
{"prompt": "<\uff5cUser\uff5c>Given points \\( A(4,0) \\) and \\( B(2,2) \\) are inside the ellipse \\( \\frac{x^{2}}{25}+\\frac{y^{2}}{9}=1 \\), and \\( M \\) is a point on the ellipse, find the maximum value of \\( |MA| + |MB| \\).\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "000adcfa66ee4270", "solutions": ["\\boxed{10+2\\sqrt{10}}"]}
|
||||
{"prompt": "<\uff5cUser\uff5c>There is a schoolbag containing 12 cards labeled $1, 1, 2, 2, \\cdots, 6, 6$. A person draws one card at a time without replacement. If a card is drawn that has the same number as a previously drawn card, both cards are discarded. The process ends when the person has 3 single cards in hand or all cards in the schoolbag have been drawn. Find the probability that all cards in the schoolbag are drawn.\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "001354647264e663", "solutions": ["\\boxed{\\frac{9}{385}}"]}
|
||||
{"prompt": "<\uff5cUser\uff5c>For the sequence of numbers \\( n_{1}, n_{2}, n_{3}, \\ldots \\), the relation \\( n_{i} = 2 n_{i-1} + a \\) holds for all \\( i > 1 \\). If \\( n_{2} = 5 \\) and \\( n_{8} = 257 \\), what is \\( n_{5} \\)?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "0014142e5f3c28a7", "solutions": ["\\boxed{33}"]}
|
||||
{"prompt": "<\uff5cUser\uff5c>Three players play tic-tac-toe together. In other words, the three players take turns placing an \"A\", \"B\", and \"C\", respectively, in one of the free spots of a \\(3 \\times 3\\) grid, and the first player to have three of their label in a row, column, or diagonal wins. How many possible final boards are there where the player who goes third wins the game? (Rotations and reflections are considered different boards, but the order of placement does not matter.)\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "0017c4e9f72d26eb", "solutions": ["\\boxed{148}"]}
|
||||
{"prompt": "<\uff5cUser\uff5c>Let \\( a_{1}, a_{2}, \\cdots, a_{2014} \\) be a permutation of the positive integers \\( 1, 2, \\cdots, 2014 \\). Define\n\\[ S_{k} = a_{1} + a_{2} + \\cdots + a_{k} \\quad (k=1, 2, \\cdots, 2014). \\]\n\nWhat is the maximum number of odd numbers among \\( S_{1}, S_{2}, \\cdots, S_{2014} \\)?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "00231541a71983cd", "solutions": ["\\boxed{1511}"]}
|
||||
{"prompt": "<\uff5cUser\uff5c>\nThe polynomial \\( G(x) \\) with real coefficients takes the value 2022 at exactly five distinct points \\( x_{1}<x_{2}<x_{3}<x_{4}<x_{5} \\). It is known that the graph of the function \\( y=G(x) \\) is symmetric with respect to the line \\( x=-7 \\).\n\n(a) Find \\( x_{1}+x_{3}+x_{5} \\).\n\n(b) What is the minimum degree that \\( G(x) \\) can have?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "002ba4c0d1ad1b54", "solutions": ["\\boxed{6}"]}
|
||||
{"prompt": "<\uff5cUser\uff5c>The square of a natural number has 202 digits. The first 100 digits are 1, followed by 101 digits of 2. Determine the last digit and the number.\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "0041f14cc37aee13", "solutions": ["\\boxed{5}"]}
|
||||
{"prompt": "<\uff5cUser\uff5c>The descriptors 'even', 'factors of 240', 'multiple of 3', 'odd', 'prime' and 'square' are to be placed in some order as row and column headings around a grid in positions \\(a, b, c, d, e,\\) and \\(f\\). The digits 1 through 9 are to be placed in the empty cells inside the grid so that each digit satisfies both the relevant row and column headings.\n(i) Show that it is possible to complete the grid.\n(ii) In how many different ways can the grid be completed?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c><think>", "task": "math", "query_id": "00514ff45cc98a48", "solutions": ["\\boxed{72}"]}
|
|
@ -1,172 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
DatasetPreprocessor,
|
||||
GenerationHyperparameters,
|
||||
GSM8KPreprocessor,
|
||||
MathCodeSingleStepConfig,
|
||||
RLVRConfig,
|
||||
RolloutCollectorConfig,
|
||||
SGLangConfig,
|
||||
TrainingArgs,
|
||||
)
|
||||
from arealite.api.io_struct import Trajectory
|
||||
from arealite.api.llm_client_api import LLMClientFactory
|
||||
from arealite.api.llm_server_api import LLMServerFactory
|
||||
from arealite.api.rollout_api import RolloutCollectorFactory
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import name_resolve, seeding
|
||||
|
||||
EXPR_NAME = "test_rollout"
|
||||
TRIAL_NAME = "test_rollout"
|
||||
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-1.7B/"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
yield load_hf_tokenizer(MODEL_PATH)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def args():
|
||||
args = TrainingArgs(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
args.rollout.model_path = MODEL_PATH
|
||||
seeding.set_random_seed(args.seed, EXPR_NAME)
|
||||
name_resolve.reconfigure(args.cluster.name_resolve)
|
||||
yield args
|
||||
name_resolve.reset()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_server(args):
|
||||
args.rollout.sglang = SGLangConfig()
|
||||
server = LLMServerFactory(args).make_server(args.rollout.llm_service)
|
||||
server._startup()
|
||||
yield
|
||||
server._graceful_exit(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["math", "code"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_rlvr_rollout(args, sglang_server, tokenizer, task):
|
||||
jsonl_file = Path(__file__).parent / "data" / f"rlvr_{task}_dataset.jsonl"
|
||||
args.rollout.server_backend = "sglang"
|
||||
args.rollout.gconfig = gconfig = GenerationHyperparameters(max_new_tokens=16)
|
||||
args.rollout.collector = RolloutCollectorConfig(
|
||||
type="rlvr",
|
||||
rlvr=RLVRConfig(reward_type=f"areal-{task}", solution_path=jsonl_file),
|
||||
)
|
||||
llm_client = LLMClientFactory(args).make_client(args.rollout.llm_client)
|
||||
collector = RolloutCollectorFactory(args).make_collector(args.rollout.collector)
|
||||
|
||||
# Test the rollout collector with the provided JSONL data
|
||||
with open(jsonl_file, "r") as f:
|
||||
for i, l in enumerate(f.readlines()):
|
||||
data = json.loads(l)
|
||||
env_option = dict(
|
||||
query_id=data["query_id"],
|
||||
input_ids=tokenizer.encode(data["prompt"]),
|
||||
prompt=data["prompt"],
|
||||
)
|
||||
res = await collector.arun_episode(
|
||||
llm_client=llm_client,
|
||||
gconfig=gconfig,
|
||||
env_option=env_option,
|
||||
)
|
||||
assert isinstance(res, Trajectory)
|
||||
assert isinstance(res.data, dict)
|
||||
assert res.prompt == env_option
|
||||
shape = res.data["input_ids"].shape
|
||||
for k in ["prompt_mask", "logprobs", "versions"]:
|
||||
assert res.data[k].shape == shape
|
||||
assert res.stats.episode_length == 1
|
||||
assert res.stats.total_reward in [0, 1], res.stats.total_reward
|
||||
assert res.stats.start_time < datetime.now().timestamp()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gsm8k_rollout(args, sglang_server, tokenizer):
|
||||
args.rollout.server_backend = "sglang"
|
||||
args.rollout.gconfig = gconfig = GenerationHyperparameters(max_new_tokens=16)
|
||||
args.rollout.collector = RolloutCollectorConfig(
|
||||
type="rlvr", rlvr=RLVRConfig(reward_type="gsm8k")
|
||||
)
|
||||
collector = RolloutCollectorFactory(args).make_collector(args.rollout.collector)
|
||||
|
||||
args.train_dataset.path = "openai/gsm8k"
|
||||
args.train_dataset.name = "main"
|
||||
args.train_dataset.split = "train"
|
||||
args.train_dataset.preprocessor = DatasetPreprocessor(
|
||||
"gsm8k_rl", gsm8k=GSM8KPreprocessor("strict")
|
||||
)
|
||||
|
||||
from arealite.api.dataset_api import DatasetFactory
|
||||
|
||||
llm_client = LLMClientFactory(args).make_client(args.rollout.llm_client)
|
||||
dataset = (
|
||||
DatasetFactory(args)
|
||||
.make_dataset(args.train_dataset, rank=0, world_size=1)
|
||||
.select(range(10))
|
||||
)
|
||||
for i in range(len(dataset)):
|
||||
env_option = dataset[i]
|
||||
res = await collector.arun_episode(
|
||||
llm_client=llm_client,
|
||||
gconfig=gconfig,
|
||||
env_option=env_option,
|
||||
)
|
||||
assert isinstance(res, Trajectory)
|
||||
assert isinstance(res.data, dict)
|
||||
assert res.prompt == env_option
|
||||
shape = res.data["input_ids"].shape
|
||||
for k in ["prompt_mask", "logprobs", "versions"]:
|
||||
assert res.data[k].shape == shape
|
||||
assert res.stats.episode_length == 1
|
||||
assert res.stats.total_reward in [0, 1], res.stats.total_reward
|
||||
assert res.stats.start_time < datetime.now().timestamp()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["math", "code"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_math_code_agentic_rollout(args, task, sglang_server, tokenizer):
|
||||
jsonl_file = Path(__file__).parent / "data" / f"rlvr_{task}_dataset.jsonl"
|
||||
args.rollout.server_backend = "sglang"
|
||||
args.rollout.gconfig = gconfig = GenerationHyperparameters(max_new_tokens=16)
|
||||
args.rollout.collector = RolloutCollectorConfig(
|
||||
type="math_code_single_step",
|
||||
math_code_single_step=MathCodeSingleStepConfig(solution_path=jsonl_file),
|
||||
)
|
||||
|
||||
collector = RolloutCollectorFactory(args).make_collector(args.rollout.collector)
|
||||
llm_client = LLMClientFactory(args).make_client(args.rollout.llm_client)
|
||||
|
||||
# Test the rollout collector with the provided JSONL data
|
||||
with open(jsonl_file, "r") as f:
|
||||
for i, l in enumerate(f.readlines()):
|
||||
data = json.loads(l)
|
||||
env_option = dict(
|
||||
query_id=data["query_id"],
|
||||
input_ids=tokenizer.encode(data["prompt"]),
|
||||
)
|
||||
res = await collector.arun_episode(
|
||||
llm_client=llm_client,
|
||||
gconfig=gconfig,
|
||||
env_option=env_option,
|
||||
)
|
||||
assert isinstance(res, Trajectory)
|
||||
assert isinstance(res.data, dict)
|
||||
assert res.prompt == env_option
|
||||
shape = res.data["input_ids"].shape
|
||||
for k in ["prompt_mask", "logprobs", "versions"]:
|
||||
assert res.data[k].shape == shape
|
||||
assert res.stats.episode_length == 1
|
||||
assert res.stats.total_reward in [0, 1], res.stats.total_reward
|
||||
assert res.stats.start_time < datetime.now().timestamp()
|
|
@ -1,209 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from datasets import load_dataset
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import RLVRConfig, SGLangConfig, TrainingArgs
|
||||
from arealite.api.io_struct import Trajectory
|
||||
from arealite.api.llm_server_api import LLMServerFactory
|
||||
from arealite.api.rollout_api import RolloutCollectorFactory
|
||||
from arealite.system.rollout_controller import RolloutController
|
||||
from arealite.tests.utils import mock_rollout_output
|
||||
from arealite.utils import concat_padded_tensors
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import name_resolve, names, seeding
|
||||
|
||||
EXPR_NAME = "test_rollout_controller"
|
||||
TRIAL_NAME = "test_rollout_controller"
|
||||
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-1.7B/"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def args():
|
||||
args = TrainingArgs(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
seeding.set_random_seed(args.seed, EXPR_NAME)
|
||||
args.rollout.model_path = MODEL_PATH
|
||||
args.rollout.llm_client.tokenizer_path = MODEL_PATH
|
||||
args.train_dataset.batch_size = 2
|
||||
args.rollout.collector.rlvr = RLVRConfig(
|
||||
solution_path=str(Path(__file__).parent / "data" / f"rlvr_math_dataset.jsonl")
|
||||
)
|
||||
start_method = mp.get_start_method()
|
||||
mp.set_start_method("fork", force=True)
|
||||
name_resolve.reconfigure(args.cluster.name_resolve)
|
||||
yield args
|
||||
name_resolve.reset()
|
||||
mp.set_start_method(start_method, force=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_server(args):
|
||||
args.rollout.sglang = SGLangConfig()
|
||||
server = LLMServerFactory(args).make_server(args.rollout.llm_service)
|
||||
server._startup()
|
||||
yield
|
||||
server._graceful_exit(0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataloader(args):
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
split="train",
|
||||
data_files=str(Path(__file__).parent / "data" / f"rlvr_math_dataset.jsonl"),
|
||||
)
|
||||
tokenizer = load_hf_tokenizer(MODEL_PATH)
|
||||
dataset = dataset.map(lambda x: tokenizer(x["prompt"]), batched=True)
|
||||
yield StatefulDataLoader(
|
||||
dataset,
|
||||
batch_size=args.train_dataset.batch_size,
|
||||
collate_fn=lambda x: x,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 4])
|
||||
@pytest.mark.parametrize("n_samples", [1, 2])
|
||||
def test_generate_batch(args, sglang_server, dataloader, n_samples, num_workers):
|
||||
args = deepcopy(args)
|
||||
args.rollout.num_workers = num_workers
|
||||
args.rollout.gconfig.n_samples = n_samples
|
||||
args.rollout.gconfig.max_new_tokens = 16
|
||||
rollout_factory = RolloutCollectorFactory(args)
|
||||
collector = rollout_factory.make_collector(args.rollout.collector)
|
||||
rollout_controller = RolloutController(args, args.rollout, collector=collector)
|
||||
|
||||
data = next(iter(dataloader))
|
||||
batch_size = len(data)
|
||||
result = rollout_controller.generate_batch(batch_size, env_options=data)
|
||||
|
||||
assert len(result) == batch_size * n_samples
|
||||
assert all(isinstance(traj, Trajectory) for traj in result)
|
||||
for traj in result:
|
||||
shape = traj.data["input_ids"].shape
|
||||
assert len(shape) == 2
|
||||
for v in traj.data.values():
|
||||
assert v.shape == shape or len(v.shape) == 1
|
||||
data = concat_padded_tensors([traj.data for traj in result])
|
||||
assert data["input_ids"].shape[0] == batch_size * n_samples
|
||||
shape = data["input_ids"].shape
|
||||
assert len(shape) == 2
|
||||
for v in data.values():
|
||||
assert v.shape == shape or len(v.shape) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 3])
|
||||
@pytest.mark.parametrize("n_samples", [1, 2, 4])
|
||||
def test_mock_trajs(batch_size, n_samples):
|
||||
# Test the consistency with mocked rollout output
|
||||
result = mock_rollout_output(batch_size, n_samples)
|
||||
assert len(result) == batch_size * n_samples
|
||||
assert all(isinstance(traj, Trajectory) for traj in result)
|
||||
for traj in result:
|
||||
shape = traj.data["input_ids"].shape
|
||||
assert len(shape) == 2
|
||||
for v in traj.data.values():
|
||||
assert v.shape == shape or len(v.shape) == 1
|
||||
data = concat_padded_tensors([traj.data for traj in result])
|
||||
assert data["input_ids"].shape[0] == batch_size * n_samples
|
||||
shape = data["input_ids"].shape
|
||||
assert len(shape) == 2
|
||||
for v in data.values():
|
||||
assert v.shape == shape or len(v.shape) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [1, 4, 16])
|
||||
@pytest.mark.parametrize("num_workers", [1, 2, 5])
|
||||
def test_async_rollout(args, sglang_server, dataloader, n_samples, num_workers):
|
||||
args = deepcopy(args)
|
||||
args.rollout.gconfig.n_samples = n_samples
|
||||
args.rollout.gconfig.max_new_tokens = 16
|
||||
args.train_dataset.batch_size = 2
|
||||
args.rollout.max_concurrent_rollouts = 16
|
||||
args.rollout.num_workers = num_workers
|
||||
rollout_factory = RolloutCollectorFactory(args)
|
||||
collector = rollout_factory.make_collector(args.rollout.collector)
|
||||
rollout_controller = RolloutController(args, args.rollout, collector=collector)
|
||||
|
||||
# start loop
|
||||
rollout_controller.start_generate_loop()
|
||||
assert hasattr(rollout_controller, "_collector_thread")
|
||||
assert rollout_controller._collector_thread.is_alive()
|
||||
|
||||
# Submit data to workers
|
||||
data = next(iter(dataloader))
|
||||
rollout_controller.submit(data)
|
||||
|
||||
# wait for batch
|
||||
batch_size = 2
|
||||
result = rollout_controller.prepare_batch(batch_size)
|
||||
assert len(result) == batch_size * n_samples
|
||||
assert all(isinstance(traj, Trajectory) for traj in result)
|
||||
for traj in result:
|
||||
shape = traj.data["input_ids"].shape
|
||||
assert len(shape) == 2
|
||||
for v in traj.data.values():
|
||||
assert v.shape == shape or len(v.shape) == 1
|
||||
data = concat_padded_tensors([traj.data for traj in result])
|
||||
assert data["input_ids"].shape[0] == batch_size * n_samples
|
||||
shape = data["input_ids"].shape
|
||||
assert len(shape) == 2
|
||||
for v in data.values():
|
||||
assert v.shape == shape or len(v.shape) == 1
|
||||
|
||||
# exit
|
||||
rollout_controller.stop_generate_loop()
|
||||
assert rollout_controller._exiting.is_set()
|
||||
assert not rollout_controller._collector_thread.is_alive()
|
||||
assert not rollout_controller._worker_processes
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ofp", [1, 2, 4, 16])
|
||||
def test_async_staleness_control(args, sglang_server, dataloader, ofp):
|
||||
args = deepcopy(args)
|
||||
args.rollout.gconfig.n_samples = 2
|
||||
args.rollout.gconfig.max_new_tokens = 4
|
||||
args.rollout.max_head_offpolicyness = ofp
|
||||
args.rollout.max_concurrent_rollouts = 100
|
||||
rollout_factory = RolloutCollectorFactory(args)
|
||||
collector = rollout_factory.make_collector(args.rollout.collector)
|
||||
rollout_controller = RolloutController(args, args.rollout, collector=collector)
|
||||
name = names.model_version(args.experiment_name, args.trial_name, "actor")
|
||||
name_resolve.add(name, str(0), replace=True)
|
||||
|
||||
# start loop
|
||||
rollout_controller.start_generate_loop()
|
||||
batch_size = args.train_dataset.batch_size
|
||||
|
||||
gen = iter(dataloader)
|
||||
rollout_controller.submit(next(gen))
|
||||
rollout_controller.submit(next(gen))
|
||||
# wait for some time
|
||||
time.sleep(15)
|
||||
assert len(rollout_controller._buffer) == min(
|
||||
batch_size * 2, batch_size * (ofp + 1)
|
||||
)
|
||||
|
||||
# Update model version
|
||||
name = names.model_version(args.experiment_name, args.trial_name, "actor")
|
||||
name_resolve.add(name, str(1), replace=True)
|
||||
print("Updated model version", flush=True)
|
||||
|
||||
# submit again
|
||||
rollout_controller.submit(next(gen))
|
||||
rollout_controller.submit(next(gen))
|
||||
# wait for some time
|
||||
time.sleep(15)
|
||||
assert len(rollout_controller._buffer) == min(
|
||||
batch_size * 4, batch_size * (ofp + 2)
|
||||
)
|
||||
|
||||
# exit
|
||||
rollout_controller.stop_generate_loop()
|
|
@ -1,108 +0,0 @@
|
|||
"""Test script for FSDP Engine implementation."""
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
DatasetConfig,
|
||||
DatasetPreprocessor,
|
||||
EngineBackendConfig,
|
||||
EngineConfig,
|
||||
OptimizerConfig,
|
||||
SFTTrainerConfig,
|
||||
TrainerConfig,
|
||||
TrainingArgs,
|
||||
)
|
||||
from arealite.api.dataset_api import DatasetFactory
|
||||
from arealite.api.trainer_api import TrainerFactory
|
||||
|
||||
|
||||
def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
|
||||
"""Mock loss function for testing."""
|
||||
return torch.mean(logits)
|
||||
|
||||
|
||||
def mock_loss_weight_fn(logits: torch.Tensor, input_data: Dict) -> float:
|
||||
"""Mock loss weight function for testing."""
|
||||
return float(input_data["attention_mask"].sum())
|
||||
|
||||
|
||||
def test_sft():
|
||||
"""Test SFTTrainer"""
|
||||
# environment variables for torch distributed
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "7777"
|
||||
|
||||
train_dataset = DatasetConfig(
|
||||
path="openai/gsm8k",
|
||||
preprocessor=DatasetPreprocessor("gsm8k_sft"),
|
||||
name="main",
|
||||
split="train",
|
||||
batch_size=8,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
valid_dataset = DatasetConfig(
|
||||
path="openai/gsm8k",
|
||||
preprocessor=DatasetPreprocessor("gsm8k_sft"),
|
||||
name="main",
|
||||
split="test",
|
||||
batch_size=8,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
engine_config = EngineConfig(
|
||||
path="/storage/openpsi/models/Qwen__Qwen3-1.7B/",
|
||||
gradient_checkpointing=False,
|
||||
optimizer=OptimizerConfig(),
|
||||
backend=EngineBackendConfig(type="hf"),
|
||||
)
|
||||
|
||||
sft_config = SFTTrainerConfig(
|
||||
model=engine_config,
|
||||
)
|
||||
|
||||
train_config = TrainerConfig(
|
||||
type="sft",
|
||||
sft=sft_config,
|
||||
)
|
||||
|
||||
args = TrainingArgs(
|
||||
experiment_name="test-sft",
|
||||
trial_name="test",
|
||||
mode="local",
|
||||
n_nodes=1,
|
||||
n_gpus_per_node=1,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
trainer=train_config,
|
||||
)
|
||||
|
||||
rollout_controller = None
|
||||
dataset_factory = DatasetFactory(args)
|
||||
train_dataset = dataset_factory.make_dataset(args.train_dataset, 0, 1)
|
||||
train_dataset = train_dataset.select(range(100))
|
||||
valid_dataset = None
|
||||
if args.valid_dataset is not None:
|
||||
valid_dataset = dataset_factory.make_dataset(args.valid_dataset, 0, 1)
|
||||
valid_dataset = valid_dataset.select(range(100))
|
||||
if args.trainer is not None:
|
||||
trainer_factory = TrainerFactory(args)
|
||||
trainer = trainer_factory.make_trainer(
|
||||
args.trainer,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
rollout_controller=rollout_controller,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
print("All tests passed!")
|
||||
test_sft()
|
|
@ -1,112 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
EngineBackendConfig,
|
||||
EngineConfig,
|
||||
GenerationHyperparameters,
|
||||
LLMClientConfig,
|
||||
OptimizerConfig,
|
||||
SGLangConfig,
|
||||
TrainingArgs,
|
||||
)
|
||||
from arealite.api.engine_api import EngineFactory
|
||||
from arealite.api.io_struct import FinetuneSpec, LLMRequest, LLMResponse
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
from arealite.api.llm_server_api import LLMServerFactory
|
||||
from realhf.base import name_resolve, seeding
|
||||
|
||||
EXPR_NAME = "test_sglang_client"
|
||||
TRIAL_NAME = "test_sglang_client"
|
||||
MODEL_PATH = "/storage/openpsi/models/Qwen__Qwen3-1.7B/"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def args():
|
||||
args = TrainingArgs(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME)
|
||||
args.rollout.model_path = MODEL_PATH
|
||||
seeding.set_random_seed(args.seed, EXPR_NAME)
|
||||
name_resolve.reconfigure(args.cluster.name_resolve)
|
||||
yield args
|
||||
name_resolve.reset()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_server(args):
|
||||
args.rollout.sglang = SGLangConfig(mem_fraction_static=0.3)
|
||||
server = LLMServerFactory(args).make_server(args.rollout.llm_service)
|
||||
server._startup()
|
||||
yield
|
||||
server._graceful_exit(0)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sglang_client(args, sglang_server):
|
||||
from arealite.system.sglang_client import SGLangClient
|
||||
|
||||
args.rollout.server_backend = "sglang"
|
||||
llm_client = LLMClientConfig()
|
||||
client = SGLangClient(args, client_config=llm_client)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sglang_generate(sglang_client):
|
||||
req = LLMRequest(
|
||||
rid=str(uuid.uuid4()),
|
||||
text="hello! how are you today",
|
||||
gconfig=GenerationHyperparameters(max_new_tokens=16),
|
||||
)
|
||||
resp = await sglang_client.agenerate(req)
|
||||
assert isinstance(resp, LLMResponse)
|
||||
assert resp.input_tokens == req.input_ids
|
||||
assert (
|
||||
len(resp.output_logprobs)
|
||||
== len(resp.output_tokens)
|
||||
== len(resp.output_versions)
|
||||
)
|
||||
assert isinstance(resp.completion, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sglang_update_weights_from_disk(sglang_client: LLMClient):
|
||||
servers = sglang_client.get_healthy_servers()
|
||||
assert len(servers) == 1
|
||||
await sglang_client.aupdate_weights_from_disk(
|
||||
server_info=servers[0], path=MODEL_PATH
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def engine(sglang_server):
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "7777"
|
||||
engine_config = EngineConfig(
|
||||
path=MODEL_PATH,
|
||||
gradient_checkpointing=False,
|
||||
optimizer=OptimizerConfig(),
|
||||
backend=EngineBackendConfig(type="fsdp"),
|
||||
)
|
||||
|
||||
mock_args = TrainingArgs(n_nodes=1, n_gpus_per_node=1)
|
||||
|
||||
engine_factory = EngineFactory(mock_args)
|
||||
engine = engine_factory.make_engine(engine_config)
|
||||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||||
engine.init_distributed(None, ft_spec)
|
||||
print("✓ Engine created successfully")
|
||||
yield engine
|
||||
|
||||
|
||||
def test_sglang_update_weights_from_distributed(
|
||||
engine, sglang_server, sglang_client: LLMClient
|
||||
):
|
||||
engine.update_weights_to(sglang_client)
|
File diff suppressed because one or more lines are too long
|
@ -1,36 +0,0 @@
|
|||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from arealite.api.io_struct import Trajectory, TrajStats
|
||||
|
||||
|
||||
def mock_rollout_output(bs, n_samples):
|
||||
trajs = []
|
||||
min_seqlen, max_seqlen = 8, 16
|
||||
for _ in range(bs * n_samples):
|
||||
input_len = random.randint(min_seqlen, max_seqlen)
|
||||
prompt_len = random.randint(1, min_seqlen - 1)
|
||||
input_ids = torch.randint(0, 100, (input_len,))
|
||||
prompt_mask = torch.tensor([1] * prompt_len + [0] * (input_len - prompt_len))
|
||||
logprobs = -torch.randn(input_len).abs()
|
||||
versions = torch.zeros(input_len)
|
||||
traj = Trajectory(
|
||||
prompt=None,
|
||||
data=dict(
|
||||
input_ids=input_ids.unsqueeze(0),
|
||||
prompt_mask=prompt_mask.unsqueeze(0),
|
||||
logprobs=logprobs.unsqueeze(0),
|
||||
versions=versions.unsqueeze(0),
|
||||
rewards=torch.tensor([random.random()]),
|
||||
),
|
||||
stats=TrajStats(
|
||||
start_time=0,
|
||||
total_reward=0,
|
||||
episode_length=1,
|
||||
info={},
|
||||
),
|
||||
)
|
||||
trajs.append(traj)
|
||||
|
||||
return trajs
|
|
@ -1,317 +0,0 @@
|
|||
Metadata-Version: 2.4
|
||||
Name: realhf
|
||||
Version: 0.3.0.dev0
|
||||
Summary: ReaL: Efficient RLHF Training of Large Language Models with Parameter Reallocation
|
||||
Keywords: distributed-systems,reinforcement-learning-from-human-feedback,large-language-models,llm-training
|
||||
Classifier: Development Status :: 2 - Pre-Alpha
|
||||
Classifier: Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.2
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Requires-Python: >=3.10
|
||||
Description-Content-Type: text/markdown
|
||||
License-File: LICENSE
|
||||
Requires-Dist: torch>2.0.0
|
||||
Requires-Dist: huggingface_hub
|
||||
Requires-Dist: datasets
|
||||
Requires-Dist: accelerate
|
||||
Requires-Dist: transformers==4.51.1
|
||||
Requires-Dist: numpy<2.0.0
|
||||
Requires-Dist: scipy
|
||||
Requires-Dist: pandas
|
||||
Requires-Dist: matplotlib
|
||||
Requires-Dist: seaborn
|
||||
Requires-Dist: h5py
|
||||
Requires-Dist: nltk
|
||||
Requires-Dist: sentencepiece
|
||||
Requires-Dist: einops
|
||||
Requires-Dist: tqdm
|
||||
Requires-Dist: rich
|
||||
Requires-Dist: orjson>=3.10.16
|
||||
Requires-Dist: pydantic
|
||||
Requires-Dist: PyYAML
|
||||
Requires-Dist: hydra-core==1.4.0.dev1
|
||||
Requires-Dist: packaging
|
||||
Requires-Dist: tabulate
|
||||
Requires-Dist: gymnasium>=1.1.1
|
||||
Requires-Dist: torchdata
|
||||
Requires-Dist: autoflake
|
||||
Requires-Dist: tensordict
|
||||
Requires-Dist: wandb
|
||||
Requires-Dist: tensorboardx
|
||||
Requires-Dist: colorama
|
||||
Requires-Dist: colorlog
|
||||
Requires-Dist: psutil
|
||||
Requires-Dist: pynvml
|
||||
Requires-Dist: swanlab[dashboard]
|
||||
Requires-Dist: ninja
|
||||
Requires-Dist: numba
|
||||
Requires-Dist: blosc
|
||||
Requires-Dist: pybind11>=2.10.0
|
||||
Requires-Dist: networkx==3.3
|
||||
Requires-Dist: aiofiles
|
||||
Requires-Dist: aiohttp>=3.11.10
|
||||
Requires-Dist: httpx>=0.28.1
|
||||
Requires-Dist: pyzmq
|
||||
Requires-Dist: paramiko
|
||||
Requires-Dist: etcd3
|
||||
Requires-Dist: protobuf<3.21
|
||||
Requires-Dist: ray
|
||||
Requires-Dist: redis
|
||||
Requires-Dist: fastapi>=0.115.12
|
||||
Requires-Dist: uvicorn>=0.34.2
|
||||
Requires-Dist: uvloop>=0.21.0
|
||||
Requires-Dist: flask
|
||||
Requires-Dist: build>=1.2.1
|
||||
Requires-Dist: wheel>=0.43.0
|
||||
Requires-Dist: setuptools<75.9,>=62.3.0
|
||||
Requires-Dist: cookiecutter>2.1.1
|
||||
Requires-Dist: distro-info>=1.0
|
||||
Requires-Dist: python-debian>=0.1.49
|
||||
Requires-Dist: func_timeout
|
||||
Requires-Dist: regex
|
||||
Requires-Dist: python_dateutil
|
||||
Requires-Dist: word2number
|
||||
Requires-Dist: Pebble
|
||||
Requires-Dist: timeout-decorator
|
||||
Requires-Dist: prettytable
|
||||
Requires-Dist: pytest
|
||||
Requires-Dist: ipython
|
||||
Requires-Dist: jupyter-book
|
||||
Requires-Dist: sphinx
|
||||
Requires-Dist: sphinx-nefertiti
|
||||
Requires-Dist: black==25.1.0
|
||||
Requires-Dist: isort==5.13.2
|
||||
Requires-Dist: clang-format==19.1.7
|
||||
Provides-Extra: dev
|
||||
Requires-Dist: pytest; extra == "dev"
|
||||
Requires-Dist: black==25.1.0; extra == "dev"
|
||||
Requires-Dist: isort==5.13.2; extra == "dev"
|
||||
Requires-Dist: clang-format==19.1.7; extra == "dev"
|
||||
Provides-Extra: docs
|
||||
Requires-Dist: sphinx; extra == "docs"
|
||||
Requires-Dist: sphinx-nefertiti; extra == "docs"
|
||||
Requires-Dist: jupyter-book; extra == "docs"
|
||||
Dynamic: license-file
|
||||
|
||||
<h1 align="center">
|
||||
<em>AReaL</em>: Ant Reasoning Reinforcement Learning for LLMs
|
||||
</h1>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://arxiv.org/pdf/2505.24298"><b>Paper</b></a> | <a href="https://inclusionai.github.io/AReaL/"><b>Documentation</b></a> | <a href="https://deepwiki.com/inclusionAI/AReaL"><b>Ask DeepWiki</b></a> | <a href="https://huggingface.co/collections/inclusionAI/areal-boba-2-683f0e819ccb7bb2e1b2f2d5"><b>🤗 Models & Data</b></a> |
|
||||
<a href="./assets/wechat_qrcode.png" target="_blank"><b>WeChat Group</b></a> |
|
||||
</p>
|
||||
|
||||
<img align="right" alt="ReaL" src="/assets/logo.png" width="20%">
|
||||
|
||||
AReaL (Ant Reasoning RL) is an open-source **fully asynchronous reinforcement learning training system** for large reasoning models developed at **the RL Lab, Ant Research**. Built upon the open-source project [RealHF](https://github.com/openpsi-project/ReaLHF), we are fully committed to open-source by providing training details, data, and infrastructure required to reproduce results along with the model itself. AReaL aims to help everyone build their own AI agents easily and affordably. Our team loves milk tea because it's delicious, customizable, and affordable. We hope you enjoy our project just like how you enjoy real-world milk tea (cheers).
|
||||
|
||||
**AReaL Highlights**
|
||||
|
||||
+ 🔥 <span style="color: red; font-weight: bold;">**[NEW] Asynchronous RL:**</span> With algorithm-system co-design, AReaL supports fully asynchronous RL for **the fastest training**! Experimental support for multi-turn agentic RL is also provided.
|
||||
+ 🛠️ **Open & Reproducible**: We continuously release _all code, datasets, and training recipes_ for RL training of LLMs.
|
||||
+ 🚀 **Scalability**: AReaL can seamlessly adapt to different computational resource settings, ranging from a single node to 1K GPUs.
|
||||
+ 🔪 **Cutting-Edge Performance:** AReaL can produce models with cutting-edge reasoning capabilities in math and coding. We are also actively working on agentic tasks.
|
||||
|
||||
## News
|
||||
|
||||
**[2025/06/03] (v0.3, boba²)** We release **boba²** (double-boba) for fully asynchronous RL training, which achieves a **2.77x speedup while obtaining on-par or even better training performance** compared to synchronous systems. Moreover, asynchronous RL makes it extremely easy to set up multi-turn agentic RL training! Check out [our v0.3 overview blog](/blog/AReaL_v0_3.md) and the [research paper](https://arxiv.org/pdf/2505.24298).
|
||||
|
||||
**[2025/03/31] (v0.2, boba)** Here comes our next milestone release - boba! Please call it A-ReaL-boba! This release includes much faster training with SGLang support and SOTA 7B and 32B models on math reasoning. Check our [v0.2 technical blog](/blog/AReaL_v0_2.md).
|
||||
|
||||
**[2025/02/24] (v0.1)** Our initial release includes reproducible results for 1.5B and 7B LRMs. Check our [v0.1 technical blog](/blog/AReaL_v0_1.md).
|
||||
|
||||
## Release Highlights
|
||||
|
||||
In our AReaL-boba² (A-ReaL-double-boba) release, we highlight the top 3 most important features:
|
||||
|
||||
+ A fully asynchronous RL training pipeline with **system and RL algorithm co-design**, achieving over 2.77x speedup without any performance drop. Check the [benchmark scripts and instructions here](https://github.com/inclusionAI/AReaL/tree/main/benchmark/verl_v0_3_0_post1_76084d3).
|
||||
|
||||
+ SOTA coding models, i.e., a 14B model with a **69.1 score on LCB-v5**. To reproduce, check the [configs](https://github.com/inclusionAI/AReaL/tree/main/examples/configs/v0.3-qwen3-code) and [instructions](https://inclusionai.github.io/AReaL/references/reproduce.html).
|
||||
|
||||
+ Experimental support for **multi-turn** agentic RL training. Check our [complete example](https://inclusionai.github.io/AReaL/customization/agent.html).
|
||||
|
||||
For the complete system design and more training details, please check [our v0.3 blog](/blog/AReaL_v0_3.md) and our [research paper](https://arxiv.org/pdf/2505.24298).
|
||||
|
||||
**Jump to the [quickstart section](https://github.com/inclusionAI/AReaL?tab=readme-ov-file#getting-started) if you want to quickly run an experiment and get your hands dirty!** 😈
|
||||
|
||||
### Overview of Asynchronous RL Training
|
||||
|
||||
During the synchronous RL training process, a generation step must wait until the longest sequence completes within the batch of LLM outputs. Due to the varying output lengths for LRMs, a synchronous RL system suffers from massive GPU idle time, leading to training inefficiency. Some recent works ([DeepCoder](https://pretty-radio-b75.notion.site/DeepCoder-A-Fully-Open-Source-14B-Coder-at-O3-mini-Level-1cf81902c14680b3bee5eb349a512a51), [Intellect](https://www.primeintellect.ai/blog/intellect-2)) propose overlapping a single training step with a single generation step to accelerate training. However, the largest bottleneck remains unchanged: the samples within a batch are still from the same model version, leading to waiting and GPU idle time.
|
||||
|
||||

|
||||
|
||||
*Fig.1. Left: Execution timeline of synchronous RL training. Right: Execution timeline of one-step overlap RL system.*
|
||||
|
||||
AReaL adopts a fully asynchronous RL training framework that completely decouples generation from training. In AReaL, LLM generation runs in a streaming manner, with each rollout worker continuously producing outputs without waiting. Meanwhile, trainer workers perform parallel model updates upon receiving training batches.
|
||||
|
||||

|
||||
|
||||
*Fig 2. Execution timeline of our fully asynchronous RL system.*
|
||||
|
||||
AReaL follows a system-algorithm co-design principle: on the system side, AReaL efficiently syncs model parameters and carefully controls the staleness of each training sample; on the algorithm side, AReaL improves the objective of PPO to make async-RL stable.
|
||||
|
||||
We compare the scalability of **asynchronous RL** training based on our AReaL-boba² system with **classical synchronous RL** training (we adopt the fastest open-source system veRL, main branch on 05/07/2025) across different model sizes and different numbers of H800 GPUs. AReaL demonstrates much improved scaling capabilities with respect to training throughput. This is also partially due to AReaL decoupling training and generation, leading to much fewer GPU memory fragments.
|
||||
|
||||

|
||||
|
||||
*Fig.3 The scaling trend of asynchronous RL (based on AReaL-boba2) and classical synchronous RL (based on veRL) with different model sizes. Dotted lines indicate ideal linear scaling.*
|
||||
|
||||
### SOTA Code Generation Model by AReaL-boba²
|
||||
|
||||
We use **Qwen3** as our base model. After asynchronous RL training, we achieve SOTA results on LiveCodeBench, Codeforces, and CodeContests benchmarks.
|
||||
|
||||
| **Model (8B)** | **LiveCodeBench v5**<br/>**(2024.10-2025.2)** | **Codeforces** | **CodeContests** |
|
||||
| :---: | :---: | :---: | :---: |
|
||||
| Qwen3-8B | 58.8 | 1879/96.7% | 31.4 |
|
||||
| DeepSeek-R1-0528-Qwen3-8B | 58.4 | 1945/97.3% | 31.0 |
|
||||
| [🤗 AReaL-boba²-8B-Open](https://huggingface.co/inclusionAI/AReaL-boba-2-8B-subset) | 62.0 | 1933/97.2% | **41.4** |
|
||||
| [🤗 AReaL-boba²-8B](https://huggingface.co/inclusionAI/AReaL-boba-2-8B) | **63.0** | **1962/97.5%** | 40.8 |
|
||||
|
||||
| **Model (14B)** | **LiveCodeBench v5**<br/>**(2024.10-2025.2)** | **Codeforces** | **CodeContests** |
|
||||
| :---: | :---: | :---: | :---: |
|
||||
| Qwen3-14B | 65.4 | 1978/97.7% | 38.3 |
|
||||
| DeepCoder-14B-Preview | 60.6 | 1936/95.3% | 40.1 |
|
||||
| [🤗 AReaL-boba²-14B-Open](https://huggingface.co/inclusionAI/AReaL-boba-2-14B-subset) | 67.3 | 1990/97.8% | **46.2** |
|
||||
| [🤗 AReal-boba²-14B](https://huggingface.co/inclusionAI/AReaL-boba-2-14B) | **69.1** | **2044/98.2%** | 46.1 |
|
||||
|
||||
| **Larger Models** | **LiveCodeBench v5**<br/>**(2024.10-2025.2)** | **Codeforces** | **CodeContests** |
|
||||
| :---: | :---: | :---: | :---: |
|
||||
| Qwen3-235B | 70.7 | 2056 | - |
|
||||
| DeepSeek-R1 | 64.3 | 2029 | - |
|
||||
| OpenAI-o3-mini (Medium) | 66.3 | 2036 | - |
|
||||
|
||||
*Table 1: Coding Task Performance Comparison. AReaL-boba²-8B/14B-Open denotes training results on open-source data. AReaL-boba²-8B/14B models are trained with an additional small amount of internal data and achieve SOTA performance on LiveCodeBench, Codeforces & CodeContests.*
|
||||
|
||||
We highlight the [tutorials](https://inclusionai.github.io/AReaL/customization/dataset.html) and [code walkthroughs](https://inclusionai.github.io/AReaL/developer/overview.html) about the following key features for asynchronous training:
|
||||
|
||||
+ [Streaming generation and reward computation](https://inclusionai.github.io/AReaL/developer/rollout/rollout_worker.html)
|
||||
+ [Interruptible rollout](https://inclusionai.github.io/AReaL/developer/rollout/gserver.html)
|
||||
+ [Data staleness control with the rollout controller](https://inclusionai.github.io/AReaL/developer/rollout/gserver.html)
|
||||
+ [The adoption of decoupled PPO loss](https://inclusionai.github.io/AReaL/customization/algorithm.html)
|
||||
|
||||
### RL Training for Multi-turn Agent
|
||||
|
||||
AReaL-boba² allows you to independently customize the [dataset](https://inclusionai.github.io/AReaL/customization/dataset.html), [rollout behavior](https://inclusionai.github.io/AReaL/customization/agent.html), and the [training algorithm](https://inclusionai.github.io/AReaL/customization/algorithm.html), without needing to modify the heavy system-level code.
|
||||
|
||||
In particular, we show a simple example to develop a multi-turn math agent for RL training. Please see the learning curve below and reference the [step-by-step guide](https://inclusionai.github.io/AReaL/customization/agent.html) if you want to implement your own agentic RL project.
|
||||
|
||||
## Getting Started
|
||||
Obtain the training data:
|
||||
- [Math](https://huggingface.co/datasets/inclusionAI/AReaL-boba-Data)
|
||||
- [Code](https://huggingface.co/datasets/inclusionAI/AReaL-boba-2-RL-Code)
|
||||
|
||||
For code training data, a simple preprocessing script was provided in `examples/data_preprocess/preprocess_training_data.py`:
|
||||
```bash
|
||||
python3 preprocess_training_data.py --data_path $original_data_path --output_path $training_data_path
|
||||
```
|
||||
|
||||
Train Qwen3 1.7B locally (Remember to modify `dataset.path` in the script below):
|
||||
```bash
|
||||
bash examples/run_async_ppo.sh
|
||||
```
|
||||
|
||||
Evaluation:
|
||||
|
||||
```bash
|
||||
cd evaluation
|
||||
# Evaluate the model
|
||||
python eval_and_aggregate.py \
|
||||
--model_path ${MODEL_PATH} \
|
||||
--output_path ${OUTPUT_PATH} \
|
||||
--data_names aime24,aime25 \
|
||||
--max_gen_tokens 32768 \
|
||||
--data_names codeforces,lcb_v5 \
|
||||
--prompt_type qwen3-think-pure \
|
||||
--temperature 1.0
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
+ [Documentation](https://inclusionai.github.io/AReaL/)
|
||||
+ [Contributing](https://inclusionai.github.io/AReaL/contrib.html)
|
||||
|
||||
### Quickstart
|
||||
|
||||
+ [Installation](https://inclusionai.github.io/AReaL/tutorial/installation.html)
|
||||
+ [Example: Improving the math capability of Qwen3 with PPO](https://inclusionai.github.io/AReaL/tutorial/quickstart.html)
|
||||
|
||||
### Benchmark and Reproduction
|
||||
|
||||
+ **Reproduce boba² Code Models**
|
||||
- 🤗 **Model weights**: [8B-code](https://huggingface.co/inclusionAI/AReaL-boba-2-8B), [14B-code](https://huggingface.co/inclusionAI/AReaL-boba-2-14B), [8B-code-open](https://huggingface.co/inclusionAI/AReaL-boba-2-8B-subset), [14B-code-open](https://huggingface.co/inclusionAI/AReaL-boba-2-14B-subset)
|
||||
- [Evaluation Guide](https://inclusionai.github.io/AReaL/tutorial/eval.html)
|
||||
- [Training configs](https://github.com/inclusionAI/AReaL/tree/main/examples/configs/v0.3-qwen3-code) and [instructions](https://inclusionai.github.io/AReaL/references/reproduce.html)
|
||||
+ [Scripts for Benchmark Training Throughput](https://github.com/inclusionAI/AReaL/tree/main/benchmark/verl_v0_3_0_post1_76084d3)
|
||||
|
||||
### Customization Guide
|
||||
|
||||
- [Use your own dataset](https://inclusionai.github.io/AReaL/customization/dataset.html)
|
||||
- [Modifying the reward function and rollout behavior (multi-turn agentic RL)](https://inclusionai.github.io/AReaL/customization/agent.html)
|
||||
- [Modifying PPO to GRPO](https://inclusionai.github.io/AReaL/customization/algorithm.html#grouped-advantage-normalization)
|
||||
- [Developing the decoupled PPO loss](https://inclusionai.github.io/AReaL/customization/algorithm.html#the-decoupled-ppo-loss)
|
||||
|
||||
### System Code Walkthrough
|
||||
|
||||
+ [Trainer](https://inclusionai.github.io/AReaL/developer/trainer/model_worker.html)
|
||||
+ [Model Backend and Algorithm Interface](https://inclusionai.github.io/AReaL/developer/trainer/algo_interface.html)
|
||||
+ [Rollout Controller](https://inclusionai.github.io/AReaL/developer/rollout/gserver.html)
|
||||
+ [Streaming generation and reward computation](https://inclusionai.github.io/AReaL/developer/rollout/rollout_worker.html)
|
||||
|
||||
## Future Plan
|
||||
|
||||
AReaL is under active development. We plan to have minor releases weekly and major releases monthly. Community engagement and contributions are extremely welcome. We are also **hiring interns and full-time employees** with open positions in both the US and China.
|
||||
|
||||
For the research and development plan already in place, please see the following list:
|
||||
|
||||
### System Development
|
||||
|
||||
- [x] Support for SGLang
|
||||
- [x] RL training with coding problems
|
||||
- [x] Asynchronous generation and RL training
|
||||
- [ ] Optimizations for distributed training: expert parallel for MOE and zero-bubble pipelining
|
||||
- [ ] RL for vision-language models (VLM)
|
||||
- [x] Multi-turn agentic RL
|
||||
- [ ] Function calling and tool use
|
||||
|
||||
### Algorithm Development
|
||||
|
||||
- [x] RL training recipes for 1.5B and 7B models
|
||||
- [x] A complete RL training recipe for 32B models
|
||||
- [ ] Sample-efficient multi-task RL algorithms
|
||||
- [ ] Agentic capabilities with end-to-end RL
|
||||
- [ ] Stable RL training for larger MOE models
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
We would like to note that major contributors are from the RL Lab at Ant Research and the Institute for Interdisciplinary Information Sciences, Tsinghua University.
|
||||
|
||||
Our team has also received invaluable assistance from the Data Intelligence Lab at Ant Research for data support and from the Super Computing Technology (SCT) team at Ant Group, particularly in the realm of large-scale cluster operations and maintenance.
|
||||
|
||||
We also appreciate all the pioneering works from the community, particularly the [ReaLHF](https://github.com/openpsi-project/ReaLHF) project from OpenPsi Inc. and other projects, including but not limited to [DeepScaleR](https://github.com/agentica-project/deepscaler), [Open-Reasoner-Zero](https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero/tree/main), [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF), [VeRL](https://github.com/volcengine/verl), [SGLang](https://github.com/sgl-project/sglang), [QwQ](https://github.com/QwenLM/QwQ), [Light-R1](https://github.com/Qihoo360/Light-R1) and [DAPO](https://github.com/BytedTsinghua-SIA/DAPO).
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{mei2025real,
|
||||
author = {Mei, Zhiyu and Fu, Wei and Li, Kaiwei and Wang, Guangju and Zhang, Huanchen and Wu, Yi},
|
||||
title = {ReaL: Efficient RLHF Training of Large Language Models with Parameter Reallocation},
|
||||
booktitle = {Proceedings of the Eighth Conference on Machine Learning and Systems,
|
||||
MLSys 2025, Santa Clara, CA, USA, May 12-15, 2025},
|
||||
publisher = {mlsys.org},
|
||||
year = {2025},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{fu2025areal,
|
||||
title={AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning},
|
||||
author={Wei Fu and Jiaxuan Gao and Xujie Shen and Chen Zhu and Zhiyu Mei and Chuyi He and Shusheng Xu and Guo Wei and Jun Mei and Jiashu Wang and Tongkai Yang and Binhang Yuan and Yi Wu},
|
||||
year={2025},
|
||||
eprint={2505.24298},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG},
|
||||
url={https://arxiv.org/abs/2505.24298},
|
||||
}
|
||||
```
|
|
@ -1,349 +0,0 @@
|
|||
LICENSE
|
||||
MANIFEST.in
|
||||
README.md
|
||||
pyproject.toml
|
||||
setup.py
|
||||
arealite/ppo_functional.py
|
||||
arealite/api/cli_args.py
|
||||
arealite/api/engine_api.py
|
||||
arealite/api/env_api.py
|
||||
arealite/api/io_struct.py
|
||||
arealite/api/llm_client_api.py
|
||||
arealite/api/llm_server_api.py
|
||||
arealite/api/reward_api.py
|
||||
arealite/api/rollout_api.py
|
||||
arealite/api/trainer_api.py
|
||||
arealite/api/vlm_client_api.py
|
||||
arealite/api/workflow_api.py
|
||||
arealite/cli/launch_server.py
|
||||
arealite/cli/launch_trainer.py
|
||||
arealite/dataset/__init__.py
|
||||
arealite/engine/__init__.py
|
||||
arealite/engine/constant.py
|
||||
arealite/engine/fsdp_engine.py
|
||||
arealite/engine/hf_engine.py
|
||||
arealite/engine/sglang_engine.py
|
||||
arealite/engine/sglang_remote.py
|
||||
arealite/engine/vl_fsdp_engine.py
|
||||
arealite/engine/sft/lm_engine.py
|
||||
arealite/env/__init__.py
|
||||
arealite/impl/engine/fsdp_wrapper.py
|
||||
arealite/impl/engine/hf_wrapper.py
|
||||
arealite/impl/rlvr/rlvr_collector.py
|
||||
arealite/impl/rlvr/vl_rlvr_collector.py
|
||||
arealite/impl/rlvr/rewards/areal_code.py
|
||||
arealite/impl/rlvr/rewards/areal_math.py
|
||||
arealite/impl/rlvr/rewards/clevr_count_70k.py
|
||||
arealite/impl/rlvr/rewards/gsm8k.py
|
||||
arealite/impl/trainer/grpo.py
|
||||
arealite/impl/trainer/sft.py
|
||||
arealite/impl/trainer/vl_grpo.py
|
||||
arealite/impl/trainer/vl_sft.py
|
||||
arealite/launcher/with_ray.py
|
||||
arealite/launcher/with_scheduler.py
|
||||
arealite/reward/__init__.py
|
||||
arealite/system/rollout_controller.py
|
||||
arealite/system/rollout_worker.py
|
||||
arealite/system/sglang_client.py
|
||||
arealite/system/sglang_server.py
|
||||
arealite/system/vl_sglang_client.py
|
||||
arealite/tests/test_engine.py
|
||||
arealite/tests/test_fsdp_engine.py
|
||||
arealite/tests/test_grpo.py
|
||||
arealite/tests/test_rlvr_workflow.py
|
||||
arealite/tests/test_rollout.py
|
||||
arealite/tests/test_rollout_controller.py
|
||||
arealite/tests/test_sft.py
|
||||
arealite/tests/test_sglang_client.py
|
||||
arealite/tests/test_sglang_engine.py
|
||||
arealite/tests/test_utils.py
|
||||
arealite/tests/test_vlm_grpo.py
|
||||
arealite/tests/test_vlm_sft.py
|
||||
arealite/tests/test_wrapper.py
|
||||
arealite/tests/utils.py
|
||||
arealite/utils/__init__.py
|
||||
arealite/utils/data.py
|
||||
arealite/utils/evaluator.py
|
||||
arealite/utils/fs.py
|
||||
arealite/utils/fsdp.py
|
||||
arealite/utils/functional.py
|
||||
arealite/utils/padding.py
|
||||
arealite/utils/save_load.py
|
||||
arealite/utils/saver.py
|
||||
arealite/utils/stats_logger.py
|
||||
arealite/utils/wrapper.py
|
||||
arealite/workflow/rlvr.py
|
||||
benchmark/verl_v0_3_0_post1_76084d3/build_cmd.py
|
||||
csrc/cugae/gae.cu
|
||||
csrc/interval_op/interval_op.cpp
|
||||
csrc/interval_op/interval_op.cu
|
||||
evaluation/aggregate_acc_from_generated.py
|
||||
evaluation/cf_elo_caculator.py
|
||||
evaluation/code_eval.py
|
||||
evaluation/data_loader.py
|
||||
evaluation/eval_and_aggregate.py
|
||||
evaluation/evaluate.py
|
||||
evaluation/examples.py
|
||||
evaluation/grader.py
|
||||
evaluation/math_eval.py
|
||||
evaluation/math_utils.py
|
||||
evaluation/model_utils.py
|
||||
evaluation/parser.py
|
||||
evaluation/python_executor.py
|
||||
evaluation/rm_maj_eval.py
|
||||
evaluation/trajectory.py
|
||||
evaluation/utils.py
|
||||
evaluation/code_verifier/local_verify.py
|
||||
evaluation/code_verifier/testing_util.py
|
||||
evaluation/latex2sympy/__init__.py
|
||||
evaluation/latex2sympy/asciimath_printer.py
|
||||
evaluation/latex2sympy/latex2sympy2.py
|
||||
evaluation/latex2sympy/setup.py
|
||||
evaluation/latex2sympy/gen/PSLexer.py
|
||||
evaluation/latex2sympy/gen/PSListener.py
|
||||
evaluation/latex2sympy/gen/PSParser.py
|
||||
evaluation/latex2sympy/gen/__init__.py
|
||||
evaluation/latex2sympy/sandbox/linalg_equations.py
|
||||
evaluation/latex2sympy/sandbox/linalg_span.py
|
||||
evaluation/latex2sympy/sandbox/matrix.py
|
||||
evaluation/latex2sympy/sandbox/matrix_placeholders.py
|
||||
evaluation/latex2sympy/sandbox/sandbox.py
|
||||
evaluation/latex2sympy/sandbox/sandbox_equality.py
|
||||
evaluation/latex2sympy/sandbox/sectan.py
|
||||
evaluation/latex2sympy/sandbox/vector.py
|
||||
evaluation/latex2sympy/tests/__init__.py
|
||||
evaluation/latex2sympy/tests/abs_test.py
|
||||
evaluation/latex2sympy/tests/all_bad_test.py
|
||||
evaluation/latex2sympy/tests/all_good_test.py
|
||||
evaluation/latex2sympy/tests/atom_expr_test.py
|
||||
evaluation/latex2sympy/tests/binomial_test.py
|
||||
evaluation/latex2sympy/tests/ceil_test.py
|
||||
evaluation/latex2sympy/tests/complex_test.py
|
||||
evaluation/latex2sympy/tests/context.py
|
||||
evaluation/latex2sympy/tests/exp_test.py
|
||||
evaluation/latex2sympy/tests/floor_test.py
|
||||
evaluation/latex2sympy/tests/gcd_test.py
|
||||
evaluation/latex2sympy/tests/greek_test.py
|
||||
evaluation/latex2sympy/tests/grouping_test.py
|
||||
evaluation/latex2sympy/tests/lcm_test.py
|
||||
evaluation/latex2sympy/tests/left_right_cdot_test.py
|
||||
evaluation/latex2sympy/tests/linalg_test.py
|
||||
evaluation/latex2sympy/tests/max_test.py
|
||||
evaluation/latex2sympy/tests/min_test.py
|
||||
evaluation/latex2sympy/tests/mod_test.py
|
||||
evaluation/latex2sympy/tests/overline_test.py
|
||||
evaluation/latex2sympy/tests/pi_test.py
|
||||
evaluation/latex2sympy/tests/trig_test.py
|
||||
evaluation/latex2sympy/tests/variable_test.py
|
||||
examples/arealite/gsm8k_sft.py
|
||||
examples/arealite/dataset/clevr_count_70k.py
|
||||
examples/arealite/dataset/gsm8k.py
|
||||
examples/data_preprocess/codeforce_process.py
|
||||
examples/data_preprocess/math_code_process.py
|
||||
examples/data_preprocess/math_process.py
|
||||
examples/data_preprocess/preprocess_training_data.py
|
||||
examples/env/setup_env_and_start_train.py
|
||||
examples/env/validate_installation.py
|
||||
functioncall/__init__.py
|
||||
functioncall/base/__init__.py
|
||||
functioncall/base/call.py
|
||||
functioncall/base/utils.py
|
||||
functioncall/code/__init__.py
|
||||
functioncall/code/local_verify.py
|
||||
functioncall/code/verify.py
|
||||
functioncall/code/function/handler.py
|
||||
functioncall/code/function/testing_util.py
|
||||
functioncall/math/__init__.py
|
||||
functioncall/math/verify.py
|
||||
functioncall/math/function/grader.py
|
||||
functioncall/math/function/handler.py
|
||||
functioncall/math/function/parser.py
|
||||
functioncall/test/performance_eval.py
|
||||
realhf/__init__.py
|
||||
realhf/utils.py
|
||||
realhf/version.py
|
||||
realhf.egg-info/PKG-INFO
|
||||
realhf.egg-info/SOURCES.txt
|
||||
realhf.egg-info/dependency_links.txt
|
||||
realhf.egg-info/requires.txt
|
||||
realhf.egg-info/top_level.txt
|
||||
realhf/api/cli_args.py
|
||||
realhf/api/core/agent_api.py
|
||||
realhf/api/core/config.py
|
||||
realhf/api/core/data_api.py
|
||||
realhf/api/core/dfg.py
|
||||
realhf/api/core/env_api.py
|
||||
realhf/api/core/model_api.py
|
||||
realhf/api/core/system_api.py
|
||||
realhf/api/from_hf/__init__.py
|
||||
realhf/api/from_hf/gemma.py
|
||||
realhf/api/from_hf/gpt2.py
|
||||
realhf/api/from_hf/llama.py
|
||||
realhf/api/from_hf/mistral.py
|
||||
realhf/api/from_hf/mixtral.py
|
||||
realhf/api/from_hf/qwen2.py
|
||||
realhf/api/from_hf/qwen3.py
|
||||
realhf/api/quickstart/__init__.py
|
||||
realhf/api/quickstart/device_mesh.py
|
||||
realhf/api/quickstart/entrypoint.py
|
||||
realhf/api/quickstart/search.py
|
||||
realhf/apps/__init__.py
|
||||
realhf/apps/main.py
|
||||
realhf/apps/quickstart.py
|
||||
realhf/apps/remote.py
|
||||
realhf/base/__init__.py
|
||||
realhf/base/cluster.py
|
||||
realhf/base/constants.py
|
||||
realhf/base/datapack.py
|
||||
realhf/base/gpu_utils.py
|
||||
realhf/base/importing.py
|
||||
realhf/base/logging.py
|
||||
realhf/base/monitor.py
|
||||
realhf/base/name_resolve.py
|
||||
realhf/base/names.py
|
||||
realhf/base/network.py
|
||||
realhf/base/numpy_utils.py
|
||||
realhf/base/pkg_version.py
|
||||
realhf/base/prologue.py
|
||||
realhf/base/ray_utils.py
|
||||
realhf/base/recover.py
|
||||
realhf/base/saveload_utils.py
|
||||
realhf/base/security.py
|
||||
realhf/base/seeding.py
|
||||
realhf/base/slurm_utils.py
|
||||
realhf/base/stats_tracker.py
|
||||
realhf/base/testing.py
|
||||
realhf/base/timeutil.py
|
||||
realhf/base/topology.py
|
||||
realhf/experiments/async_exp/async_ppo_math_exp.py
|
||||
realhf/experiments/async_exp/async_rl_exp.py
|
||||
realhf/experiments/common/check.py
|
||||
realhf/experiments/common/common.py
|
||||
realhf/experiments/common/math_code_eval_exp.py
|
||||
realhf/experiments/common/null_exp.py
|
||||
realhf/experiments/common/ppo_math_exp.py
|
||||
realhf/experiments/common/sft_exp.py
|
||||
realhf/experiments/common/utils.py
|
||||
realhf/impl/agent/__init__.py
|
||||
realhf/impl/agent/math_multi_turn_agent.py
|
||||
realhf/impl/agent/math_single_step_agent.py
|
||||
realhf/impl/agent/null_agent.py
|
||||
realhf/impl/dataset/__init__.py
|
||||
realhf/impl/dataset/math_code_dataset.py
|
||||
realhf/impl/dataset/math_parser.py
|
||||
realhf/impl/dataset/prompt_answer_dataset.py
|
||||
realhf/impl/dataset/prompt_dataset.py
|
||||
realhf/impl/dataset/rw_paired_dataset.py
|
||||
realhf/impl/environment/__init__.py
|
||||
realhf/impl/environment/math_code_single_step_env.py
|
||||
realhf/impl/model/__init__.py
|
||||
realhf/impl/model/backend/inference.py
|
||||
realhf/impl/model/backend/megatron.py
|
||||
realhf/impl/model/backend/mock_train.py
|
||||
realhf/impl/model/backend/pipe_runner.py
|
||||
realhf/impl/model/backend/sglang.py
|
||||
realhf/impl/model/backend/vllm.py
|
||||
realhf/impl/model/backend/thirdparty/megatron/__init__.py
|
||||
realhf/impl/model/backend/thirdparty/megatron/v0_6_0/lr_schduler.py
|
||||
realhf/impl/model/backend/thirdparty/vllm/__init__.py
|
||||
realhf/impl/model/backend/thirdparty/vllm/context.py
|
||||
realhf/impl/model/backend/thirdparty/vllm/custom_cache_manager.py
|
||||
realhf/impl/model/backend/thirdparty/vllm/engine.py
|
||||
realhf/impl/model/backend/thirdparty/vllm/executor.py
|
||||
realhf/impl/model/comm/global_comm.py
|
||||
realhf/impl/model/comm/param_realloc.py
|
||||
realhf/impl/model/conversion/hf_registry.py
|
||||
realhf/impl/model/interface/fused_interface.py
|
||||
realhf/impl/model/interface/math_rw_interface.py
|
||||
realhf/impl/model/interface/ppo_interface.py
|
||||
realhf/impl/model/interface/sft_interface.py
|
||||
realhf/impl/model/modules/__init__.py
|
||||
realhf/impl/model/modules/activations.py
|
||||
realhf/impl/model/modules/attn.py
|
||||
realhf/impl/model/modules/embedding.py
|
||||
realhf/impl/model/modules/mlp.py
|
||||
realhf/impl/model/modules/rms.py
|
||||
realhf/impl/model/modules/rotary.py
|
||||
realhf/impl/model/modules/moe/__init__.py
|
||||
realhf/impl/model/modules/moe/experts.py
|
||||
realhf/impl/model/modules/moe/layer.py
|
||||
realhf/impl/model/modules/moe/router.py
|
||||
realhf/impl/model/modules/moe/token_dispatcher.py
|
||||
realhf/impl/model/nn/flatten_param.py
|
||||
realhf/impl/model/nn/real_llm_api.py
|
||||
realhf/impl/model/nn/real_llm_base.py
|
||||
realhf/impl/model/nn/real_llm_generate.py
|
||||
realhf/impl/model/nn/real_llm_parallel.py
|
||||
realhf/impl/model/parallelism/pipeline_parallel/instruction.py
|
||||
realhf/impl/model/parallelism/pipeline_parallel/p2p.py
|
||||
realhf/impl/model/parallelism/pipeline_parallel/static_schedule.py
|
||||
realhf/impl/model/parallelism/pipeline_parallel/tensor_storage.py
|
||||
realhf/impl/model/parallelism/tensor_parallel/mappings.py
|
||||
realhf/impl/model/parallelism/tensor_parallel/modules.py
|
||||
realhf/impl/model/parallelism/tensor_parallel/utils.py
|
||||
realhf/impl/model/utils/cuda_graph.py
|
||||
realhf/impl/model/utils/dpo_functional.py
|
||||
realhf/impl/model/utils/functional.py
|
||||
realhf/impl/model/utils/logits_warper.py
|
||||
realhf/impl/model/utils/moe.py
|
||||
realhf/impl/model/utils/padding.py
|
||||
realhf/impl/model/utils/ppo_functional.py
|
||||
realhf/impl/model/utils/random.py
|
||||
realhf/scheduler/client.py
|
||||
realhf/scheduler/evaluator.py
|
||||
realhf/scheduler/local/client.py
|
||||
realhf/scheduler/slurm/client.py
|
||||
realhf/scheduler/slurm/utils.py
|
||||
realhf/system/__init__.py
|
||||
realhf/system/buffer.py
|
||||
realhf/system/controller.py
|
||||
realhf/system/data_manager.py
|
||||
realhf/system/flops_counter.py
|
||||
realhf/system/function_executor.py
|
||||
realhf/system/generation_server.py
|
||||
realhf/system/gserver_manager.py
|
||||
realhf/system/master_worker.py
|
||||
realhf/system/model_function_call.py
|
||||
realhf/system/model_worker.py
|
||||
realhf/system/partial_rollout.py
|
||||
realhf/system/push_pull_stream.py
|
||||
realhf/system/redistributor.py
|
||||
realhf/system/request_reply_stream.py
|
||||
realhf/system/rollout_worker.py
|
||||
realhf/system/stream_dataset.py
|
||||
realhf/system/worker_base.py
|
||||
realhf/system/worker_control.py
|
||||
tests/__init__.py
|
||||
tests/fixtures.py
|
||||
tests/agent/test_math_single_step_agent.py
|
||||
tests/comm/test_data_transfer.py
|
||||
tests/comm/test_param_realloc.py
|
||||
tests/cpp_extensions/test_cugae.py
|
||||
tests/cpp_extensions/test_grouped_gemm.py
|
||||
tests/cpp_extensions/test_interval_ops.py
|
||||
tests/data/test_dfg.py
|
||||
tests/data/test_dual_clip.py
|
||||
tests/data/test_epoch_counter.py
|
||||
tests/data/test_load_data.py
|
||||
tests/data/test_sequence_gather_split.py
|
||||
tests/data/test_stats_tracker.py
|
||||
tests/distributed/test_find_port.py
|
||||
tests/distributed/test_name_resolve.py
|
||||
tests/experiments/test_buffer_recover.py
|
||||
tests/experiments/test_math_ppo.py
|
||||
tests/experiments/test_sft.py
|
||||
tests/experiments/utils.py
|
||||
tests/interfaces/test_multi_task_reward.py
|
||||
tests/legacy/test_sglang_tp.py
|
||||
tests/legacy/test_vllm_tp.py
|
||||
tests/model/test_cpu_inference.py
|
||||
tests/model/test_distributed_load_hf.py
|
||||
tests/reward/test_math_reward.py
|
||||
tests/system/test_gserver_manager.py
|
||||
tests/system/test_partial_rollout.py
|
||||
tests/system/test_push_pull_stream.py
|
||||
tests/system/test_stream_dataset.py
|
||||
training/main_async_ppo.py
|
||||
training/main_sft.py
|
||||
training/main_sync_ppo.py
|
||||
training/utils.py
|
|
@ -1 +0,0 @@
|
|||
|
|
@ -1,83 +0,0 @@
|
|||
torch>2.0.0
|
||||
huggingface_hub
|
||||
datasets
|
||||
accelerate
|
||||
transformers==4.51.1
|
||||
numpy<2.0.0
|
||||
scipy
|
||||
pandas
|
||||
matplotlib
|
||||
seaborn
|
||||
h5py
|
||||
nltk
|
||||
sentencepiece
|
||||
einops
|
||||
tqdm
|
||||
rich
|
||||
orjson>=3.10.16
|
||||
pydantic
|
||||
PyYAML
|
||||
hydra-core==1.4.0.dev1
|
||||
packaging
|
||||
tabulate
|
||||
gymnasium>=1.1.1
|
||||
torchdata
|
||||
autoflake
|
||||
tensordict
|
||||
wandb
|
||||
tensorboardx
|
||||
colorama
|
||||
colorlog
|
||||
psutil
|
||||
pynvml
|
||||
swanlab[dashboard]
|
||||
ninja
|
||||
numba
|
||||
blosc
|
||||
pybind11>=2.10.0
|
||||
networkx==3.3
|
||||
aiofiles
|
||||
aiohttp>=3.11.10
|
||||
httpx>=0.28.1
|
||||
pyzmq
|
||||
paramiko
|
||||
etcd3
|
||||
protobuf<3.21
|
||||
ray
|
||||
redis
|
||||
fastapi>=0.115.12
|
||||
uvicorn>=0.34.2
|
||||
uvloop>=0.21.0
|
||||
flask
|
||||
build>=1.2.1
|
||||
wheel>=0.43.0
|
||||
setuptools<75.9,>=62.3.0
|
||||
cookiecutter>2.1.1
|
||||
distro-info>=1.0
|
||||
python-debian>=0.1.49
|
||||
func_timeout
|
||||
regex
|
||||
python_dateutil
|
||||
word2number
|
||||
Pebble
|
||||
timeout-decorator
|
||||
prettytable
|
||||
pytest
|
||||
ipython
|
||||
jupyter-book
|
||||
sphinx
|
||||
sphinx-nefertiti
|
||||
black==25.1.0
|
||||
isort==5.13.2
|
||||
clang-format==19.1.7
|
||||
|
||||
[dev]
|
||||
pytest
|
||||
black==25.1.0
|
||||
isort==5.13.2
|
||||
clang-format==19.1.7
|
||||
|
||||
[docs]
|
||||
sphinx
|
||||
sphinx-nefertiti
|
||||
jupyter-book
|
|
@ -1,14 +0,0 @@
|
|||
arealite
|
||||
assets
|
||||
benchmark
|
||||
blog
|
||||
ci
|
||||
csrc
|
||||
docs
|
||||
evaluation
|
||||
examples
|
||||
functioncall
|
||||
patch
|
||||
realhf
|
||||
tests
|
||||
training
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue