checkout previous impl

This commit is contained in:
博惟 2025-07-07 15:51:32 +08:00
parent 6710d5f275
commit 3a0f1e558c
15 changed files with 2993 additions and 0 deletions

View File

@ -0,0 +1,232 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import os
import re
import uuid
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache
from typing import Any, List, Optional, Tuple
import torch
from arealite.api.cli_args import GenerationHyperparameters, TrainingArgs
from arealite.api.io_struct import (
AgentInferInput,
AgentInferOutput,
LLMRequest,
Trajectory,
TrajStats,
)
from arealite.api.llm_client_api import LLMClient
from arealite.api.rollout_api import Agent, Environment, RolloutCollector
from arealite.utils import pad_sequences_to_tensors
from functioncall.code.local_verify import code_verify as local_code_verify
from functioncall.code.verify import code_verify
from functioncall.math.verify import math_verify
from realhf.impl.dataset.math_code_dataset import load_metadata
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify
@lru_cache(maxsize=128)
def _load_metadata_cached(dataset_path: str):
"""Cached version of load_metadata to avoid reloading metadata each time."""
return load_metadata(dataset_path)
def extract_code(text, min_length=20):
"""Extract code blocks from text."""
code_pattern = r"(?i)```(?:python|py|cpp|CPP)?\s*\n?(.*?)\n?```"
code_blocks = re.findall(code_pattern, text, re.DOTALL)
valid_blocks = []
for block in code_blocks:
clean_block = block.strip()
if len(clean_block) < min_length:
continue
valid_blocks.append(clean_block)
if not valid_blocks:
return None
# return the last code block
return valid_blocks[-1]
@dataclass
class MathCodeAction:
query_id: str
answer: str
@dataclass
class MathCodeObs:
query_id: str
prompt_ids: List[int]
class MathCodeSingleStepEnv(Environment):
"""Math and Code single-step verification environment."""
def __init__(self, args: TrainingArgs, solution_path: str):
super().__init__(args)
self.id2info, _ = _load_metadata_cached(solution_path)
def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[Any, dict]:
"""Reset the environment."""
super().reset(seed=seed)
try:
prompt_ids = options["input_ids"]
query_id = options["query_id"]
except KeyError:
raise RuntimeError("`input_ids` and `query_id` must be set in env options.")
# Return dummy observation and info
return MathCodeObs(query_id=query_id, prompt_ids=prompt_ids), {}
def step(
self, action: MathCodeAction
) -> Tuple[MathCodeObs, float, bool, bool, dict]:
"""Execute one step in the environment."""
query_id = action.query_id
answer = action.answer
query_id = query_id.split("@")[0]
cur_task = self.id2info[query_id]["task"]
if cur_task == "math":
# Run math verification
format_reward = math_verify_call(self.id2info, [answer], [query_id])[0]
elif cur_task == "code":
# Extract code blocks and run code verification
extracted_answer = extract_code(answer)
format_reward = code_verify_call(
self.id2info, [extracted_answer], [query_id]
)[0]
else:
raise NotImplementedError(f"Task type '{cur_task}' not implemented")
# Return: observation, reward, terminated, truncated, info
terminated = True # Single step environment always terminates
truncated = False
info = {"task": cur_task, "query_id": query_id}
return (
None,
format_reward,
terminated,
truncated,
info,
)
class MathCodeAgent(Agent):
async def aact(self, inp: AgentInferInput) -> AgentInferOutput:
"""Async version of act. Given an observation, return an action."""
# Extract information from observation
obs: MathCodeObs = inp.obs
query_id = obs.query_id
prompt_ids = obs.prompt_ids
# Create LLM request
llm_req = LLMRequest(
rid=str(query_id) + "-" + str(uuid.uuid4()),
input_ids=prompt_ids,
gconfig=inp.gconfig,
)
# Generate response using async LLM client
llm_resp = await inp.llm_client.agenerate(llm_req)
# Extract answers from completion
answer = llm_resp.completion
return AgentInferOutput(
action=MathCodeAction(query_id=query_id, answer=answer),
llm_req=llm_req,
llm_resp=llm_resp,
)
def reset(self):
"""Resets the agent's memory."""
pass # Stateless agent, no memory to reset
async def areset(self):
"""Async version of reset. Resets the agent's memory."""
pass # Stateless agent, no memory to reset
class MathCodeSingleStepCollector(RolloutCollector):
async def arun_episode(
self,
llm_client: LLMClient,
gconfig: GenerationHyperparameters,
env_option: Optional[Any] = None,
seed: Optional[int] = None,
) -> Trajectory:
"""Async version of run_episode. Run a single episode and return the trajectory."""
# Reset the environment and the agent's memory.
obs, _ = self.env.reset(options=env_option, seed=seed)
await self.agent.areset()
data = []
rewards = []
tik = datetime.now().timestamp()
ret = 0.0
ep_len = 0
done = False
# Episode loop.
while not done:
# Take an action by sending a request to generation server.
agent_infer_in = AgentInferInput(
obs=obs, gconfig=gconfig, llm_client=llm_client
)
agent_infer_out = await self.agent.aact(agent_infer_in)
action = agent_infer_out.action
# Advance one step in the environment.
nex_obs, reward, terminated, truncated, _ = self.env.step(action)
# Collect the step data.
resp = agent_infer_out.llm_resp
input_len = len(resp.input_tokens)
output_len = len(resp.output_tokens)
input_ids = resp.input_tokens + resp.output_tokens
prompt_mask = [1] * input_len + [0] * output_len
logprobs = [0.0] * input_len + resp.output_logprobs
versions = [-1] * input_len + resp.output_versions
d = dict(
input_ids=torch.tensor(input_ids, dtype=torch.long),
prompt_mask=torch.tensor(prompt_mask, dtype=torch.bool),
logprobs=torch.tensor(logprobs, dtype=torch.float32),
versions=torch.tensor(versions, dtype=torch.long),
)
data.append(d)
rewards.append(reward)
ret += float(reward)
ep_len += 1
# Prepare information for the next step.
done = terminated or truncated
obs = nex_obs
return Trajectory(
prompt=env_option,
data=dict(rewards=torch.tensor(rewards), **pad_sequences_to_tensors(data)),
stats=TrajStats(
start_time=tik,
total_reward=ret,
episode_length=ep_len,
info={},
),
)

View File

@ -0,0 +1,7 @@
from datasets import Dataset
def process_areal_dataset(dataset: Dataset, tokenizer):
return dataset.map(
lambda x: tokenizer(x["prompt"], return_attention_mask=False), batched=True
)

View File

@ -0,0 +1,46 @@
from datasets import Dataset
def process_gsm8k_rl_dataset(dataset: Dataset, tokenizer, reward_mode):
def process_example(example, idx):
# Add query_id column
example["query_id"] = str(idx)
example["prompt"] = example["question"]
# used by the reward function
example["method"] = reward_mode
return example
dataset = dataset.map(
lambda example, idx: process_example(example, idx),
with_indices=True,
)
return dataset.map(
lambda x: tokenizer(x["question"], return_attention_mask=False), batched=True
)
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
def process_example(example, idx):
# Add query_id column
example["query_id"] = str(idx)
example["prompt"] = example["question"]
example["seq"] = example["prompt"] + example["answer"] + tokenizer.eos_token
return example
dataset = dataset.map(
lambda example, idx: process_example(example, idx),
with_indices=True,
)
def _tokenize(example):
example["prompt"] = tokenizer(example["prompt"], return_attention_mask=False)[
"input_ids"
]
example["seq"] = tokenizer(example["seq"], return_attention_mask=False)[
"input_ids"
]
return example
dataset = dataset.map(lambda x: _tokenize(x), batched=True)
return dataset

View File

@ -0,0 +1,553 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import asyncio
import functools
import math
import os
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PreTrainedModel,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import EngineConfig, FSDPConfig, MicroBatchSpec, TrainingArgs
from arealite.api.engine_api import SPMDWrapper
from arealite.api.io_struct import FinetuneSpec
from arealite.api.llm_client_api import LLMClient
from arealite.utils import (
get_state_dict_from_repo_id_or_path,
recorder_list,
split_dict_tensor_with_cu_seqlens,
unpack_sequence,
)
from realhf.api.cli_args import ParallelismConfig
from realhf.base import constants
from realhf.base.pkg_version import is_version_greater_or_equal
if is_version_greater_or_equal("torch", "2.6.0"):
from torch.distributed.fsdp import (
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
fully_shard,
)
elif is_version_greater_or_equal("torch", "2.4.0"):
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
fully_shard,
)
else:
fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = (
None,
None,
None,
None,
)
from torch.distributed.device_mesh import init_device_mesh
def fsdp2_clip_grad_norm_(
parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None
):
"""torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor"""
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
else:
# prevent generators from being exhausted
parameters = list(parameters)
grads = [p.grad for p in parameters if p.grad is not None]
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
return total_norm
def create_fsdp_device_mesh(shard_size, world_size):
if shard_size < 0 or shard_size >= world_size:
device_mesh = init_device_mesh(
"cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)
)
else:
device_mesh = init_device_mesh(
"cuda",
mesh_shape=(world_size // shard_size, shard_size),
mesh_dim_names=("ddp", "fsdp"),
)
return device_mesh
def apply_fsdp2(model, fsdp_kwargs, wrap_policy):
"""model: AutoModelForCausalLM"""
assert (
CPUOffloadPolicy is not None
), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", list())
fsdp_transformer_layer_cls_to_wrap = (
wrap_policy.transformer_layer_cls_to_wrap if wrap_policy is not None else list()
)
if not fsdp_transformer_layer_cls_to_wrap:
fsdp_transformer_layer_cls_to_wrap = default_transformer_cls_names_to_wrap
if isinstance(fsdp_transformer_layer_cls_to_wrap, str):
fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]
assert (
len(fsdp_transformer_layer_cls_to_wrap) > 0
and fsdp_transformer_layer_cls_to_wrap[0] is not None
)
modules = []
for name, module in model.named_modules():
if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (
isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings
):
modules.append(module)
for idx, module in enumerate(modules):
fully_shard(module, **fsdp_kwargs)
fully_shard(
model, **fsdp_kwargs
) # fsdp2 will not reshard_after_forward for root module
def fsdp2_load_full_state_dict(
model: PreTrainedModel,
full_state: dict,
cpu_offload=None,
tie_word_embeddings=False,
):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args:
model (`torch.nn.Module`): The model to load the state dict into
full_state (`dict`): The full state dict to load, can only be on rank 0
"""
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
set_model_state_dict,
)
device = torch.cuda.current_device()
model = model.to(device=device, non_blocking=True)
cpu_offload = cpu_offload is not None
options = StateDictOptions(
full_state_dict=True,
cpu_offload=cpu_offload,
broadcast_from_rank0=True,
strict=not tie_word_embeddings,
)
set_model_state_dict(model, full_state, options=options)
if tie_word_embeddings:
model.tie_weights()
# rotary_emb is not in state_dict, so we need to broadcast it manually
for name, buf in model.named_buffers():
dist.broadcast(buf, src=0)
if cpu_offload:
model.to("cpu", non_blocking=True)
for buf in model.buffers():
buf.data = buf.data.to(device)
def get_cosine_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
The minimum lr ratio w.r.t the maximum.
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
coef = (1 - min_lr_ratio) * 0.5
intercept = (1 + min_lr_ratio) * 0.5
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return min_lr_ratio + (1.0 - min_lr_ratio) * (
float(current_step) / float(max(1, num_warmup_steps))
)
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
return max(min_lr_ratio, x * coef + intercept)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
class FSDPEngine(SPMDWrapper):
"""Simplified FSDP engine for transformer models."""
def __init__(self, args: TrainingArgs, engine_config: EngineConfig):
super().__init__(args, engine_config)
assert is_version_greater_or_equal(
"torch", "2.4.0"
), f"arealite only supports FSDP2, which requires torch>=2.4.0"
self.fsdp_config = engine_config.backend.fsdp
if self.fsdp_config is None:
self.fsdp_config = FSDPConfig()
self.optimizer_config = engine_config.optimizer
self.model = None
self.optimizer = None
self.model_config = None
self.device_mesh = None
self.cpu_offload = None
self.world_size = int(os.environ["WORLD_SIZE"])
def train(self, mode: bool = True):
"""Set the module in training mode."""
assert self.model is not None
self.model.train(mode=mode)
return self
def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec):
"""Initialize distributed communication and model."""
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.engine_config.path,
trust_remote_code=True,
)
with torch.device("cuda"):
# initialize scratch model from config
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation="flash_attention_2",
)
# Simple auto wrap policy
# TODO: fix wrap policy
mixed_precision_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
cast_forward_inputs=True,
)
device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size)
self.device_mesh = device_mesh
# sharding_strategy = ShardingStrategy.FULL_SHARD
self.cpu_offload = (
CPUOffloadPolicy() if self.fsdp_config.offload_params else None
)
fsdp_kwargs = {
"mesh": device_mesh,
"mp_policy": mixed_precision_policy,
"offload_policy": self.cpu_offload,
"reshard_after_forward": True,
}
# Wrap with FSDP2
apply_fsdp2(model, fsdp_kwargs, self.fsdp_config.wrap_policy)
self.model = model
# Set up optimizer
if self.optimizer_config is not None:
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
self.optimizer_config.warmup_steps_proportion * total_train_steps
)
if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
def train(self, mode: bool = True):
self.model.train(mode)
return self
def train_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict:
"""Train on a batch using gradient accumulation."""
# self._initialize_fsdp_train()
assert self.optimizer is not None
assert self.optimizer_config is not None
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_inputs = split_dict_tensor_with_cu_seqlens(input_, mb_spec).mbs
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)
assert total_loss_weight != 0
dist.all_reduce(total_loss_weight)
# Process microbatches with gradient accumulation
for i, mb_input in enumerate(mb_inputs):
outputs = self.model(**mb_input)
loss = loss_fn(outputs.logits, mb_input)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
# Scale loss for accumulation
# Revert gradient averaging across dp ranks
loss_scale *= self.world_size
loss *= loss_scale
loss.backward()
grad_norm = fsdp2_clip_grad_norm_(
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
)
if not torch.isfinite(grad_norm):
self.optimizer.zero_grad()
update_successful = False
else:
self.optimizer.step()
update_successful = True
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
lr=current_lr,
)
def step_lr_scheduler(self):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
@torch.no_grad()
def eval_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
loss = loss_fn(outputs.logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
total_loss += loss.item() * loss_scale
total_weight += loss_scale
return torch.tensor(total_loss / total_weight)
@torch.no_grad()
def forward(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),
) -> Any | None:
"""Forward pass with optional post-processing."""
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
if output_seqlens is None:
cu_seqlens = input_["cu_seqlens"]
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
if post_hook:
result = post_hook(outputs.logits, mb_input)
results.append(result)
else:
results.append(outputs.logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=1)
return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices))
def get_hf_model_state_dict(self) -> Dict[str, torch.Tensor]:
"""Get model state dict for saving."""
if self.model is None:
raise RuntimeError("Model not initialized")
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
return self.model.state_dict()
def save_model_to_hf(
self,
path: str,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
base_model_path: Optional[str] = None,
):
"""Save model in HuggingFace format."""
if self.model is None:
raise RuntimeError("Model not initialized")
os.makedirs(path, exist_ok=True)
# FSDP2 checkpoint saving
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
# Get full state dict with FSDP2
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(self.model, options=options)
# save huggingface model
if dist.get_rank() == 0:
os.makedirs(path, exist_ok=True)
self.model.save_pretrained(path, state_dict=state_dict)
self.model_config.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
dist.barrier()
def load_model_from_hf(self, path: str):
"""Load model from HuggingFace format."""
if dist.get_rank() == 0:
full_state = get_state_dict_from_repo_id_or_path(path)
else:
full_state = {}
fsdp2_load_full_state_dict(
self.model,
full_state,
self.cpu_offload,
tie_word_embeddings=self.model_config.tie_word_embeddings,
)
def save_optimizer_state(self, path: str):
"""Save optimizer state."""
if self.optimizer is None:
raise RuntimeError("Optimizer not initialized")
os.makedirs(path, exist_ok=True)
torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt"))
def load_optimizer_state(self, path: str):
"""Load optimizer state."""
if self.optimizer is None:
raise RuntimeError("Optimizer not initialized")
optimizer_path = os.path.join(path, "optimizer.pt")
if os.path.exists(optimizer_path):
self.optimizer.load_state_dict(
torch.load(optimizer_path, map_location="cpu")
)
else:
raise RuntimeError(f"Optimizer state file not found: {optimizer_path}")
async def aupdate_weights_to(self, llm_client: LLMClient):
"""Async method to update weights to all healthy servers."""
path = constants.get_param_realloc_path(self.args)
self.save_model_to_hf(path)
tasks = [
llm_client.aupdate_weights_from_disk(server_info=server_info, path=path)
for server_info in llm_client.get_healthy_servers()
]
await asyncio.gather(*tasks)
def update_weights_to(self, llm_client: LLMClient):
"""Update the weights to the server by sending requests to the client."""
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(self.aupdate_weights_to(llm_client))
finally:
loop.close()

View File

@ -0,0 +1,315 @@
import asyncio
import functools
import math
import os
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
from arealite.api.cli_args import (
EngineConfig,
MicroBatchSpec,
ParallelismConfig,
TrainingArgs,
)
from arealite.api.engine_api import SPMDWrapper
from arealite.api.io_struct import FinetuneSpec
from arealite.api.llm_client_api import LLMClient
from arealite.utils import (
get_state_dict_from_repo_id_or_path,
recorder_list,
split_dict_tensor_with_cu_seqlens,
unpack_sequence,
)
from realhf.base import constants
def get_cosine_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
The minimum lr ratio w.r.t the maximum.
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
coef = (1 - min_lr_ratio) * 0.5
intercept = (1 + min_lr_ratio) * 0.5
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
return max(0.0, x * coef + intercept)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
class HFEngine(SPMDWrapper):
"""Simplified HF engine for transformer models."""
def __init__(self, args: TrainingArgs, engine_config: EngineConfig):
super().__init__(args, engine_config)
self.model = None
self.optimizer = None
self.model_config = None
self.weight_update_group_initialized = False
def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec):
"""Initialize model in single node."""
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
if dist.get_world_size() > 1:
raise RuntimeError(
"Distributed training is not supported in this engine. "
"Please use FSDP for distributed training."
)
torch.cuda.set_device("cuda:0")
dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.engine_config.path,
trust_remote_code=True,
)
with torch.device("cuda"):
# initialize scratch model from config
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation="flash_attention_2",
)
model = model.cuda()
self.model = model
# Set up optimizer
optimizer_config = self.engine_config.optimizer
if optimizer_config is not None:
assert (
optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = optimizer_config.lr
weight_decay = optimizer_config.weight_decay
beta1 = optimizer_config.beta1
beta2 = optimizer_config.beta2
eps = optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
optimizer_config.warmup_steps_proportion * total_train_steps
)
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=optimizer_config.min_lr_ratio,
)
def train(self, mode: bool = True):
"""Set the module in training mode."""
return self.model.train(mode)
def train_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict:
"""Train on a batch using gradient accumulation."""
assert self.optimizer is not None
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
loss = loss_fn(outputs.logits, mb_input)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
loss *= loss_scale
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.engine_config.optimizer.gradient_clipping,
norm_type=2.0,
error_if_nonfinite=False,
foreach=None,
)
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return {
"grad_norm": grad_norm,
"lr": current_lr,
}
@torch.no_grad()
def eval_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
loss = loss_fn(outputs.logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
total_loss += loss.item() * loss_scale
total_weight += loss_scale
return torch.tensor(total_loss / total_weight)
@torch.no_grad()
def forward(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),
) -> Any | None:
"""Forward pass with optional post-processing."""
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
if output_seqlens is None:
cu_seqlens = input_["cu_seqlens"]
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
if post_hook:
result = post_hook(outputs.logits, mb_input)
results.append(result)
else:
results.append(outputs.logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=1)
return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices))
def step_lr_scheduler(self):
"""Step the learning rate scheduler."""
return self.lr_scheduler.step()
def save_model_to_hf(
self,
path: str,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
base_model_path: Optional[str] = None,
):
"""Save model in HuggingFace format."""
if self.model is None:
raise RuntimeError("Model not initialized")
os.makedirs(path, exist_ok=True)
state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()}
self.model.save_pretrained(path, state_dict=state_dict)
self.model_config.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
def load_model_from_hf(self, path: str):
"""Load model from HuggingFace format."""
full_state = get_state_dict_from_repo_id_or_path(path)
self.model.load_state_dict(
full_state, strict=not self.model_config.tie_word_embeddings
)
if self.model_config.tie_word_embeddings:
self.model.tie_weights()
def save_optimizer_state(self, path: str):
"""Save optimizer state."""
if self.optimizer is None:
raise RuntimeError("Optimizer not initialized")
os.makedirs(path, exist_ok=True)
torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt"))
def load_optimizer_state(self, path: str):
"""Load optimizer state."""
if self.optimizer is None:
raise RuntimeError("Optimizer not initialized")
optimizer_path = os.path.join(path, "optimizer.pt")
if os.path.exists(optimizer_path):
self.optimizer.load_state_dict(
torch.load(optimizer_path, map_location="cpu")
)
else:
raise RuntimeError(f"Optimizer state file not found: {optimizer_path}")
async def aupdate_weights_to(self, llm_client: LLMClient):
path = constants.get_param_realloc_path(self.args)
self.save_model_to_hf(path)
tasks = [
llm_client.aupdate_weights_from_disk(server_info=server_info, path=path)
for server_info in llm_client.get_healthy_servers()
]
await asyncio.gather(*tasks)
def update_weights_to(self, llm_client: LLMClient):
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(self.aupdate_weights_to(llm_client))
finally:
loop.close()

View File

@ -0,0 +1,47 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import re
from functools import lru_cache
from typing import List
from functioncall.code.local_verify import code_verify
from realhf.impl.dataset.math_code_dataset import load_metadata
@lru_cache(maxsize=1)
def _load_metadata(dataset_path: str):
"""Cached version of load_metadata to avoid reloading metadata each time."""
return load_metadata(dataset_path)
def extract_code(text, min_length=20):
"""Extract code blocks from text."""
code_pattern = r"(?i)```(?:python|py|cpp|CPP)?\s*\n?(.*?)\n?```"
code_blocks = re.findall(code_pattern, text, re.DOTALL)
valid_blocks = []
for block in code_blocks:
clean_block = block.strip()
if len(clean_block) < min_length:
continue
valid_blocks.append(clean_block)
if not valid_blocks:
return None
# return the last code block
return valid_blocks[-1]
def code_reward(
query_id: str,
prompt: str,
completion: str,
prompt_ids: List[int],
completion_ids: List[int],
dataset_path: str,
**kwargs,
) -> float:
id2info, _ = _load_metadata(dataset_path)
return code_verify(
id2info=id2info, generateds=[extract_code(completion)], query_ids=[query_id]
)[0]

View File

@ -0,0 +1,27 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
from functools import lru_cache
from typing import List
from realhf.impl.dataset.math_code_dataset import load_metadata
from realhf.impl.dataset.math_parser import parse_line
@lru_cache(maxsize=1)
def _load_metadata(dataset_path: str):
"""Cached version of load_metadata to avoid reloading metadata each time."""
return load_metadata(dataset_path)
def math_reward(
query_id: str,
prompt: str,
completion: str,
prompt_ids: List[int],
completion_ids: List[int],
dataset_path: str,
**kwargs,
) -> float:
id2info, _ = _load_metadata(dataset_path)
return parse_line(id2info=id2info, generated=completion, query_id=query_id)

View File

@ -0,0 +1,83 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
# Modified from verl.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List
def extract_solution(solution_str, method="strict"):
assert method in ["strict", "flexible"]
if method == "strict":
# this also tests the formatting of the model
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
if len(solutions) == 0:
final_answer = None
else:
# take the last solution
final_answer = solutions[-1].replace(",", "").replace("$", "")
elif method == "flexible":
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
final_answer = None
if len(answer) == 0:
# no reward is there is no answer
pass
else:
invalid_str = ["", "."]
# find the last number that is not '.'
for final_answer in reversed(answer):
if final_answer not in invalid_str:
break
return final_answer
def compute_score(
solution_str, ground_truth, method="strict", format_score=0.0, score=1.0
):
"""The scoring function for GSM8k.
Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str, method=method)
if answer is None:
return 0
else:
if answer == ground_truth:
return score
else:
return format_score
def gsm8k_reward_fn(
query_id: str,
prompt: str,
completion: str,
prompt_ids: List[int],
completion_ids: List[int],
answer: str,
method: str,
**kwargs,
) -> float:
return compute_score(completion, extract_solution(answer), method=method)

View File

@ -0,0 +1,91 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
from datetime import datetime
from typing import Any, Callable, Dict, Optional
import torch
from arealite.api.cli_args import (
GenerationHyperparameters,
RolloutCollectorConfig,
TrainingArgs,
)
from arealite.api.io_struct import LLMRequest, Trajectory, TrajStats
from arealite.api.llm_client_api import LLMClient
from arealite.api.rollout_api import RolloutCollector
from realhf.base import logging
logger = logging.getLogger(__file__)
class RlvrCollector(RolloutCollector):
def __init__(
self,
args: TrainingArgs,
config: RolloutCollectorConfig,
reward_fn: Callable,
):
super().__init__(args, config, None, None)
self.reward_fn = reward_fn
async def arun_episode(
self,
llm_client: LLMClient,
gconfig: GenerationHyperparameters,
env_option: Optional[Dict[str, Any]] = None,
seed: Optional[int] = None,
) -> Trajectory:
"""Async version of run_episode. Run a single episode and return the trajectory."""
tik = datetime.now().timestamp()
prompt_ids = env_option["input_ids"]
query_id = env_option["query_id"]
req = LLMRequest(input_ids=prompt_ids, gconfig=gconfig)
# Use async LLM client
resp = await llm_client.agenerate(req)
# Run reward computation in executor to avoid blocking
reward_kwargs = env_option.copy()
reward_kwargs.pop("query_id")
reward_kwargs.pop("prompt")
reward = self.reward_fn(
query_id=query_id,
prompt=req.text,
completion=resp.completion,
prompt_ids=prompt_ids,
completion_ids=resp.output_tokens,
**reward_kwargs,
)
input_len = len(resp.input_tokens)
output_len = len(resp.output_tokens)
input_ids = resp.input_tokens + resp.output_tokens
prompt_mask = [1] * input_len + [0] * output_len
logprobs = [0.0] * input_len + resp.output_logprobs
versions = [-1] * input_len + resp.output_versions
# logger.info(
# f"Prompt: {req.text}, reward: {reward}\nCompletion: {resp.completion}"
# )
return Trajectory(
prompt=env_option,
data=dict(
# unsqueeze to add an additional batch dimension
input_ids=torch.tensor(input_ids).unsqueeze(0),
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
logprobs=torch.tensor(logprobs).unsqueeze(0),
versions=torch.tensor(versions).unsqueeze(0),
# reward
rewards=torch.tensor([reward]),
),
stats=TrajStats(
start_time=tik,
total_reward=reward,
episode_length=1,
info={},
),
)

View File

@ -0,0 +1,530 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import functools
import os
import time
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from datasets import Dataset
from arealite import ppo_functional
from arealite.api.cli_args import (
GRPOTrainerConfig,
MicroBatchSpec,
TrainerConfig,
TrainingArgs,
)
from arealite.api.engine_api import EngineFactory
from arealite.api.io_struct import FinetuneSpec, Trajectory
from arealite.api.llm_client_api import LLMClientFactory
from arealite.api.trainer_api import Trainer
from arealite.system.rollout_controller import RolloutController
from arealite.utils import (
calc_entropy,
close_wandb_tensorboard,
compute_varlen_position_indices,
concat_padded_tensors,
gather_logprobs,
init_stats_logging,
log_wandb_tensorboard,
masked_normalization,
record_timing,
split_dict_tensor_with_cu_seqlens,
to_device,
unpad_input,
)
from realhf.api.core.data_api import load_hf_tokenizer, tabulate_stats
from realhf.base import constants, logging, name_resolve, names, stats_tracker, timeutil
logger = logging.getLogger("GRPO Trainer", "system")
class SpmdGRPOTrainer(Trainer):
def __init__(
self,
args: TrainingArgs,
trainer_config: TrainerConfig,
train_dataset: Dataset,
valid_dataset: Optional[Dataset] = None,
rollout_controller: Optional[RolloutController] = None,
):
super().__init__(
args,
trainer_config,
train_dataset,
valid_dataset,
rollout_controller,
)
if self.rollout_controller is None:
raise ValueError("GRPO Trainer requires a rollout controller.")
assert trainer_config.grpo is not None
self.config: GRPOTrainerConfig = trainer_config.grpo
assert args.rollout is not None
assert self.config.actor is not None
# Create actor model
engine_factory = EngineFactory(args)
self.actor = engine_factory.make_engine(self.config.actor)
self.actor_tokenizer = load_hf_tokenizer(self.config.actor.path)
self.gconfig = args.rollout.gconfig
# Create reference model is specified
self.ref = None
if self.config.ref is not None:
self.ref = engine_factory.make_engine(self.config.ref)
# Create a client to generate responses and update weights
client_factory = LLMClientFactory(args)
self.llm_client = client_factory.make_client(args.rollout.llm_client)
# Algorithm related attributes
self.kl_ctl = self.config.kl_ctl
self.discount = self.config.discount
self.gae_lambda = self.config.gae_lambda
self.adv_norm = self.config.adv_norm
self.max_reward_clip = self.config.max_reward_clip
self.group_adv_norm = self.config.group_adv_norm
self.group_size = args.rollout.gconfig.n_samples
self.max_head_offpolicyness = args.rollout.max_head_offpolicyness
self.reward_bias = self.config.reward_bias
self.reward_scaling = self.config.reward_scaling
self.max_reward_clip = self.config.max_reward_clip
self.save_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=self.args.exp_ctrl.save_freq_epochs,
freq_step=self.args.exp_ctrl.save_freq_steps,
freq_sec=self.args.exp_ctrl.save_freq_secs,
)
self.eval_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=self.args.exp_ctrl.eval_freq_epochs,
freq_step=self.args.exp_ctrl.eval_freq_steps,
freq_sec=self.args.exp_ctrl.eval_freq_steps,
)
self.summary_writer = init_stats_logging(args)
def train(self, resume_from_checkpoint=None):
# TODO: handle recover
self.create_train_dataloader()
assert self.rollout_controller is not None
assert self.train_dataloader is not None
total_epochs = self.args.exp_ctrl.total_train_epochs
steps_per_epoch = len(self.train_dataloader)
ft_spec = FinetuneSpec(
total_train_epochs=total_epochs,
dataset_size=len(self.train_dataset),
train_batch_size=self.args.train_dataset.batch_size,
)
# Setting up models.
self.actor.init_distributed(None, ft_spec)
self.actor.load_model_from_hf(self.config.actor.path)
self.actor.eval()
if self.ref is not None:
self.ref.init_distributed(None, ft_spec)
self.ref.load_model_from_hf(self.config.ref.path)
self.ref.eval()
self.llm_client.wait_until_servers_ready()
self.actor.update_weights_to(self.llm_client)
# Start rollout for asynchronous RL.
if self.config.async_training:
self.rollout_controller.start_generate_loop()
# Main RL training loop.
total_epochs = self.args.exp_ctrl.total_train_epochs
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
global_step = 0
warmup_steps = self.max_head_offpolicyness + 1
assert steps_per_epoch >= warmup_steps
start_time = time.monotonic()
for epoch in range(total_epochs):
for step, data in enumerate(self.train_dataloader):
timing_stats = {}
with record_timing("timeperf/rollout", timing_stats):
if self.config.async_training:
self.rollout_controller.submit(data)
# Submitted data will not actually be sent for rollout.
# The rollout controller over-subscribe the data to
# ensure that there are enough data being generated.
if epoch == 0 and step < warmup_steps:
continue
# Wait until enough trajectories has been collected.
trajs = self.rollout_controller.prepare_batch(
batch_size=self.args.train_dataset.batch_size // world_size
)
else:
# Run batched rollout by submitting requests to LLM servers
trajs = self.rollout_controller.generate_batch(
batch_size=len(data),
env_options=data,
)
with record_timing("timeperf/train_step", timing_stats):
# Run RL training and update weights.
mb_stats = self._train_step(trajs)
self.actor.step_lr_scheduler()
with record_timing("timeperf/sync_weights", timing_stats):
# Synchronize weights to the client.
self.actor.update_weights_to(self.llm_client)
# Update model version
name = names.model_version(
self.args.experiment_name, self.args.trial_name, "actor"
)
name_resolve.add(name, str(global_step + 1), replace=True)
if self.save_ctl.check(
epochs=int(step == steps_per_epoch - 1), steps=1
):
if dist.get_rank() == 0:
logger.info("Saving model ...")
with record_timing("timeperf/save", timing_stats):
save_path = os.path.join(
constants.get_save_path(self.args), "actor"
)
self.actor.save_model_to_hf(
save_path,
tokenizer=self.actor_tokenizer,
base_model_path=self.config.actor.path,
)
assert len(mb_stats) == self.config.ppo_n_minibatches
log_step = self.config.ppo_n_minibatches * global_step
for i, stats in enumerate(mb_stats):
log_wandb_tensorboard(log_step + i, stats, self.summary_writer)
log_wandb_tensorboard(log_step, timing_stats, self.summary_writer)
if dist.get_rank() == 0:
logger.info(
f"Epoch {epoch+1}/{total_epochs} "
f"Step {step+1}/{steps_per_epoch} "
f"Train step {global_step + 1}/{total_epochs * steps_per_epoch - warmup_steps} done."
)
logger.info(
f"Detailed time stats: \n{tabulate_stats(timing_stats, floatfmt='.2f')}"
)
for i, stats in enumerate(mb_stats):
logger.info(
f"GRPO training stats ({i + 1}/{len(mb_stats)}):\n{tabulate_stats(stats)}"
)
global_step += 1
if dist.get_rank() == 0:
logger.info(
f"Training completes! Total time elapsed {time.monotonic() - start_time:.2f}."
)
if self.config.async_training:
self.rollout_controller.stop_generate_loop()
close_wandb_tensorboard(self.summary_writer)
def _train_step(self, trajs: List[Trajectory]):
rollout = concat_padded_tensors([traj.data for traj in trajs])
rollout = to_device(rollout, torch.cuda.current_device())
# Marks which sequence does not has an EOS token, i.e.,
# generation is truncated by the configured maximum generation length
batch_tokens = rollout["input_ids"]
seq_no_eos_mask = (
batch_tokens[:, -1] != self.actor_tokenizer.eos_token_id
).logical_and(batch_tokens[:, -1] != self.actor_tokenizer.pad_token_id)
# Remove padding to use flash-attn
attn_mask = rollout["attention_mask"]
input_ids, _, cu_seqlens, max_seqlen = unpad_input(
rollout["input_ids"], attn_mask
)
position_ids = compute_varlen_position_indices(input_ids.shape[0], cu_seqlens)
# Transformer forward input data
model_inputs = dict(
input_ids=input_ids.unsqueeze(0),
attention_mask=None,
position_ids=position_ids.unsqueeze(0),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
use_cache=False,
)
old_logp, *_ = unpad_input(rollout["logprobs"], attn_mask)
prompt_mask, *_ = unpad_input(rollout["prompt_mask"], attn_mask)
# Shift logprobs and mask for computing loss.
loss_mask = prompt_mask.logical_not()
loss_mask = torch.roll(loss_mask, shifts=-1)
old_logp = torch.roll(old_logp, shifts=-1)
input_ids = model_inputs["input_ids"].squeeze(0)
n_seqs = seq_no_eos_mask.shape[0]
assert n_seqs == self.local_train_batch_size * self.group_size, (
n_seqs,
self.group_size,
self.local_train_batch_size,
)
# Run reference model forward
def calc_logprobs(logits, input_data):
logits = logits.squeeze(0).float()
labels = torch.roll(input_data["input_ids"].squeeze(0), shifts=-1)
logits /= self.gconfig.temperature
logprobs = gather_logprobs(logits, labels)
return logprobs.unsqueeze(0)
if self.ref is not None and self.config.kl_ctl != 0.0:
ref_logp = self.ref.forward(
model_inputs,
mb_spec=self.config.mb_spec,
post_hook=calc_logprobs,
).squeeze(0)
else:
ref_logp = torch.zeros_like(input_ids, dtype=torch.float32)
# Recompute logprobs using the current actor model.
prox_logp = None
if self.config.recompute_logprob:
_logp = self.actor.forward(
model_inputs,
mb_spec=self.config.mb_spec,
post_hook=calc_logprobs,
).squeeze(0)
if self.config.use_decoupled_loss:
prox_logp = _logp
else:
# Overwrite the logp returned by the inference engine
old_logp = _logp
# Compute rewards using the reward function in synchronous RLVR pipeline.
reward_score = rollout["rewards"]
reward_score = (reward_score + self.reward_bias) * self.reward_scaling
reward_score = torch.clip(reward_score, max=self.max_reward_clip)
if self.config.group_reward_norm:
for i in range(n_seqs // self.group_size):
s = slice(i * self.group_size, (i + 1) * self.group_size)
r = reward_score[s]
reward_score[s] = (r - r.mean()) / (r.std() + 1e-9)
# Apply the mask to log probabilities.
ref_logp *= loss_mask
old_logp *= loss_mask
# Compute KL-regularized rewards and GAEs.
cu_seqlens = model_inputs["cu_seqlens"]
seq_no_eos_mask = seq_no_eos_mask
kl_rewards, rewards = ppo_functional.get_packed_rewards(
kl_ctl=self.kl_ctl,
clip_reward_value=self.max_reward_clip,
log_probs=old_logp,
ref_log_probs=ref_logp,
reward_score=reward_score,
cu_seqlens=cu_seqlens,
seq_no_eos_mask=seq_no_eos_mask,
mask_no_eos_with_zero=self.config.mask_no_eos_with_zero,
)
advantages, _ = ppo_functional.get_packed_advantages_and_returns(
gamma=self.discount,
lam=self.gae_lambda,
values=torch.zeros(
input_ids.shape[0] + n_seqs,
device=input_ids.device,
dtype=torch.float32,
),
rewards=rewards,
short1cu_seqlens=cu_seqlens,
seq_no_eos_mask=seq_no_eos_mask,
)
# Optionally perform advantage normalization.
if self.adv_norm:
if self.group_adv_norm:
n_samples = len(cu_seqlens) - 1
assert n_samples % self.group_size == 0
adv_list = []
for i in range(0, n_samples, self.group_size):
adv_list.append(
masked_normalization(
advantages[cu_seqlens[i] : cu_seqlens[i + self.group_size]],
loss_mask[cu_seqlens[i] : cu_seqlens[i + self.group_size]],
all_reduce=False,
)
)
advantages = torch.cat(adv_list, 0)
else:
advantages = masked_normalization(advantages, loss_mask)
# Prepare data to be splitted into mini-batches.
global_batch = dict(
**model_inputs,
old_logp=old_logp,
advantages=advantages,
loss_mask=loss_mask,
prox_logp=prox_logp,
)
input_lens = model_inputs["cu_seqlens"][1:] - model_inputs["cu_seqlens"][:-1]
all_stats = []
with stats_tracker.scope("actor"):
########## Logging code starts ##########
result_denominators = {
"correct_n_seqs": (reward_score > 0).bool(),
"incorrect_n_seqs": (reward_score <= 0).bool(),
}
global_denominators = dict(
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
n_tokens=torch.ones_like(loss_mask, dtype=torch.bool),
n_valid_tokens=loss_mask.bool(),
**result_denominators,
)
stats_tracker.denominator(**global_denominators)
stats_tracker.stat(
correct_seq_len=input_lens.float(), denominator="correct_n_seqs"
)
stats_tracker.stat(
incorrect_seq_len=input_lens.float(), denominator="incorrect_n_seqs"
)
stats = dict(
advantages=advantages,
kl_rewards=kl_rewards,
final_reward=rewards,
)
stats_tracker.stat(**stats, denominator="n_valid_tokens")
prompt_lens = []
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
prompt_lens.append(prompt_mask[s:e].sum())
prompt_lens = torch.tensor(prompt_lens, device=reward_score.device)
seq_stats = dict(
no_eos_ratios=seq_no_eos_mask.float(),
task_reward=reward_score,
prompt_len=prompt_lens.float(),
seq_len=input_lens.float(),
)
stats_tracker.stat(**seq_stats, denominator="n_seqs")
scalars = dict(
mask_no_eos_with_zero=self.config.mask_no_eos_with_zero,
eps_clip=self.config.eps_clip,
use_prox_logp=prox_logp is not None,
)
if self.config.c_clip is not None:
scalars["c_clip"] = self.config.c_clip
scalars["use_dual_clip"] = 1
else:
scalars["use_dual_clip"] = 0
if self.config.behav_imp_weight_cap is not None:
scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap
stats_tracker.scalar(**scalars)
global_stats = stats_tracker.export()
for k in global_denominators:
global_stats.pop(f"actor/{k}")
########## Logging code ends ##########
mb_inputs = split_dict_tensor_with_cu_seqlens(
global_batch,
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
)
for mb in mb_inputs.mbs:
model_inputs = {k: mb[k] for k in model_inputs}
train_stat = self.actor.train_batch(
mb,
loss_fn=functools.partial(
grpo_loss_fn,
temperature=self.gconfig.temperature,
eps_clip=self.config.eps_clip,
c_clip=self.config.c_clip,
behav_imp_weight_cap=self.config.behav_imp_weight_cap,
),
mb_spec=self.config.mb_spec,
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
)
stats_tracker.scalar(**train_stat)
all_stats.append(stats_tracker.export())
all_stats[0].update(global_stats)
return all_stats
def grpo_loss_fn(
logits: torch.Tensor,
input_data: Dict,
temperature: float,
eps_clip: float,
c_clip: float | None,
behav_imp_weight_cap: float | None,
):
"""Loss function for actor step, all inputs should be splitted into
pipeline micro batches, returns loss and logging stats."""
input_ids = input_data["input_ids"].squeeze(0)
cu_seqlens = input_data["cu_seqlens"]
old_logp = input_data["old_logp"]
advantages = input_data["advantages"]
loss_mask = input_data["loss_mask"]
prox_logp = input_data["prox_logp"]
logits = logits.squeeze(0).float()
logits /= temperature
logprobs = gather_logprobs(logits, torch.roll(input_ids, shifts=-1))
loss, stat = ppo_functional.actor_loss_fn(
logprobs=logprobs,
old_logprobs=old_logp,
advantages=advantages,
eps_clip=eps_clip,
loss_mask=loss_mask,
c_clip=c_clip,
proximal_logprobs=prox_logp,
behav_imp_weight_cap=behav_imp_weight_cap,
)
entropy = calc_entropy(logits=logits, cu_seqlens=cu_seqlens)
# Log training statistics
stats_tracker.denominator(
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=loss_mask.bool(),
clipped_tokens=stat["clip_mask"],
dual_clipped_tokens=stat["dual_clip_mask"],
)
stats_tracker.stat(
importance_weight=stat["importance_weight"],
approx_kl=stat["approx_kl"],
new_logp=logprobs.detach(),
old_logp=old_logp,
entropy=entropy.float(),
actor_loss=stat["loss"],
clip_ratio=stat["clip_mask"].float(),
dual_clip_ratio=stat["dual_clip_mask"].float(),
denominator="n_valid_tokens",
)
if "behave_imp_weight" in stat:
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
stats_tracker.stat(
behave_imp_weight=stat["behave_imp_weight"],
behave_approx_kl=stat["behave_approx_kl"],
denominator="unclipped_behave_tokens",
)
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
clip_mask = stat["clip_mask"]
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
stats_tracker.stat(
clipped_new_logp=clipped_new_logp,
clipped_old_logp=clipped_old_logp,
denominator="clipped_tokens",
)
return loss

View File

@ -0,0 +1,269 @@
import time
from typing import Any, Dict, List, Optional
import torch
import torch.distributed as dist
import torch.utils.data
from datasets import Dataset
from arealite.api.cli_args import TrainerConfig, TrainingArgs
from arealite.api.engine_api import EngineFactory
from arealite.api.trainer_api import Trainer
from arealite.system.rollout_controller import RolloutController
from arealite.utils import (
close_wandb_tensorboard,
compute_varlen_position_indices,
gather_logprobs,
init_stats_logging,
list_of_dict2dict_of_list,
log_wandb_tensorboard,
record_timing,
)
from realhf.api.core.data_api import load_hf_tokenizer, tabulate_stats
from realhf.api.core.model_api import FinetuneSpec
from realhf.base import logging, stats_tracker, timeutil
logger = logging.getLogger("SFT Trainer")
def compute_packed_sft_loss(
logits: torch.Tensor,
input_: Dict[str, torch.Tensor],
) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"].squeeze(dim=0)
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
input_lens: torch.Tensor = cu_seqlens[1:] - cu_seqlens[:-1]
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
prompt_mask = input_["prompt_mask"].squeeze(dim=0)
logits = logits.squeeze(dim=0).float()
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1))
logprobs = torch.where(prompt_mask, 0, logprobs)
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
with torch.no_grad():
seqlogp = torch.zeros(
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
)
for i in range(cu_seqlens.shape[0] - 1):
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
cu_seqlens,
logprobs.shape,
)
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
m.numel() - m.count_nonzero()
)
## Loggin stats
stats_tracker.denominator(
n_seqs=torch.ones(
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
),
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
n_valid_tokens=prompt_mask.logical_not(),
prompt_tokens=prompt_mask,
)
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
vocab_max_logits=vocab_max_logits,
denominator="n_tokens",
)
return loss
class SFTTrainer(Trainer):
def __init__(
self,
args: TrainingArgs,
trainer_config: TrainerConfig,
train_dataset: Dataset,
valid_dataset: Optional[Dataset] = None,
rollout_controller: Optional[RolloutController] = None,
):
super().__init__(
args, trainer_config, train_dataset, valid_dataset, rollout_controller
)
self.config = config = trainer_config.sft
assert config is not None
engine_factory = EngineFactory(args)
self.model = engine_factory.make_engine(config.model)
self.tokenizer = load_hf_tokenizer(config.model.path)
self.mb_spec = config.mb_spec
self.save_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=self.args.exp_ctrl.save_freq_epochs,
freq_step=self.args.exp_ctrl.save_freq_steps,
freq_sec=self.args.exp_ctrl.save_freq_secs,
)
self.eval_ctl = timeutil.EpochStepTimeFreqCtl(
freq_epoch=self.args.exp_ctrl.eval_freq_epochs,
freq_step=self.args.exp_ctrl.eval_freq_steps,
freq_sec=self.args.exp_ctrl.eval_freq_steps,
)
self.summary_writer = init_stats_logging(args)
def _tokenize(self, strs: List[str]):
# tokenize strings into unpadded tokens with lengths.
return self.tokenizer(
strs,
padding=False,
truncation=True,
return_length=True,
max_length=self.mb_spec.max_tokens_per_mb,
return_attention_mask=False,
)
def _get_packed_input(self, data: List[Dict[str, Any]]):
data: Dict[str, List[Any]] = list_of_dict2dict_of_list(data)
tokenized_seqs = data["seq"]
tokenized_prompts = data["prompt"]
prompt_lens = [len(prompt) for prompt in tokenized_prompts]
input_lens = [len(prompt) for prompt in tokenized_seqs]
input_lens = torch.tensor(input_lens, dtype=torch.int)
input_ids = [torch.tensor(seq, dtype=torch.long) for seq in tokenized_seqs]
prompt_mask = []
for input_len, prompt_len in zip(input_lens, prompt_lens):
assert input_len >= prompt_len, (input_len, prompt_len)
pm = [1] * prompt_len + [0] * (input_len - prompt_len)
prompt_mask.append(torch.tensor(pm, dtype=torch.bool))
cu_seqlens = torch.nn.functional.pad(
input_lens.cumsum(0, dtype=torch.int), (1, 0)
)
max_seqlen = int(torch.max(input_lens).item())
packed_input_ids = torch.cat(input_ids, dim=0)
prompt_mask = torch.cat(prompt_mask, dim=0)
total_seqlen = int(cu_seqlens[-1].item())
position_ids = compute_varlen_position_indices(total_seqlen, cu_seqlens)
return dict(
input_ids=packed_input_ids.unsqueeze(0).cuda(),
attention_mask=None,
position_ids=position_ids.unsqueeze(0).cuda(),
prompt_mask=prompt_mask.unsqueeze(0).cuda(),
cu_seqlens=cu_seqlens.cuda(),
max_seqlen=max_seqlen,
use_cache=False,
)
def train(self, resume_from_checkpoint=None):
self.create_train_dataloader()
total_epochs = self.args.exp_ctrl.total_train_epochs
steps_per_epoch = len(self.train_dataloader)
ft_spec = FinetuneSpec(
total_train_epochs=steps_per_epoch,
dataset_size=len(self.train_dataset),
train_batch_size=self.args.train_dataset.batch_size,
)
self.model.init_distributed(None, ft_spec)
self.model.load_model_from_hf(self.config.model.path)
self.model.train()
if dist.get_rank() == 0:
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
global_step = 0
start_time = time.monotonic()
for epoch in range(total_epochs):
for step, data in enumerate(self.train_dataloader):
timing_stats = {}
with record_timing("timeperf/data_processing", timing_stats):
packed_input_data = self._get_packed_input(data)
with record_timing("timeperf/train_step", timing_stats):
with stats_tracker.scope("sft"):
stats = self.model.train_batch(
input_=packed_input_data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"]
.logical_not()
.count_nonzero(),
mb_spec=self.mb_spec,
)
self.model.step_lr_scheduler()
stats_tracker.scalar(**stats)
if self.save_ctl.check(
epochs=int(step == steps_per_epoch - 1), steps=1
):
if dist.get_rank() == 0:
logger.info("Saving model ...")
with record_timing("timeperf/save", timing_stats):
save_path = self.get_save_checkpoint_path(
epoch, step, global_step
)
self.model.save_model_to_hf(save_path, self.tokenizer)
if self.eval_ctl.check(
epochs=int(step == steps_per_epoch - 1), steps=1
):
if dist.get_rank() == 0:
logger.info("Running evaluation ...")
with record_timing("timeperf/eval", timing_stats):
self._eval(global_step)
training_stats = stats_tracker.export()
training_stats.update(timing_stats)
log_wandb_tensorboard(global_step, training_stats, self.summary_writer)
if dist.get_rank() == 0:
logger.info(
f"Epoch {epoch} Step {step} GlobalStep {global_step} done. Detailed time stats:"
f"\n{tabulate_stats(timing_stats, floatfmt='.2f')}"
)
global_step += 1
if dist.get_rank() == 0:
logger.info(
f"Training completes! Total time elapsed {time.monotonic() - start_time:.2f}."
)
close_wandb_tensorboard(self.summary_writer)
def _eval(self, global_step):
self.create_valid_dataloader()
if self.valid_dataloader is None:
return
self.eval_data_generator = iter(self.valid_dataloader)
n_steps = len(self.valid_dataloader)
losses = []
start_time = time.monotonic()
for step in range(n_steps):
data = next(self.eval_data_generator)
packed_input_data = self._get_packed_input(data)
with stats_tracker.scope("sft-eval"):
avg_loss = self.model.eval_batch(
input_=packed_input_data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"]
.logical_not()
.count_nonzero(),
mb_spec=self.mb_spec,
)
losses.append(avg_loss)
val_loss = torch.mean(torch.stack(losses))
logger.info(
f"Global step: {global_step} evaluation time cost {time.monotonic() - start_time:.2f} "
f"val_loss={val_loss:.4f}"
)

View File

@ -0,0 +1,228 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import asyncio
import threading
import time
import traceback
from queue import Empty as QueueEmpty
from typing import Any, List, Optional
import numpy as np
# NOTE: the start method of mp should be fork rather than spawn
import torch.multiprocessing as mp
from arealite.api.cli_args import RolloutConfig, TrainingArgs
from arealite.api.io_struct import Trajectory
from arealite.api.llm_client_api import LLMClientFactory
from arealite.api.rollout_api import RolloutCollector
from arealite.system.rollout_worker import RolloutWorker
from realhf.base import datapack, logging, network
from realhf.system.push_pull_stream import ZMQJsonPuller, ZMQJsonPusher
logger = logging.getLogger("Rollout Controller")
class RolloutController:
def __init__(
self,
args: TrainingArgs,
config: RolloutConfig,
collector: RolloutCollector,
):
self.args = args
self.config = config
self.gconfig = config.gconfig
self.collector = collector
# Process-based execution
self._exiting = mp.Event()
self._lock = mp.Lock()
self._buffer: List[List[Trajectory]] = []
self._version = 0
# Worker processes for asynchronous rollout
self._worker_processes: List[mp.Process] = []
self.llm_client = LLMClientFactory(args).make_client(config.llm_client)
# PushPull communication for data to workers
self._data_pusher = None
self._data_pusher_port = None
self._puller = None
self._puller_port = None
self._collector_thread = None
################### User Interfaces Start #################
def generate_batch(
self,
batch_size: int,
env_options: Optional[List[Any]] = None,
seeds: Optional[List[int]] = None,
) -> List[Trajectory]:
"""Run episodes in batch using the collector directly (for compatibility)."""
if env_options is None:
env_options = [None] * batch_size
else:
assert len(env_options) == batch_size
if seeds is None:
seeds = [None] * batch_size
else:
assert len(seeds) == batch_size
async def run_parallel_gen():
worker = RolloutWorker(
worker_id=0,
args=self.args,
config=self.config,
llm_client=self.llm_client,
)
tasks = [
worker._run_grouped_episode_async(None, env_option, seed)
for env_option, seed in zip(env_options, seeds)
]
results = await asyncio.gather(*tasks)
return sum([r[1] for r in results], [])
return asyncio.run(run_parallel_gen())
def start_generate_loop(self):
"""Start worker processes that run generation loops."""
logger.info("Starting worker processes...")
# Start background thread to collect data from workers
self._puller_port = network.find_free_port(
experiment_name=self.args.experiment_name, trial_name=self.args.trial_name
)
self._collector_thread = threading.Thread(
target=self._collect_from_workers, daemon=True
)
self._collector_thread.start()
# Start worker processes
self._data_pusher_port = network.find_free_port(
experiment_name=self.args.experiment_name, trial_name=self.args.trial_name
)
self._data_pusher = ZMQJsonPusher(
host="localhost", port=self._data_pusher_port, bind=True
)
logger.info(f"RolloutController sending data on port {self._data_pusher_port}")
num_workers = self.config.num_workers
for worker_id in range(num_workers):
process = mp.Process(
target=_run_worker_process,
args=(
worker_id,
self.args,
self.config,
self._puller_port,
self._data_pusher_port,
),
)
process.start()
self._worker_processes.append(process)
logger.info(f"Started worker process {worker_id}")
def submit(self, data):
"""Submit data to worker processes for processing."""
if self._data_pusher is None:
raise RuntimeError(
"Data pusher not initialized. Call start_generate_loop() first."
)
# Convert data to JSON-compatible format
assert isinstance(data, list)
for d in data:
self._data_pusher.push(d)
logger.debug(f"Submitted {len(data)} data to workers")
def prepare_batch(self, batch_size: int) -> List[Trajectory]:
"""Prepare and wait for a batch of trajectories."""
buf_size = -1
while buf_size < batch_size:
with self._lock:
buf_size = len(self._buffer)
time.sleep(0.1)
with self._lock:
self._buffer = sorted(
self._buffer, key=lambda x: np.mean([xx.stats.start_time for xx in x])
)
data, self._buffer = self._buffer[:batch_size], self._buffer[batch_size:]
return datapack.flat2d(data)
def stop_generate_loop(self):
"""Stop worker processes and cleanup."""
logger.info("Stopping worker processes...")
self._exiting.set()
# Stop worker processes gracefully first, then forcefully if needed
for i, process in enumerate(self._worker_processes):
if process.is_alive():
logger.info(f"Terminating worker process {i}...")
try:
process.terminate()
process.join(timeout=1.0)
except Exception:
process.kill()
self._worker_processes.clear()
if self._collector_thread is not None:
# Wait for the thread to finish (with optional timeout)
self._collector_thread.join(timeout=1.0)
# Close communication channels
if self._puller:
self._puller.close()
if self._data_pusher:
self._data_pusher.close()
logger.info("Cleanup completed")
################## User Interfaces End ##################
def _collect_from_workers(self):
"""Background thread to collect trajectories from workers."""
# Find a free port
self._puller = ZMQJsonPuller(host="localhost", port=self._puller_port)
logger.info(f"RolloutController listening on port {self._puller_port}")
while not self._exiting.is_set():
try:
# Pull data from workers
data = self._puller.pull(timeout_ms=100)
# Convert back to Trajectory objects
trajs = [
Trajectory.from_json_compatible(traj_data)
for traj_data in data["trajs"]
]
# Add to buffer
with self._lock:
self._buffer.append(trajs)
logger.debug(
f"Received {len(trajs)} trajectories from worker {data['worker_id']}"
)
except QueueEmpty:
# No data available, continue
time.sleep(0.1)
continue
except Exception as e:
if not self._exiting.is_set():
logger.error(f"Error in collector thread: {e}")
logger.error(traceback.format_exc())
break
def _run_worker_process(worker_id: int, args, config, puller_port, data_pusher_port):
worker = RolloutWorker(
worker_id=worker_id,
args=args,
config=config,
pusher_host="localhost",
pusher_port=puller_port,
data_puller_host="localhost",
data_puller_port=data_pusher_port,
)
logger.info(f"Worker {worker_id} starting generation loop...")
worker.run_generation_loop()

View File

@ -0,0 +1,234 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import asyncio
import queue
from typing import Any, Dict, List, Optional
import numpy as np
import torch.distributed as dist
from arealite.api.cli_args import RolloutConfig, TrainingArgs
from arealite.api.io_struct import Trajectory
from arealite.api.llm_client_api import LLMClient, LLMClientFactory
from arealite.api.rollout_api import RolloutCollectorFactory
from realhf.base import logging, name_resolve, names
from realhf.base.monitor import RolloutStat
from realhf.system.push_pull_stream import ZMQJsonPuller, ZMQJsonPusher
logger = logging.getLogger("RolloutWorker")
ROLLOUT_POLL_WAIT_TIME = 0.4
class RolloutWorker:
"""Standalone rollout worker that runs continuous generation loop."""
def __init__(
self,
worker_id: int,
args: TrainingArgs,
config: RolloutConfig,
llm_client: LLMClient | None = None,
pusher_host: Optional[str] = "localhost",
pusher_port: Optional[int] = 5555,
data_puller_host: Optional[str] = "localhost",
data_puller_port: Optional[int] = 5556,
):
self.worker_id = worker_id
self.args = args
self.config = config
self.gconfig = config.gconfig
# For staleness control
self.train_batch_size = args.train_dataset.batch_size
self.max_concurrent_rollouts = (
config.max_concurrent_rollouts or self.train_batch_size
)
self.pusher_host = pusher_host
self.pusher_port = pusher_port
self.data_puller_host = data_puller_host
self.data_puller_port = data_puller_port
self._shutdown = False
self.pusher = None
self.data_puller = None
if llm_client is None:
llm_client = LLMClientFactory(args).make_client(config.llm_client)
self.llm_client = llm_client
def _cleanup(self):
"""Clean up resources."""
if self.pusher:
self.pusher.close()
if self.data_puller:
self.data_puller.close()
def run_generation_loop(self):
"""Run the continuous generation loop like the original _generate_loop."""
try:
asyncio.run(self._generate_loop())
finally:
self._cleanup()
async def _run_grouped_episode_async(
self, rid: int, data: Any, seed: Optional[int] = None
):
"""Run grouped episode asynchronously."""
tasks = []
for _ in range(self.gconfig.n_samples):
# Create collector
factory = RolloutCollectorFactory(self.args)
collector = factory.make_collector(self.config.collector)
tasks += [
collector.arun_episode(
llm_client=self.llm_client,
gconfig=self.gconfig.new(n_samples=1),
env_option=data,
seed=seed,
)
]
trajs = await asyncio.gather(*tasks)
return rid, trajs
def _get_model_version(self) -> int:
name = names.model_version(
self.args.experiment_name,
self.args.trial_name,
"actor",
)
try:
return int(name_resolve.get(name))
except name_resolve.NameEntryNotFoundError:
return 0
async def _generate_loop(self):
"""Main generation loop - similar to original RolloutController._generate_loop."""
data = None
# Communication with main process
self.pusher = ZMQJsonPusher(host=self.pusher_host, port=self.pusher_port)
self.data_puller = ZMQJsonPuller(
host=self.data_puller_host,
port=self.data_puller_port,
bind=False,
)
rollout_stat = RolloutStat()
rollout_tasks: Dict[int, asyncio.Task] = {}
rid = 0
try:
while not self._shutdown:
# Load next data from controller
if data is None:
try:
data = self.data_puller.pull(timeout_ms=50)
logger.debug(f"Get data from puller: {data}")
except queue.Empty:
logger.debug(f"No data from puller stream.")
# Check capacity
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
cannot_rollout_reason = []
capacity = max(1, self.max_concurrent_rollouts // world_size)
can_rollout = len(rollout_tasks) < capacity
if not can_rollout:
cannot_rollout_reason.append(
f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}"
)
# Staleness control
version = self._get_model_version()
ofp = self.config.max_head_offpolicyness
sample_cnt = rollout_stat.accepted + rollout_stat.running
expected_version = sample_cnt // self.train_batch_size
not_staled = expected_version <= ofp + version
can_rollout &= not_staled
if not not_staled:
cannot_rollout_reason.append(
f"Staled: expected version ({expected_version}) = "
f"global sample cnt ({sample_cnt}) // batch size ({self.train_batch_size}), "
f"current latest version {version}, "
f"offpolicyness {self.config.max_head_offpolicyness}."
)
if not can_rollout:
logger.debug(
f"Worker {self.worker_id}: Cannot submit new rollouts. "
+ "\n".join(cannot_rollout_reason)
)
# Create new rollout task
if can_rollout and data is not None:
task = asyncio.create_task(
self._run_grouped_episode_async(rid, data)
)
rollout_tasks[rid] = task
rollout_stat.submitted += 1
rollout_stat.running += 1
logger.debug(
f"Worker {self.worker_id}: Submit rollout rid {rid}. "
f"Submit: {rollout_stat.submitted}, "
f"running: {rollout_stat.running}, "
f"accepted: {rollout_stat.accepted}."
)
rid += 1
data = None
# Wait for rollout completion
tasks = list(rollout_tasks.values())
done = []
if tasks:
done, _ = await asyncio.wait(
tasks,
timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED,
)
else:
await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME)
# Collect done results
for task in done:
task_rid, trajs = await task
trajs: List[Trajectory]
rollout_tasks.pop(task_rid)
rollout_stat.running -= 1
# Filter data according to episodic return
ret = np.mean([traj.stats.total_reward for traj in trajs])
accepted = ret >= self.config.filter_reward_lb
accepted &= ret <= self.config.filter_reward_ub
if accepted:
# Push trajectories to main process
trajectory_data = {
"worker_id": self.worker_id,
"trajs": [traj.to_json_compatible() for traj in trajs],
}
self.pusher.push(trajectory_data)
rollout_stat.accepted += 1
logger.debug(
f"Worker {self.worker_id}: Finish rollout {task_rid}. "
f"Submit: {rollout_stat.submitted}, "
f"running: {rollout_stat.running}, "
f"accepted: {rollout_stat.accepted}."
)
finally:
# Cancel remaining tasks
for task in rollout_tasks.values():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

View File

@ -0,0 +1,165 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import time
from arealite.api.io_struct import LLMRequest, LLMResponse, LLMServerInfo
from arealite.api.llm_client_api import LLMClient
from realhf.base import logging, pkg_version
logger = logging.getLogger(__name__)
if pkg_version.is_available("sglang"):
if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"):
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids"
else:
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
class SGLangClient(LLMClient):
"""SGLang implementation of LLMClient."""
async def agenerate(self, req: LLMRequest) -> LLMResponse:
"""Async version of generate using aiohttp."""
# Convert messages to prompt
if not req.text:
assert req.input_ids is not None
req.text = self.tokenizer.decode(req.input_ids)
# Prepare request payload
gconfig = req.gconfig
stop_token_ids = gconfig.stop_token_ids
if self.tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(self.tokenizer.eos_token_id)
if self.tokenizer.pad_token_id not in stop_token_ids:
stop_token_ids.append(self.tokenizer.pad_token_id)
assert gconfig.n_samples == 1
sample_params = {
"top_p": gconfig.top_p,
"top_k": gconfig.top_k,
"max_new_tokens": gconfig.max_new_tokens,
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
"stop_token_ids": stop_token_ids,
}
payload = {
"rid": req.rid,
"text": req.text,
"sampling_params": sample_params,
"return_logprob": True,
"stream": False,
}
# Make request
start_time = time.perf_counter()
accumulated_output_tokens = []
accumulated_output_logprobs = []
accumulated_versions = []
# Deal with rollout interruption
completion = ""
stop_reason = "length"
while (
stop_reason != "stop"
and len(accumulated_output_tokens) < gconfig.max_new_tokens
):
# loop until the generation is complete
response, server_info = await self.arequest_with_retry(
endpoint="/generate",
payload=payload,
method="POST",
max_retries=3,
timeout=self.client_config.request_timeout,
)
result = await response.json()
# Parse response
completion += result["text"]
meta_info = result["meta_info"]
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
# Update accumulated outputs
accumulated_output_tokens.extend(output_tokens)
accumulated_output_logprobs.extend(output_logprobs)
accumulated_versions.extend([server_info.version] * len(output_tokens))
# Check if generation is complete
finish_reason = meta_info["finish_reason"]
stop_reason = finish_reason["type"]
payload["text"] += completion
latency = time.perf_counter() - start_time
return LLMResponse(
completion=completion,
input_tokens=req.input_ids,
output_tokens=accumulated_output_tokens,
output_logprobs=accumulated_output_logprobs,
output_versions=accumulated_versions,
stop_reason=stop_reason,
latency=latency,
ttft=latency, # Simplified for non-streaming
)
async def aupdate_weights_from_disk(self, server_info: LLMServerInfo, path: str):
server_url = f"http://{server_info.host}:{server_info.port}"
response, _ = await self.arequest_with_retry(
endpoint="/update_weights_from_disk",
payload=dict(model_path=path, allow_interrupt=True),
method="POST",
max_retries=3,
timeout=self.client_config.request_timeout,
target_server=server_info,
)
res = await response.json()
assert res["success"]
if "num_paused_requests" in res:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updating weights for server {server_url}"
)
self.registry.update_heartbeat(
server_info.server_id, "healthy", version=server_info.version + 1
)
async def ainit_weight_update_group(self, server_info, group_meta):
payload = dict(
master_address=group_meta.master_address,
master_port=group_meta.master_port,
rank_offset=group_meta.rank_offset,
world_size=group_meta.world_size,
group_name=group_meta.group_name,
backend=group_meta.backend,
)
response, _ = await self.arequest_with_retry(
endpoint="/init_weights_update_group",
payload=payload,
method="POST",
max_retries=3,
timeout=self.client_config.request_timeout,
target_server=server_info,
)
res = await response.json()
assert res["success"], res["message"]
async def aupdate_weights_from_distributed(self, server_info, weight_meta):
payload = dict(
name=weight_meta.param_name,
dtype=weight_meta.dtype,
shape=weight_meta.shape,
)
response, _ = await self.arequest_with_retry(
endpoint="/update_weights_from_distributed",
payload=payload,
method="POST",
max_retries=3,
timeout=self.client_config.request_timeout,
target_server=server_info,
)
res = await response.json()
assert res["success"], res["message"]

View File

@ -0,0 +1,166 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
import os
import subprocess
import sys
from pathlib import Path
import requests
from arealite.api.cli_args import LLMServiceConfig, SGLangConfig
from arealite.api.io_struct import AllocationMode, LLMServerInfo
from arealite.api.llm_server_api import LLMServer
from realhf.base import gpu_utils, logging, network, pkg_version
logger = logging.getLogger(__name__)
def apply_sglang_path():
"""Apply SGLang patch if available."""
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = ""
try:
sglang_meta = subprocess.check_output(
"python3 -m pip show sglang", shell=True
).decode("ascii")
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]).parent)
if target_path and Path(patch_path).exists():
proc = subprocess.Popen(
["git", "apply", patch_path],
cwd=target_path,
stderr=sys.stdout,
stdout=sys.stdout,
)
proc.wait()
logger.info(f"Applied SGLang patch at {target_path}")
except (subprocess.CalledProcessError, FileNotFoundError):
pass
class SGLangServer(LLMServer):
"""SGLang implementation of LLMServer."""
def __init__(self, args, service_config: LLMServiceConfig):
super().__init__(args, service_config)
self.server_info: LLMServerInfo | None = None
self.base_gpu_id = 0
self.config = args.rollout.sglang
self.alloc_mode = AllocationMode.from_str(args.allocation_mode)
def _resolve_base_gpu_id(self):
# Determine GPU configuration
import ray
tp_size = self.alloc_mode.gen_tp_size
pp_size = self.alloc_mode.gen_pp_size
mp_size = tp_size * pp_size
if ray.is_initialized():
self.base_gpu_id = 0
elif "CUDA_VISIBLE_DEVICES" in os.environ:
if len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1:
self.base_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"])
elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == mp_size:
self.base_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0])
else:
logger.warning(
f"Unknown how to resolve cuda visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}, "
f"setting base_gpu_id to 0."
)
self.base_gpu_id = 0
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(gpu_utils.gpu_count()))
)
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
# torchrun
self.base_gpu_id = int(os.environ["RANK"]) % gpu_utils.gpu_count()
elif gpu_utils.gpu_count() == mp_size:
self.base_gpu_id = 0
else:
logger.warning("Unknown GPU configuration, setting base_gpu_id to 0. ")
self.base_gpu_id = 0
def launch_server(self) -> LLMServerInfo | None:
# Apply SGLang patch
apply_sglang_path()
self._resolve_base_gpu_id()
# Get host and ports
host_ip = network.gethostip()
host = "localhost" if not self.config.enable_metrics else host_ip
ports = network.find_multiple_free_ports(
2,
low=10000,
high=60000,
experiment_name=self.registry.expr_name,
trial_name=self.registry.trial_name,
)
server_port = ports[0]
nccl_port = ports[1]
# Build command
tp_size = self.alloc_mode.gen_tp_size
cmd = SGLangConfig.build_cmd(
sglang_config=self.config,
model_path=self.args.rollout.model_path,
tp_size=tp_size,
base_gpu_id=self.base_gpu_id,
dist_init_addr=f"{host}:{nccl_port}",
served_model_name=self.service_config.served_model_name,
skip_tokenizer_init=False,
)
# Launch process
full_command = f"{cmd} --port {server_port}"
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
self.process = subprocess.Popen(
full_command.split(),
text=True,
stdout=sys.stdout,
stderr=sys.stdout,
)
# Create server info
self.server_info = LLMServerInfo(
server_id=self.server_id,
host=host,
port=server_port,
status="starting",
version=0,
)
return self.server_info
def check_health(self) -> bool:
"""Check if the SGLang server is healthy."""
if not self.server_info or not self.process:
return False
# Check if process is still running
if self.process.poll() is not None:
return False
try:
# Check server endpoint
base_url = f"http://{self.server_info.host}:{self.server_info.port}"
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
if response.status_code != 200:
return False
# Update server load
for line in response.text.split("\n"):
if line.startswith("sglang:num_running_reqs"):
self.load = float(line.split(" ")[1])
break
return True
except requests.exceptions.RequestException:
return False