mirror of https://github.com/inclusionAI/AReaL
checkout previous impl
This commit is contained in:
parent
6710d5f275
commit
3a0f1e558c
|
@ -0,0 +1,232 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from arealite.api.cli_args import GenerationHyperparameters, TrainingArgs
|
||||
from arealite.api.io_struct import (
|
||||
AgentInferInput,
|
||||
AgentInferOutput,
|
||||
LLMRequest,
|
||||
Trajectory,
|
||||
TrajStats,
|
||||
)
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
from arealite.api.rollout_api import Agent, Environment, RolloutCollector
|
||||
from arealite.utils import pad_sequences_to_tensors
|
||||
from functioncall.code.local_verify import code_verify as local_code_verify
|
||||
from functioncall.code.verify import code_verify
|
||||
from functioncall.math.verify import math_verify
|
||||
from realhf.impl.dataset.math_code_dataset import load_metadata
|
||||
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
|
||||
|
||||
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
|
||||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
|
||||
code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _load_metadata_cached(dataset_path: str):
|
||||
"""Cached version of load_metadata to avoid reloading metadata each time."""
|
||||
return load_metadata(dataset_path)
|
||||
|
||||
|
||||
def extract_code(text, min_length=20):
|
||||
"""Extract code blocks from text."""
|
||||
code_pattern = r"(?i)```(?:python|py|cpp|CPP)?\s*\n?(.*?)\n?```"
|
||||
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
||||
valid_blocks = []
|
||||
for block in code_blocks:
|
||||
clean_block = block.strip()
|
||||
if len(clean_block) < min_length:
|
||||
continue
|
||||
valid_blocks.append(clean_block)
|
||||
|
||||
if not valid_blocks:
|
||||
return None
|
||||
# return the last code block
|
||||
return valid_blocks[-1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MathCodeAction:
|
||||
query_id: str
|
||||
answer: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MathCodeObs:
|
||||
query_id: str
|
||||
prompt_ids: List[int]
|
||||
|
||||
|
||||
class MathCodeSingleStepEnv(Environment):
|
||||
"""Math and Code single-step verification environment."""
|
||||
|
||||
def __init__(self, args: TrainingArgs, solution_path: str):
|
||||
super().__init__(args)
|
||||
self.id2info, _ = _load_metadata_cached(solution_path)
|
||||
|
||||
def reset(
|
||||
self, seed: Optional[int] = None, options: Optional[dict] = None
|
||||
) -> Tuple[Any, dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset(seed=seed)
|
||||
try:
|
||||
prompt_ids = options["input_ids"]
|
||||
query_id = options["query_id"]
|
||||
except KeyError:
|
||||
raise RuntimeError("`input_ids` and `query_id` must be set in env options.")
|
||||
# Return dummy observation and info
|
||||
return MathCodeObs(query_id=query_id, prompt_ids=prompt_ids), {}
|
||||
|
||||
def step(
|
||||
self, action: MathCodeAction
|
||||
) -> Tuple[MathCodeObs, float, bool, bool, dict]:
|
||||
"""Execute one step in the environment."""
|
||||
query_id = action.query_id
|
||||
answer = action.answer
|
||||
|
||||
query_id = query_id.split("@")[0]
|
||||
cur_task = self.id2info[query_id]["task"]
|
||||
|
||||
if cur_task == "math":
|
||||
# Run math verification
|
||||
format_reward = math_verify_call(self.id2info, [answer], [query_id])[0]
|
||||
elif cur_task == "code":
|
||||
# Extract code blocks and run code verification
|
||||
extracted_answer = extract_code(answer)
|
||||
format_reward = code_verify_call(
|
||||
self.id2info, [extracted_answer], [query_id]
|
||||
)[0]
|
||||
else:
|
||||
raise NotImplementedError(f"Task type '{cur_task}' not implemented")
|
||||
|
||||
# Return: observation, reward, terminated, truncated, info
|
||||
terminated = True # Single step environment always terminates
|
||||
truncated = False
|
||||
info = {"task": cur_task, "query_id": query_id}
|
||||
|
||||
return (
|
||||
None,
|
||||
format_reward,
|
||||
terminated,
|
||||
truncated,
|
||||
info,
|
||||
)
|
||||
|
||||
|
||||
class MathCodeAgent(Agent):
|
||||
|
||||
async def aact(self, inp: AgentInferInput) -> AgentInferOutput:
|
||||
"""Async version of act. Given an observation, return an action."""
|
||||
# Extract information from observation
|
||||
obs: MathCodeObs = inp.obs
|
||||
query_id = obs.query_id
|
||||
prompt_ids = obs.prompt_ids
|
||||
|
||||
# Create LLM request
|
||||
llm_req = LLMRequest(
|
||||
rid=str(query_id) + "-" + str(uuid.uuid4()),
|
||||
input_ids=prompt_ids,
|
||||
gconfig=inp.gconfig,
|
||||
)
|
||||
|
||||
# Generate response using async LLM client
|
||||
llm_resp = await inp.llm_client.agenerate(llm_req)
|
||||
|
||||
# Extract answers from completion
|
||||
answer = llm_resp.completion
|
||||
|
||||
return AgentInferOutput(
|
||||
action=MathCodeAction(query_id=query_id, answer=answer),
|
||||
llm_req=llm_req,
|
||||
llm_resp=llm_resp,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""Resets the agent's memory."""
|
||||
pass # Stateless agent, no memory to reset
|
||||
|
||||
async def areset(self):
|
||||
"""Async version of reset. Resets the agent's memory."""
|
||||
pass # Stateless agent, no memory to reset
|
||||
|
||||
|
||||
class MathCodeSingleStepCollector(RolloutCollector):
|
||||
|
||||
async def arun_episode(
|
||||
self,
|
||||
llm_client: LLMClient,
|
||||
gconfig: GenerationHyperparameters,
|
||||
env_option: Optional[Any] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> Trajectory:
|
||||
"""Async version of run_episode. Run a single episode and return the trajectory."""
|
||||
# Reset the environment and the agent's memory.
|
||||
obs, _ = self.env.reset(options=env_option, seed=seed)
|
||||
await self.agent.areset()
|
||||
|
||||
data = []
|
||||
rewards = []
|
||||
tik = datetime.now().timestamp()
|
||||
ret = 0.0
|
||||
ep_len = 0
|
||||
|
||||
done = False
|
||||
# Episode loop.
|
||||
while not done:
|
||||
# Take an action by sending a request to generation server.
|
||||
agent_infer_in = AgentInferInput(
|
||||
obs=obs, gconfig=gconfig, llm_client=llm_client
|
||||
)
|
||||
agent_infer_out = await self.agent.aact(agent_infer_in)
|
||||
action = agent_infer_out.action
|
||||
|
||||
# Advance one step in the environment.
|
||||
nex_obs, reward, terminated, truncated, _ = self.env.step(action)
|
||||
|
||||
# Collect the step data.
|
||||
resp = agent_infer_out.llm_resp
|
||||
input_len = len(resp.input_tokens)
|
||||
output_len = len(resp.output_tokens)
|
||||
|
||||
input_ids = resp.input_tokens + resp.output_tokens
|
||||
prompt_mask = [1] * input_len + [0] * output_len
|
||||
logprobs = [0.0] * input_len + resp.output_logprobs
|
||||
versions = [-1] * input_len + resp.output_versions
|
||||
|
||||
d = dict(
|
||||
input_ids=torch.tensor(input_ids, dtype=torch.long),
|
||||
prompt_mask=torch.tensor(prompt_mask, dtype=torch.bool),
|
||||
logprobs=torch.tensor(logprobs, dtype=torch.float32),
|
||||
versions=torch.tensor(versions, dtype=torch.long),
|
||||
)
|
||||
data.append(d)
|
||||
rewards.append(reward)
|
||||
|
||||
ret += float(reward)
|
||||
ep_len += 1
|
||||
|
||||
# Prepare information for the next step.
|
||||
done = terminated or truncated
|
||||
obs = nex_obs
|
||||
|
||||
return Trajectory(
|
||||
prompt=env_option,
|
||||
data=dict(rewards=torch.tensor(rewards), **pad_sequences_to_tensors(data)),
|
||||
stats=TrajStats(
|
||||
start_time=tik,
|
||||
total_reward=ret,
|
||||
episode_length=ep_len,
|
||||
info={},
|
||||
),
|
||||
)
|
|
@ -0,0 +1,7 @@
|
|||
from datasets import Dataset
|
||||
|
||||
|
||||
def process_areal_dataset(dataset: Dataset, tokenizer):
|
||||
return dataset.map(
|
||||
lambda x: tokenizer(x["prompt"], return_attention_mask=False), batched=True
|
||||
)
|
|
@ -0,0 +1,46 @@
|
|||
from datasets import Dataset
|
||||
|
||||
|
||||
def process_gsm8k_rl_dataset(dataset: Dataset, tokenizer, reward_mode):
|
||||
def process_example(example, idx):
|
||||
# Add query_id column
|
||||
example["query_id"] = str(idx)
|
||||
example["prompt"] = example["question"]
|
||||
|
||||
# used by the reward function
|
||||
example["method"] = reward_mode
|
||||
return example
|
||||
|
||||
dataset = dataset.map(
|
||||
lambda example, idx: process_example(example, idx),
|
||||
with_indices=True,
|
||||
)
|
||||
return dataset.map(
|
||||
lambda x: tokenizer(x["question"], return_attention_mask=False), batched=True
|
||||
)
|
||||
|
||||
|
||||
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
|
||||
def process_example(example, idx):
|
||||
# Add query_id column
|
||||
example["query_id"] = str(idx)
|
||||
example["prompt"] = example["question"]
|
||||
example["seq"] = example["prompt"] + example["answer"] + tokenizer.eos_token
|
||||
return example
|
||||
|
||||
dataset = dataset.map(
|
||||
lambda example, idx: process_example(example, idx),
|
||||
with_indices=True,
|
||||
)
|
||||
|
||||
def _tokenize(example):
|
||||
example["prompt"] = tokenizer(example["prompt"], return_attention_mask=False)[
|
||||
"input_ids"
|
||||
]
|
||||
example["seq"] = tokenizer(example["seq"], return_attention_mask=False)[
|
||||
"input_ids"
|
||||
]
|
||||
return example
|
||||
|
||||
dataset = dataset.map(lambda x: _tokenize(x), batched=True)
|
||||
return dataset
|
|
@ -0,0 +1,553 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
PreTrainedModel,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
from arealite.api.cli_args import EngineConfig, FSDPConfig, MicroBatchSpec, TrainingArgs
|
||||
from arealite.api.engine_api import SPMDWrapper
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
from arealite.utils import (
|
||||
get_state_dict_from_repo_id_or_path,
|
||||
recorder_list,
|
||||
split_dict_tensor_with_cu_seqlens,
|
||||
unpack_sequence,
|
||||
)
|
||||
from realhf.api.cli_args import ParallelismConfig
|
||||
from realhf.base import constants
|
||||
from realhf.base.pkg_version import is_version_greater_or_equal
|
||||
|
||||
if is_version_greater_or_equal("torch", "2.6.0"):
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
MixedPrecisionPolicy,
|
||||
fully_shard,
|
||||
)
|
||||
elif is_version_greater_or_equal("torch", "2.4.0"):
|
||||
from torch.distributed._composable.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
FSDPModule,
|
||||
MixedPrecisionPolicy,
|
||||
fully_shard,
|
||||
)
|
||||
else:
|
||||
fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
|
||||
def fsdp2_clip_grad_norm_(
|
||||
parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None
|
||||
):
|
||||
"""torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor"""
|
||||
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
else:
|
||||
# prevent generators from being exhausted
|
||||
parameters = list(parameters)
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
|
||||
total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)
|
||||
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
|
||||
return total_norm
|
||||
|
||||
|
||||
def create_fsdp_device_mesh(shard_size, world_size):
|
||||
if shard_size < 0 or shard_size >= world_size:
|
||||
device_mesh = init_device_mesh(
|
||||
"cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)
|
||||
)
|
||||
else:
|
||||
device_mesh = init_device_mesh(
|
||||
"cuda",
|
||||
mesh_shape=(world_size // shard_size, shard_size),
|
||||
mesh_dim_names=("ddp", "fsdp"),
|
||||
)
|
||||
return device_mesh
|
||||
|
||||
|
||||
def apply_fsdp2(model, fsdp_kwargs, wrap_policy):
|
||||
"""model: AutoModelForCausalLM"""
|
||||
assert (
|
||||
CPUOffloadPolicy is not None
|
||||
), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
|
||||
|
||||
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", list())
|
||||
fsdp_transformer_layer_cls_to_wrap = (
|
||||
wrap_policy.transformer_layer_cls_to_wrap if wrap_policy is not None else list()
|
||||
)
|
||||
if not fsdp_transformer_layer_cls_to_wrap:
|
||||
fsdp_transformer_layer_cls_to_wrap = default_transformer_cls_names_to_wrap
|
||||
|
||||
if isinstance(fsdp_transformer_layer_cls_to_wrap, str):
|
||||
fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]
|
||||
|
||||
assert (
|
||||
len(fsdp_transformer_layer_cls_to_wrap) > 0
|
||||
and fsdp_transformer_layer_cls_to_wrap[0] is not None
|
||||
)
|
||||
|
||||
modules = []
|
||||
for name, module in model.named_modules():
|
||||
if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (
|
||||
isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings
|
||||
):
|
||||
modules.append(module)
|
||||
|
||||
for idx, module in enumerate(modules):
|
||||
fully_shard(module, **fsdp_kwargs)
|
||||
fully_shard(
|
||||
model, **fsdp_kwargs
|
||||
) # fsdp2 will not reshard_after_forward for root module
|
||||
|
||||
|
||||
def fsdp2_load_full_state_dict(
|
||||
model: PreTrainedModel,
|
||||
full_state: dict,
|
||||
cpu_offload=None,
|
||||
tie_word_embeddings=False,
|
||||
):
|
||||
"""
|
||||
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
||||
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to load the state dict into
|
||||
full_state (`dict`): The full state dict to load, can only be on rank 0
|
||||
"""
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
StateDictOptions,
|
||||
set_model_state_dict,
|
||||
)
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
model = model.to(device=device, non_blocking=True)
|
||||
cpu_offload = cpu_offload is not None
|
||||
options = StateDictOptions(
|
||||
full_state_dict=True,
|
||||
cpu_offload=cpu_offload,
|
||||
broadcast_from_rank0=True,
|
||||
strict=not tie_word_embeddings,
|
||||
)
|
||||
set_model_state_dict(model, full_state, options=options)
|
||||
|
||||
if tie_word_embeddings:
|
||||
model.tie_weights()
|
||||
|
||||
# rotary_emb is not in state_dict, so we need to broadcast it manually
|
||||
for name, buf in model.named_buffers():
|
||||
dist.broadcast(buf, src=0)
|
||||
|
||||
if cpu_offload:
|
||||
model.to("cpu", non_blocking=True)
|
||||
for buf in model.buffers():
|
||||
buf.data = buf.data.to(device)
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
min_lr_ratio: float = 0.0,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||
initial lr set in the optimizer.
|
||||
Args:
|
||||
optimizer (:class:`~torch.optim.Optimizer`):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (:obj:`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (:obj:`int`):
|
||||
The total number of training steps.
|
||||
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The minimum lr ratio w.r.t the maximum.
|
||||
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
|
||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||
following a half-cosine).
|
||||
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
Return:
|
||||
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
|
||||
coef = (1 - min_lr_ratio) * 0.5
|
||||
intercept = (1 + min_lr_ratio) * 0.5
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return min_lr_ratio + (1.0 - min_lr_ratio) * (
|
||||
float(current_step) / float(max(1, num_warmup_steps))
|
||||
)
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
|
||||
return max(min_lr_ratio, x * coef + intercept)
|
||||
|
||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
class FSDPEngine(SPMDWrapper):
|
||||
"""Simplified FSDP engine for transformer models."""
|
||||
|
||||
def __init__(self, args: TrainingArgs, engine_config: EngineConfig):
|
||||
super().__init__(args, engine_config)
|
||||
assert is_version_greater_or_equal(
|
||||
"torch", "2.4.0"
|
||||
), f"arealite only supports FSDP2, which requires torch>=2.4.0"
|
||||
|
||||
self.fsdp_config = engine_config.backend.fsdp
|
||||
if self.fsdp_config is None:
|
||||
self.fsdp_config = FSDPConfig()
|
||||
self.optimizer_config = engine_config.optimizer
|
||||
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.model_config = None
|
||||
self.device_mesh = None
|
||||
self.cpu_offload = None
|
||||
|
||||
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set the module in training mode."""
|
||||
assert self.model is not None
|
||||
self.model.train(mode=mode)
|
||||
return self
|
||||
|
||||
def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec):
|
||||
"""Initialize distributed communication and model."""
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.engine_config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
with torch.device("cuda"):
|
||||
# initialize scratch model from config
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
# Simple auto wrap policy
|
||||
# TODO: fix wrap policy
|
||||
mixed_precision_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.float32,
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size)
|
||||
self.device_mesh = device_mesh
|
||||
# sharding_strategy = ShardingStrategy.FULL_SHARD
|
||||
self.cpu_offload = (
|
||||
CPUOffloadPolicy() if self.fsdp_config.offload_params else None
|
||||
)
|
||||
|
||||
fsdp_kwargs = {
|
||||
"mesh": device_mesh,
|
||||
"mp_policy": mixed_precision_policy,
|
||||
"offload_policy": self.cpu_offload,
|
||||
"reshard_after_forward": True,
|
||||
}
|
||||
|
||||
# Wrap with FSDP2
|
||||
apply_fsdp2(model, fsdp_kwargs, self.fsdp_config.wrap_policy)
|
||||
|
||||
self.model = model
|
||||
|
||||
# Set up optimizer
|
||||
if self.optimizer_config is not None:
|
||||
assert (
|
||||
self.optimizer_config.type == "adam"
|
||||
), "Only AdamW optimizer is supported in this engine."
|
||||
lr = self.optimizer_config.lr
|
||||
weight_decay = self.optimizer_config.weight_decay
|
||||
beta1 = self.optimizer_config.beta1
|
||||
beta2 = self.optimizer_config.beta2
|
||||
eps = self.optimizer_config.eps
|
||||
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(beta1, beta2),
|
||||
eps=eps,
|
||||
)
|
||||
total_train_steps = ft_spec.total_train_steps
|
||||
num_warmup_steps = int(
|
||||
self.optimizer_config.warmup_steps_proportion * total_train_steps
|
||||
)
|
||||
|
||||
if self.optimizer_config.lr_scheduler_type == "cosine":
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
min_lr_ratio=self.optimizer_config.min_lr_ratio,
|
||||
)
|
||||
elif self.optimizer_config.lr_scheduler_type == "linear":
|
||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
)
|
||||
elif self.optimizer_config.lr_scheduler_type == "constant":
|
||||
self.lr_scheduler = get_constant_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
||||
)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
self.model.train(mode)
|
||||
return self
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict:
|
||||
"""Train on a batch using gradient accumulation."""
|
||||
# self._initialize_fsdp_train()
|
||||
assert self.optimizer is not None
|
||||
assert self.optimizer_config is not None
|
||||
assert self.lr_scheduler is not None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
mb_inputs = split_dict_tensor_with_cu_seqlens(input_, mb_spec).mbs
|
||||
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
dist.all_reduce(total_loss_weight)
|
||||
|
||||
# Process microbatches with gradient accumulation
|
||||
for i, mb_input in enumerate(mb_inputs):
|
||||
outputs = self.model(**mb_input)
|
||||
|
||||
loss = loss_fn(outputs.logits, mb_input)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
|
||||
# Scale loss for accumulation
|
||||
# Revert gradient averaging across dp ranks
|
||||
loss_scale *= self.world_size
|
||||
|
||||
loss *= loss_scale
|
||||
loss.backward()
|
||||
|
||||
grad_norm = fsdp2_clip_grad_norm_(
|
||||
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
|
||||
)
|
||||
if not torch.isfinite(grad_norm):
|
||||
self.optimizer.zero_grad()
|
||||
update_successful = False
|
||||
else:
|
||||
self.optimizer.step()
|
||||
update_successful = True
|
||||
|
||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
return dict(
|
||||
update_successful=float(update_successful),
|
||||
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||
lr=current_lr,
|
||||
)
|
||||
|
||||
def step_lr_scheduler(self):
|
||||
assert self.lr_scheduler is not None
|
||||
self.lr_scheduler.step()
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate on a batch."""
|
||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
|
||||
total_loss = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for mb_input in mb_splits.mbs:
|
||||
outputs = self.model(**mb_input)
|
||||
loss = loss_fn(outputs.logits, mb_input)
|
||||
|
||||
# Simple weight calculation (could be improved)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
total_loss += loss.item() * loss_scale
|
||||
total_weight += loss_scale
|
||||
|
||||
return torch.tensor(total_loss / total_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),
|
||||
) -> Any | None:
|
||||
"""Forward pass with optional post-processing."""
|
||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
||||
if output_seqlens is None:
|
||||
cu_seqlens = input_["cu_seqlens"]
|
||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||
|
||||
results = []
|
||||
for mb_input in mb_splits.mbs:
|
||||
outputs = self.model(**mb_input)
|
||||
if post_hook:
|
||||
result = post_hook(outputs.logits, mb_input)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(outputs.logits)
|
||||
|
||||
res = aggregate_fn(results)
|
||||
output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices]
|
||||
unpacked = unpack_sequence(res, lens=output_seqlens, dim=1)
|
||||
return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices))
|
||||
|
||||
def get_hf_model_state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
"""Get model state dict for saving."""
|
||||
if self.model is None:
|
||||
raise RuntimeError("Model not initialized")
|
||||
|
||||
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
|
||||
return self.model.state_dict()
|
||||
|
||||
def save_model_to_hf(
|
||||
self,
|
||||
path: str,
|
||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
|
||||
base_model_path: Optional[str] = None,
|
||||
):
|
||||
"""Save model in HuggingFace format."""
|
||||
if self.model is None:
|
||||
raise RuntimeError("Model not initialized")
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
# FSDP2 checkpoint saving
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
)
|
||||
|
||||
# Get full state dict with FSDP2
|
||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
||||
state_dict = get_model_state_dict(self.model, options=options)
|
||||
|
||||
# save huggingface model
|
||||
if dist.get_rank() == 0:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
self.model.save_pretrained(path, state_dict=state_dict)
|
||||
self.model_config.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
def load_model_from_hf(self, path: str):
|
||||
"""Load model from HuggingFace format."""
|
||||
if dist.get_rank() == 0:
|
||||
full_state = get_state_dict_from_repo_id_or_path(path)
|
||||
else:
|
||||
full_state = {}
|
||||
|
||||
fsdp2_load_full_state_dict(
|
||||
self.model,
|
||||
full_state,
|
||||
self.cpu_offload,
|
||||
tie_word_embeddings=self.model_config.tie_word_embeddings,
|
||||
)
|
||||
|
||||
def save_optimizer_state(self, path: str):
|
||||
"""Save optimizer state."""
|
||||
if self.optimizer is None:
|
||||
raise RuntimeError("Optimizer not initialized")
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt"))
|
||||
|
||||
def load_optimizer_state(self, path: str):
|
||||
"""Load optimizer state."""
|
||||
if self.optimizer is None:
|
||||
raise RuntimeError("Optimizer not initialized")
|
||||
|
||||
optimizer_path = os.path.join(path, "optimizer.pt")
|
||||
if os.path.exists(optimizer_path):
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(optimizer_path, map_location="cpu")
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Optimizer state file not found: {optimizer_path}")
|
||||
|
||||
async def aupdate_weights_to(self, llm_client: LLMClient):
|
||||
"""Async method to update weights to all healthy servers."""
|
||||
path = constants.get_param_realloc_path(self.args)
|
||||
self.save_model_to_hf(path)
|
||||
tasks = [
|
||||
llm_client.aupdate_weights_from_disk(server_info=server_info, path=path)
|
||||
for server_info in llm_client.get_healthy_servers()
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def update_weights_to(self, llm_client: LLMClient):
|
||||
"""Update the weights to the server by sending requests to the client."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(self.aupdate_weights_to(llm_client))
|
||||
finally:
|
||||
loop.close()
|
|
@ -0,0 +1,315 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
EngineConfig,
|
||||
MicroBatchSpec,
|
||||
ParallelismConfig,
|
||||
TrainingArgs,
|
||||
)
|
||||
from arealite.api.engine_api import SPMDWrapper
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
from arealite.utils import (
|
||||
get_state_dict_from_repo_id_or_path,
|
||||
recorder_list,
|
||||
split_dict_tensor_with_cu_seqlens,
|
||||
unpack_sequence,
|
||||
)
|
||||
from realhf.base import constants
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
min_lr_ratio: float = 0.0,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||
initial lr set in the optimizer.
|
||||
Args:
|
||||
optimizer (:class:`~torch.optim.Optimizer`):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (:obj:`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (:obj:`int`):
|
||||
The total number of training steps.
|
||||
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The minimum lr ratio w.r.t the maximum.
|
||||
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
|
||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||
following a half-cosine).
|
||||
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
Return:
|
||||
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
|
||||
coef = (1 - min_lr_ratio) * 0.5
|
||||
intercept = (1 + min_lr_ratio) * 0.5
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
|
||||
return max(0.0, x * coef + intercept)
|
||||
|
||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
class HFEngine(SPMDWrapper):
|
||||
"""Simplified HF engine for transformer models."""
|
||||
|
||||
def __init__(self, args: TrainingArgs, engine_config: EngineConfig):
|
||||
super().__init__(args, engine_config)
|
||||
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.model_config = None
|
||||
|
||||
self.weight_update_group_initialized = False
|
||||
|
||||
def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec):
|
||||
"""Initialize model in single node."""
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
if dist.get_world_size() > 1:
|
||||
raise RuntimeError(
|
||||
"Distributed training is not supported in this engine. "
|
||||
"Please use FSDP for distributed training."
|
||||
)
|
||||
torch.cuda.set_device("cuda:0")
|
||||
|
||||
dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.engine_config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
with torch.device("cuda"):
|
||||
# initialize scratch model from config
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
model = model.cuda()
|
||||
|
||||
self.model = model
|
||||
|
||||
# Set up optimizer
|
||||
optimizer_config = self.engine_config.optimizer
|
||||
if optimizer_config is not None:
|
||||
assert (
|
||||
optimizer_config.type == "adam"
|
||||
), "Only AdamW optimizer is supported in this engine."
|
||||
lr = optimizer_config.lr
|
||||
weight_decay = optimizer_config.weight_decay
|
||||
beta1 = optimizer_config.beta1
|
||||
beta2 = optimizer_config.beta2
|
||||
eps = optimizer_config.eps
|
||||
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(beta1, beta2),
|
||||
eps=eps,
|
||||
)
|
||||
total_train_steps = ft_spec.total_train_steps
|
||||
num_warmup_steps = int(
|
||||
optimizer_config.warmup_steps_proportion * total_train_steps
|
||||
)
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps,
|
||||
total_train_steps,
|
||||
min_lr_ratio=optimizer_config.min_lr_ratio,
|
||||
)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set the module in training mode."""
|
||||
return self.model.train(mode)
|
||||
|
||||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict:
|
||||
"""Train on a batch using gradient accumulation."""
|
||||
assert self.optimizer is not None
|
||||
assert self.lr_scheduler is not None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
|
||||
for mb_input in mb_splits.mbs:
|
||||
outputs = self.model(**mb_input)
|
||||
loss = loss_fn(outputs.logits, mb_input)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
loss *= loss_scale
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.engine_config.optimizer.gradient_clipping,
|
||||
norm_type=2.0,
|
||||
error_if_nonfinite=False,
|
||||
foreach=None,
|
||||
)
|
||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
return {
|
||||
"grad_norm": grad_norm,
|
||||
"lr": current_lr,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate on a batch."""
|
||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
|
||||
)
|
||||
assert total_loss_weight != 0
|
||||
|
||||
total_loss = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for mb_input in mb_splits.mbs:
|
||||
outputs = self.model(**mb_input)
|
||||
loss = loss_fn(outputs.logits, mb_input)
|
||||
|
||||
# Simple weight calculation (could be improved)
|
||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||
total_loss += loss.item() * loss_scale
|
||||
total_weight += loss_scale
|
||||
|
||||
return torch.tensor(total_loss / total_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),
|
||||
) -> Any | None:
|
||||
"""Forward pass with optional post-processing."""
|
||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
||||
if output_seqlens is None:
|
||||
cu_seqlens = input_["cu_seqlens"]
|
||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||
|
||||
results = []
|
||||
for mb_input in mb_splits.mbs:
|
||||
outputs = self.model(**mb_input)
|
||||
if post_hook:
|
||||
result = post_hook(outputs.logits, mb_input)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(outputs.logits)
|
||||
|
||||
res = aggregate_fn(results)
|
||||
output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices]
|
||||
unpacked = unpack_sequence(res, lens=output_seqlens, dim=1)
|
||||
return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices))
|
||||
|
||||
def step_lr_scheduler(self):
|
||||
"""Step the learning rate scheduler."""
|
||||
return self.lr_scheduler.step()
|
||||
|
||||
def save_model_to_hf(
|
||||
self,
|
||||
path: str,
|
||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
|
||||
base_model_path: Optional[str] = None,
|
||||
):
|
||||
"""Save model in HuggingFace format."""
|
||||
if self.model is None:
|
||||
raise RuntimeError("Model not initialized")
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()}
|
||||
self.model.save_pretrained(path, state_dict=state_dict)
|
||||
self.model_config.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
def load_model_from_hf(self, path: str):
|
||||
"""Load model from HuggingFace format."""
|
||||
full_state = get_state_dict_from_repo_id_or_path(path)
|
||||
self.model.load_state_dict(
|
||||
full_state, strict=not self.model_config.tie_word_embeddings
|
||||
)
|
||||
if self.model_config.tie_word_embeddings:
|
||||
self.model.tie_weights()
|
||||
|
||||
def save_optimizer_state(self, path: str):
|
||||
"""Save optimizer state."""
|
||||
if self.optimizer is None:
|
||||
raise RuntimeError("Optimizer not initialized")
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt"))
|
||||
|
||||
def load_optimizer_state(self, path: str):
|
||||
"""Load optimizer state."""
|
||||
if self.optimizer is None:
|
||||
raise RuntimeError("Optimizer not initialized")
|
||||
|
||||
optimizer_path = os.path.join(path, "optimizer.pt")
|
||||
if os.path.exists(optimizer_path):
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(optimizer_path, map_location="cpu")
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Optimizer state file not found: {optimizer_path}")
|
||||
|
||||
async def aupdate_weights_to(self, llm_client: LLMClient):
|
||||
path = constants.get_param_realloc_path(self.args)
|
||||
self.save_model_to_hf(path)
|
||||
tasks = [
|
||||
llm_client.aupdate_weights_from_disk(server_info=server_info, path=path)
|
||||
for server_info in llm_client.get_healthy_servers()
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def update_weights_to(self, llm_client: LLMClient):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(self.aupdate_weights_to(llm_client))
|
||||
finally:
|
||||
loop.close()
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
from functioncall.code.local_verify import code_verify
|
||||
from realhf.impl.dataset.math_code_dataset import load_metadata
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_metadata(dataset_path: str):
|
||||
"""Cached version of load_metadata to avoid reloading metadata each time."""
|
||||
return load_metadata(dataset_path)
|
||||
|
||||
|
||||
def extract_code(text, min_length=20):
|
||||
"""Extract code blocks from text."""
|
||||
code_pattern = r"(?i)```(?:python|py|cpp|CPP)?\s*\n?(.*?)\n?```"
|
||||
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
||||
valid_blocks = []
|
||||
for block in code_blocks:
|
||||
clean_block = block.strip()
|
||||
if len(clean_block) < min_length:
|
||||
continue
|
||||
valid_blocks.append(clean_block)
|
||||
|
||||
if not valid_blocks:
|
||||
return None
|
||||
# return the last code block
|
||||
return valid_blocks[-1]
|
||||
|
||||
|
||||
def code_reward(
|
||||
query_id: str,
|
||||
prompt: str,
|
||||
completion: str,
|
||||
prompt_ids: List[int],
|
||||
completion_ids: List[int],
|
||||
dataset_path: str,
|
||||
**kwargs,
|
||||
) -> float:
|
||||
id2info, _ = _load_metadata(dataset_path)
|
||||
return code_verify(
|
||||
id2info=id2info, generateds=[extract_code(completion)], query_ids=[query_id]
|
||||
)[0]
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
from realhf.impl.dataset.math_code_dataset import load_metadata
|
||||
from realhf.impl.dataset.math_parser import parse_line
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_metadata(dataset_path: str):
|
||||
"""Cached version of load_metadata to avoid reloading metadata each time."""
|
||||
return load_metadata(dataset_path)
|
||||
|
||||
|
||||
def math_reward(
|
||||
query_id: str,
|
||||
prompt: str,
|
||||
completion: str,
|
||||
prompt_ids: List[int],
|
||||
completion_ids: List[int],
|
||||
dataset_path: str,
|
||||
**kwargs,
|
||||
) -> float:
|
||||
id2info, _ = _load_metadata(dataset_path)
|
||||
return parse_line(id2info=id2info, generated=completion, query_id=query_id)
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
# Modified from verl.
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
|
||||
def extract_solution(solution_str, method="strict"):
|
||||
assert method in ["strict", "flexible"]
|
||||
|
||||
if method == "strict":
|
||||
# this also tests the formatting of the model
|
||||
solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
|
||||
if len(solutions) == 0:
|
||||
final_answer = None
|
||||
else:
|
||||
# take the last solution
|
||||
final_answer = solutions[-1].replace(",", "").replace("$", "")
|
||||
elif method == "flexible":
|
||||
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
|
||||
final_answer = None
|
||||
if len(answer) == 0:
|
||||
# no reward is there is no answer
|
||||
pass
|
||||
else:
|
||||
invalid_str = ["", "."]
|
||||
# find the last number that is not '.'
|
||||
for final_answer in reversed(answer):
|
||||
if final_answer not in invalid_str:
|
||||
break
|
||||
return final_answer
|
||||
|
||||
|
||||
def compute_score(
|
||||
solution_str, ground_truth, method="strict", format_score=0.0, score=1.0
|
||||
):
|
||||
"""The scoring function for GSM8k.
|
||||
|
||||
Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.
|
||||
|
||||
Args:
|
||||
solution_str: the solution text
|
||||
ground_truth: the ground truth
|
||||
method: the method to extract the solution, choices are 'strict' and 'flexible'
|
||||
format_score: the score for the format
|
||||
score: the score for the correct answer
|
||||
"""
|
||||
answer = extract_solution(solution_str=solution_str, method=method)
|
||||
if answer is None:
|
||||
return 0
|
||||
else:
|
||||
if answer == ground_truth:
|
||||
return score
|
||||
else:
|
||||
return format_score
|
||||
|
||||
|
||||
def gsm8k_reward_fn(
|
||||
query_id: str,
|
||||
prompt: str,
|
||||
completion: str,
|
||||
prompt_ids: List[int],
|
||||
completion_ids: List[int],
|
||||
answer: str,
|
||||
method: str,
|
||||
**kwargs,
|
||||
) -> float:
|
||||
return compute_score(completion, extract_solution(answer), method=method)
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
GenerationHyperparameters,
|
||||
RolloutCollectorConfig,
|
||||
TrainingArgs,
|
||||
)
|
||||
from arealite.api.io_struct import LLMRequest, Trajectory, TrajStats
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
from arealite.api.rollout_api import RolloutCollector
|
||||
from realhf.base import logging
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class RlvrCollector(RolloutCollector):
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArgs,
|
||||
config: RolloutCollectorConfig,
|
||||
reward_fn: Callable,
|
||||
):
|
||||
super().__init__(args, config, None, None)
|
||||
self.reward_fn = reward_fn
|
||||
|
||||
async def arun_episode(
|
||||
self,
|
||||
llm_client: LLMClient,
|
||||
gconfig: GenerationHyperparameters,
|
||||
env_option: Optional[Dict[str, Any]] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> Trajectory:
|
||||
"""Async version of run_episode. Run a single episode and return the trajectory."""
|
||||
tik = datetime.now().timestamp()
|
||||
|
||||
prompt_ids = env_option["input_ids"]
|
||||
query_id = env_option["query_id"]
|
||||
req = LLMRequest(input_ids=prompt_ids, gconfig=gconfig)
|
||||
|
||||
# Use async LLM client
|
||||
resp = await llm_client.agenerate(req)
|
||||
|
||||
# Run reward computation in executor to avoid blocking
|
||||
reward_kwargs = env_option.copy()
|
||||
reward_kwargs.pop("query_id")
|
||||
reward_kwargs.pop("prompt")
|
||||
reward = self.reward_fn(
|
||||
query_id=query_id,
|
||||
prompt=req.text,
|
||||
completion=resp.completion,
|
||||
prompt_ids=prompt_ids,
|
||||
completion_ids=resp.output_tokens,
|
||||
**reward_kwargs,
|
||||
)
|
||||
|
||||
input_len = len(resp.input_tokens)
|
||||
output_len = len(resp.output_tokens)
|
||||
|
||||
input_ids = resp.input_tokens + resp.output_tokens
|
||||
prompt_mask = [1] * input_len + [0] * output_len
|
||||
logprobs = [0.0] * input_len + resp.output_logprobs
|
||||
versions = [-1] * input_len + resp.output_versions
|
||||
|
||||
# logger.info(
|
||||
# f"Prompt: {req.text}, reward: {reward}\nCompletion: {resp.completion}"
|
||||
# )
|
||||
|
||||
return Trajectory(
|
||||
prompt=env_option,
|
||||
data=dict(
|
||||
# unsqueeze to add an additional batch dimension
|
||||
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
||||
prompt_mask=torch.tensor(prompt_mask).unsqueeze(0),
|
||||
logprobs=torch.tensor(logprobs).unsqueeze(0),
|
||||
versions=torch.tensor(versions).unsqueeze(0),
|
||||
# reward
|
||||
rewards=torch.tensor([reward]),
|
||||
),
|
||||
stats=TrajStats(
|
||||
start_time=tik,
|
||||
total_reward=reward,
|
||||
episode_length=1,
|
||||
info={},
|
||||
),
|
||||
)
|
|
@ -0,0 +1,530 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import functools
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import Dataset
|
||||
|
||||
from arealite import ppo_functional
|
||||
from arealite.api.cli_args import (
|
||||
GRPOTrainerConfig,
|
||||
MicroBatchSpec,
|
||||
TrainerConfig,
|
||||
TrainingArgs,
|
||||
)
|
||||
from arealite.api.engine_api import EngineFactory
|
||||
from arealite.api.io_struct import FinetuneSpec, Trajectory
|
||||
from arealite.api.llm_client_api import LLMClientFactory
|
||||
from arealite.api.trainer_api import Trainer
|
||||
from arealite.system.rollout_controller import RolloutController
|
||||
from arealite.utils import (
|
||||
calc_entropy,
|
||||
close_wandb_tensorboard,
|
||||
compute_varlen_position_indices,
|
||||
concat_padded_tensors,
|
||||
gather_logprobs,
|
||||
init_stats_logging,
|
||||
log_wandb_tensorboard,
|
||||
masked_normalization,
|
||||
record_timing,
|
||||
split_dict_tensor_with_cu_seqlens,
|
||||
to_device,
|
||||
unpad_input,
|
||||
)
|
||||
from realhf.api.core.data_api import load_hf_tokenizer, tabulate_stats
|
||||
from realhf.base import constants, logging, name_resolve, names, stats_tracker, timeutil
|
||||
|
||||
logger = logging.getLogger("GRPO Trainer", "system")
|
||||
|
||||
|
||||
class SpmdGRPOTrainer(Trainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArgs,
|
||||
trainer_config: TrainerConfig,
|
||||
train_dataset: Dataset,
|
||||
valid_dataset: Optional[Dataset] = None,
|
||||
rollout_controller: Optional[RolloutController] = None,
|
||||
):
|
||||
super().__init__(
|
||||
args,
|
||||
trainer_config,
|
||||
train_dataset,
|
||||
valid_dataset,
|
||||
rollout_controller,
|
||||
)
|
||||
if self.rollout_controller is None:
|
||||
raise ValueError("GRPO Trainer requires a rollout controller.")
|
||||
|
||||
assert trainer_config.grpo is not None
|
||||
self.config: GRPOTrainerConfig = trainer_config.grpo
|
||||
assert args.rollout is not None
|
||||
assert self.config.actor is not None
|
||||
|
||||
# Create actor model
|
||||
engine_factory = EngineFactory(args)
|
||||
self.actor = engine_factory.make_engine(self.config.actor)
|
||||
|
||||
self.actor_tokenizer = load_hf_tokenizer(self.config.actor.path)
|
||||
self.gconfig = args.rollout.gconfig
|
||||
|
||||
# Create reference model is specified
|
||||
self.ref = None
|
||||
if self.config.ref is not None:
|
||||
self.ref = engine_factory.make_engine(self.config.ref)
|
||||
|
||||
# Create a client to generate responses and update weights
|
||||
client_factory = LLMClientFactory(args)
|
||||
self.llm_client = client_factory.make_client(args.rollout.llm_client)
|
||||
|
||||
# Algorithm related attributes
|
||||
self.kl_ctl = self.config.kl_ctl
|
||||
self.discount = self.config.discount
|
||||
self.gae_lambda = self.config.gae_lambda
|
||||
self.adv_norm = self.config.adv_norm
|
||||
self.max_reward_clip = self.config.max_reward_clip
|
||||
self.group_adv_norm = self.config.group_adv_norm
|
||||
self.group_size = args.rollout.gconfig.n_samples
|
||||
self.max_head_offpolicyness = args.rollout.max_head_offpolicyness
|
||||
self.reward_bias = self.config.reward_bias
|
||||
self.reward_scaling = self.config.reward_scaling
|
||||
self.max_reward_clip = self.config.max_reward_clip
|
||||
|
||||
self.save_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=self.args.exp_ctrl.save_freq_epochs,
|
||||
freq_step=self.args.exp_ctrl.save_freq_steps,
|
||||
freq_sec=self.args.exp_ctrl.save_freq_secs,
|
||||
)
|
||||
self.eval_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=self.args.exp_ctrl.eval_freq_epochs,
|
||||
freq_step=self.args.exp_ctrl.eval_freq_steps,
|
||||
freq_sec=self.args.exp_ctrl.eval_freq_steps,
|
||||
)
|
||||
self.summary_writer = init_stats_logging(args)
|
||||
|
||||
def train(self, resume_from_checkpoint=None):
|
||||
# TODO: handle recover
|
||||
self.create_train_dataloader()
|
||||
assert self.rollout_controller is not None
|
||||
assert self.train_dataloader is not None
|
||||
|
||||
total_epochs = self.args.exp_ctrl.total_train_epochs
|
||||
steps_per_epoch = len(self.train_dataloader)
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=total_epochs,
|
||||
dataset_size=len(self.train_dataset),
|
||||
train_batch_size=self.args.train_dataset.batch_size,
|
||||
)
|
||||
|
||||
# Setting up models.
|
||||
self.actor.init_distributed(None, ft_spec)
|
||||
self.actor.load_model_from_hf(self.config.actor.path)
|
||||
self.actor.eval()
|
||||
if self.ref is not None:
|
||||
self.ref.init_distributed(None, ft_spec)
|
||||
self.ref.load_model_from_hf(self.config.ref.path)
|
||||
self.ref.eval()
|
||||
self.llm_client.wait_until_servers_ready()
|
||||
self.actor.update_weights_to(self.llm_client)
|
||||
|
||||
# Start rollout for asynchronous RL.
|
||||
if self.config.async_training:
|
||||
self.rollout_controller.start_generate_loop()
|
||||
|
||||
# Main RL training loop.
|
||||
total_epochs = self.args.exp_ctrl.total_train_epochs
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
global_step = 0
|
||||
warmup_steps = self.max_head_offpolicyness + 1
|
||||
assert steps_per_epoch >= warmup_steps
|
||||
start_time = time.monotonic()
|
||||
for epoch in range(total_epochs):
|
||||
for step, data in enumerate(self.train_dataloader):
|
||||
timing_stats = {}
|
||||
with record_timing("timeperf/rollout", timing_stats):
|
||||
if self.config.async_training:
|
||||
self.rollout_controller.submit(data)
|
||||
# Submitted data will not actually be sent for rollout.
|
||||
# The rollout controller over-subscribe the data to
|
||||
# ensure that there are enough data being generated.
|
||||
if epoch == 0 and step < warmup_steps:
|
||||
continue
|
||||
# Wait until enough trajectories has been collected.
|
||||
trajs = self.rollout_controller.prepare_batch(
|
||||
batch_size=self.args.train_dataset.batch_size // world_size
|
||||
)
|
||||
else:
|
||||
# Run batched rollout by submitting requests to LLM servers
|
||||
trajs = self.rollout_controller.generate_batch(
|
||||
batch_size=len(data),
|
||||
env_options=data,
|
||||
)
|
||||
|
||||
with record_timing("timeperf/train_step", timing_stats):
|
||||
# Run RL training and update weights.
|
||||
mb_stats = self._train_step(trajs)
|
||||
self.actor.step_lr_scheduler()
|
||||
|
||||
with record_timing("timeperf/sync_weights", timing_stats):
|
||||
# Synchronize weights to the client.
|
||||
self.actor.update_weights_to(self.llm_client)
|
||||
# Update model version
|
||||
name = names.model_version(
|
||||
self.args.experiment_name, self.args.trial_name, "actor"
|
||||
)
|
||||
name_resolve.add(name, str(global_step + 1), replace=True)
|
||||
|
||||
if self.save_ctl.check(
|
||||
epochs=int(step == steps_per_epoch - 1), steps=1
|
||||
):
|
||||
if dist.get_rank() == 0:
|
||||
logger.info("Saving model ...")
|
||||
with record_timing("timeperf/save", timing_stats):
|
||||
save_path = os.path.join(
|
||||
constants.get_save_path(self.args), "actor"
|
||||
)
|
||||
self.actor.save_model_to_hf(
|
||||
save_path,
|
||||
tokenizer=self.actor_tokenizer,
|
||||
base_model_path=self.config.actor.path,
|
||||
)
|
||||
|
||||
assert len(mb_stats) == self.config.ppo_n_minibatches
|
||||
log_step = self.config.ppo_n_minibatches * global_step
|
||||
for i, stats in enumerate(mb_stats):
|
||||
log_wandb_tensorboard(log_step + i, stats, self.summary_writer)
|
||||
log_wandb_tensorboard(log_step, timing_stats, self.summary_writer)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(
|
||||
f"Epoch {epoch+1}/{total_epochs} "
|
||||
f"Step {step+1}/{steps_per_epoch} "
|
||||
f"Train step {global_step + 1}/{total_epochs * steps_per_epoch - warmup_steps} done."
|
||||
)
|
||||
logger.info(
|
||||
f"Detailed time stats: \n{tabulate_stats(timing_stats, floatfmt='.2f')}"
|
||||
)
|
||||
for i, stats in enumerate(mb_stats):
|
||||
logger.info(
|
||||
f"GRPO training stats ({i + 1}/{len(mb_stats)}):\n{tabulate_stats(stats)}"
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(
|
||||
f"Training completes! Total time elapsed {time.monotonic() - start_time:.2f}."
|
||||
)
|
||||
if self.config.async_training:
|
||||
self.rollout_controller.stop_generate_loop()
|
||||
|
||||
close_wandb_tensorboard(self.summary_writer)
|
||||
|
||||
def _train_step(self, trajs: List[Trajectory]):
|
||||
rollout = concat_padded_tensors([traj.data for traj in trajs])
|
||||
rollout = to_device(rollout, torch.cuda.current_device())
|
||||
|
||||
# Marks which sequence does not has an EOS token, i.e.,
|
||||
# generation is truncated by the configured maximum generation length
|
||||
batch_tokens = rollout["input_ids"]
|
||||
seq_no_eos_mask = (
|
||||
batch_tokens[:, -1] != self.actor_tokenizer.eos_token_id
|
||||
).logical_and(batch_tokens[:, -1] != self.actor_tokenizer.pad_token_id)
|
||||
|
||||
# Remove padding to use flash-attn
|
||||
attn_mask = rollout["attention_mask"]
|
||||
input_ids, _, cu_seqlens, max_seqlen = unpad_input(
|
||||
rollout["input_ids"], attn_mask
|
||||
)
|
||||
position_ids = compute_varlen_position_indices(input_ids.shape[0], cu_seqlens)
|
||||
|
||||
# Transformer forward input data
|
||||
model_inputs = dict(
|
||||
input_ids=input_ids.unsqueeze(0),
|
||||
attention_mask=None,
|
||||
position_ids=position_ids.unsqueeze(0),
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
use_cache=False,
|
||||
)
|
||||
old_logp, *_ = unpad_input(rollout["logprobs"], attn_mask)
|
||||
prompt_mask, *_ = unpad_input(rollout["prompt_mask"], attn_mask)
|
||||
# Shift logprobs and mask for computing loss.
|
||||
loss_mask = prompt_mask.logical_not()
|
||||
loss_mask = torch.roll(loss_mask, shifts=-1)
|
||||
old_logp = torch.roll(old_logp, shifts=-1)
|
||||
|
||||
input_ids = model_inputs["input_ids"].squeeze(0)
|
||||
n_seqs = seq_no_eos_mask.shape[0]
|
||||
assert n_seqs == self.local_train_batch_size * self.group_size, (
|
||||
n_seqs,
|
||||
self.group_size,
|
||||
self.local_train_batch_size,
|
||||
)
|
||||
|
||||
# Run reference model forward
|
||||
def calc_logprobs(logits, input_data):
|
||||
logits = logits.squeeze(0).float()
|
||||
labels = torch.roll(input_data["input_ids"].squeeze(0), shifts=-1)
|
||||
logits /= self.gconfig.temperature
|
||||
logprobs = gather_logprobs(logits, labels)
|
||||
return logprobs.unsqueeze(0)
|
||||
|
||||
if self.ref is not None and self.config.kl_ctl != 0.0:
|
||||
ref_logp = self.ref.forward(
|
||||
model_inputs,
|
||||
mb_spec=self.config.mb_spec,
|
||||
post_hook=calc_logprobs,
|
||||
).squeeze(0)
|
||||
else:
|
||||
ref_logp = torch.zeros_like(input_ids, dtype=torch.float32)
|
||||
|
||||
# Recompute logprobs using the current actor model.
|
||||
prox_logp = None
|
||||
if self.config.recompute_logprob:
|
||||
_logp = self.actor.forward(
|
||||
model_inputs,
|
||||
mb_spec=self.config.mb_spec,
|
||||
post_hook=calc_logprobs,
|
||||
).squeeze(0)
|
||||
if self.config.use_decoupled_loss:
|
||||
prox_logp = _logp
|
||||
else:
|
||||
# Overwrite the logp returned by the inference engine
|
||||
old_logp = _logp
|
||||
|
||||
# Compute rewards using the reward function in synchronous RLVR pipeline.
|
||||
reward_score = rollout["rewards"]
|
||||
reward_score = (reward_score + self.reward_bias) * self.reward_scaling
|
||||
reward_score = torch.clip(reward_score, max=self.max_reward_clip)
|
||||
if self.config.group_reward_norm:
|
||||
for i in range(n_seqs // self.group_size):
|
||||
s = slice(i * self.group_size, (i + 1) * self.group_size)
|
||||
r = reward_score[s]
|
||||
reward_score[s] = (r - r.mean()) / (r.std() + 1e-9)
|
||||
|
||||
# Apply the mask to log probabilities.
|
||||
ref_logp *= loss_mask
|
||||
old_logp *= loss_mask
|
||||
|
||||
# Compute KL-regularized rewards and GAEs.
|
||||
cu_seqlens = model_inputs["cu_seqlens"]
|
||||
seq_no_eos_mask = seq_no_eos_mask
|
||||
kl_rewards, rewards = ppo_functional.get_packed_rewards(
|
||||
kl_ctl=self.kl_ctl,
|
||||
clip_reward_value=self.max_reward_clip,
|
||||
log_probs=old_logp,
|
||||
ref_log_probs=ref_logp,
|
||||
reward_score=reward_score,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_no_eos_mask=seq_no_eos_mask,
|
||||
mask_no_eos_with_zero=self.config.mask_no_eos_with_zero,
|
||||
)
|
||||
advantages, _ = ppo_functional.get_packed_advantages_and_returns(
|
||||
gamma=self.discount,
|
||||
lam=self.gae_lambda,
|
||||
values=torch.zeros(
|
||||
input_ids.shape[0] + n_seqs,
|
||||
device=input_ids.device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
rewards=rewards,
|
||||
short1cu_seqlens=cu_seqlens,
|
||||
seq_no_eos_mask=seq_no_eos_mask,
|
||||
)
|
||||
|
||||
# Optionally perform advantage normalization.
|
||||
if self.adv_norm:
|
||||
if self.group_adv_norm:
|
||||
n_samples = len(cu_seqlens) - 1
|
||||
assert n_samples % self.group_size == 0
|
||||
adv_list = []
|
||||
for i in range(0, n_samples, self.group_size):
|
||||
adv_list.append(
|
||||
masked_normalization(
|
||||
advantages[cu_seqlens[i] : cu_seqlens[i + self.group_size]],
|
||||
loss_mask[cu_seqlens[i] : cu_seqlens[i + self.group_size]],
|
||||
all_reduce=False,
|
||||
)
|
||||
)
|
||||
advantages = torch.cat(adv_list, 0)
|
||||
else:
|
||||
advantages = masked_normalization(advantages, loss_mask)
|
||||
|
||||
# Prepare data to be splitted into mini-batches.
|
||||
global_batch = dict(
|
||||
**model_inputs,
|
||||
old_logp=old_logp,
|
||||
advantages=advantages,
|
||||
loss_mask=loss_mask,
|
||||
prox_logp=prox_logp,
|
||||
)
|
||||
input_lens = model_inputs["cu_seqlens"][1:] - model_inputs["cu_seqlens"][:-1]
|
||||
|
||||
all_stats = []
|
||||
with stats_tracker.scope("actor"):
|
||||
########## Logging code starts ##########
|
||||
result_denominators = {
|
||||
"correct_n_seqs": (reward_score > 0).bool(),
|
||||
"incorrect_n_seqs": (reward_score <= 0).bool(),
|
||||
}
|
||||
global_denominators = dict(
|
||||
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
|
||||
n_tokens=torch.ones_like(loss_mask, dtype=torch.bool),
|
||||
n_valid_tokens=loss_mask.bool(),
|
||||
**result_denominators,
|
||||
)
|
||||
stats_tracker.denominator(**global_denominators)
|
||||
stats_tracker.stat(
|
||||
correct_seq_len=input_lens.float(), denominator="correct_n_seqs"
|
||||
)
|
||||
stats_tracker.stat(
|
||||
incorrect_seq_len=input_lens.float(), denominator="incorrect_n_seqs"
|
||||
)
|
||||
|
||||
stats = dict(
|
||||
advantages=advantages,
|
||||
kl_rewards=kl_rewards,
|
||||
final_reward=rewards,
|
||||
)
|
||||
stats_tracker.stat(**stats, denominator="n_valid_tokens")
|
||||
|
||||
prompt_lens = []
|
||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||
prompt_lens.append(prompt_mask[s:e].sum())
|
||||
prompt_lens = torch.tensor(prompt_lens, device=reward_score.device)
|
||||
seq_stats = dict(
|
||||
no_eos_ratios=seq_no_eos_mask.float(),
|
||||
task_reward=reward_score,
|
||||
prompt_len=prompt_lens.float(),
|
||||
seq_len=input_lens.float(),
|
||||
)
|
||||
stats_tracker.stat(**seq_stats, denominator="n_seqs")
|
||||
scalars = dict(
|
||||
mask_no_eos_with_zero=self.config.mask_no_eos_with_zero,
|
||||
eps_clip=self.config.eps_clip,
|
||||
use_prox_logp=prox_logp is not None,
|
||||
)
|
||||
if self.config.c_clip is not None:
|
||||
scalars["c_clip"] = self.config.c_clip
|
||||
scalars["use_dual_clip"] = 1
|
||||
else:
|
||||
scalars["use_dual_clip"] = 0
|
||||
if self.config.behav_imp_weight_cap is not None:
|
||||
scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap
|
||||
stats_tracker.scalar(**scalars)
|
||||
|
||||
global_stats = stats_tracker.export()
|
||||
for k in global_denominators:
|
||||
global_stats.pop(f"actor/{k}")
|
||||
########## Logging code ends ##########
|
||||
|
||||
mb_inputs = split_dict_tensor_with_cu_seqlens(
|
||||
global_batch,
|
||||
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
|
||||
)
|
||||
for mb in mb_inputs.mbs:
|
||||
model_inputs = {k: mb[k] for k in model_inputs}
|
||||
train_stat = self.actor.train_batch(
|
||||
mb,
|
||||
loss_fn=functools.partial(
|
||||
grpo_loss_fn,
|
||||
temperature=self.gconfig.temperature,
|
||||
eps_clip=self.config.eps_clip,
|
||||
c_clip=self.config.c_clip,
|
||||
behav_imp_weight_cap=self.config.behav_imp_weight_cap,
|
||||
),
|
||||
mb_spec=self.config.mb_spec,
|
||||
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
|
||||
)
|
||||
stats_tracker.scalar(**train_stat)
|
||||
all_stats.append(stats_tracker.export())
|
||||
all_stats[0].update(global_stats)
|
||||
return all_stats
|
||||
|
||||
|
||||
def grpo_loss_fn(
|
||||
logits: torch.Tensor,
|
||||
input_data: Dict,
|
||||
temperature: float,
|
||||
eps_clip: float,
|
||||
c_clip: float | None,
|
||||
behav_imp_weight_cap: float | None,
|
||||
):
|
||||
"""Loss function for actor step, all inputs should be splitted into
|
||||
pipeline micro batches, returns loss and logging stats."""
|
||||
input_ids = input_data["input_ids"].squeeze(0)
|
||||
cu_seqlens = input_data["cu_seqlens"]
|
||||
old_logp = input_data["old_logp"]
|
||||
advantages = input_data["advantages"]
|
||||
loss_mask = input_data["loss_mask"]
|
||||
prox_logp = input_data["prox_logp"]
|
||||
|
||||
logits = logits.squeeze(0).float()
|
||||
logits /= temperature
|
||||
logprobs = gather_logprobs(logits, torch.roll(input_ids, shifts=-1))
|
||||
loss, stat = ppo_functional.actor_loss_fn(
|
||||
logprobs=logprobs,
|
||||
old_logprobs=old_logp,
|
||||
advantages=advantages,
|
||||
eps_clip=eps_clip,
|
||||
loss_mask=loss_mask,
|
||||
c_clip=c_clip,
|
||||
proximal_logprobs=prox_logp,
|
||||
behav_imp_weight_cap=behav_imp_weight_cap,
|
||||
)
|
||||
|
||||
entropy = calc_entropy(logits=logits, cu_seqlens=cu_seqlens)
|
||||
|
||||
# Log training statistics
|
||||
stats_tracker.denominator(
|
||||
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
||||
n_valid_tokens=loss_mask.bool(),
|
||||
clipped_tokens=stat["clip_mask"],
|
||||
dual_clipped_tokens=stat["dual_clip_mask"],
|
||||
)
|
||||
|
||||
stats_tracker.stat(
|
||||
importance_weight=stat["importance_weight"],
|
||||
approx_kl=stat["approx_kl"],
|
||||
new_logp=logprobs.detach(),
|
||||
old_logp=old_logp,
|
||||
entropy=entropy.float(),
|
||||
actor_loss=stat["loss"],
|
||||
clip_ratio=stat["clip_mask"].float(),
|
||||
dual_clip_ratio=stat["dual_clip_mask"].float(),
|
||||
denominator="n_valid_tokens",
|
||||
)
|
||||
if "behave_imp_weight" in stat:
|
||||
stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"])
|
||||
stats_tracker.stat(
|
||||
behave_imp_weight=stat["behave_imp_weight"],
|
||||
behave_approx_kl=stat["behave_approx_kl"],
|
||||
denominator="unclipped_behave_tokens",
|
||||
)
|
||||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
vocab_max_logits=vocab_max_logits,
|
||||
denominator="n_tokens",
|
||||
)
|
||||
|
||||
clip_mask = stat["clip_mask"]
|
||||
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
|
||||
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
|
||||
stats_tracker.stat(
|
||||
clipped_new_logp=clipped_new_logp,
|
||||
clipped_old_logp=clipped_old_logp,
|
||||
denominator="clipped_tokens",
|
||||
)
|
||||
return loss
|
|
@ -0,0 +1,269 @@
|
|||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
from datasets import Dataset
|
||||
|
||||
from arealite.api.cli_args import TrainerConfig, TrainingArgs
|
||||
from arealite.api.engine_api import EngineFactory
|
||||
from arealite.api.trainer_api import Trainer
|
||||
from arealite.system.rollout_controller import RolloutController
|
||||
from arealite.utils import (
|
||||
close_wandb_tensorboard,
|
||||
compute_varlen_position_indices,
|
||||
gather_logprobs,
|
||||
init_stats_logging,
|
||||
list_of_dict2dict_of_list,
|
||||
log_wandb_tensorboard,
|
||||
record_timing,
|
||||
)
|
||||
from realhf.api.core.data_api import load_hf_tokenizer, tabulate_stats
|
||||
from realhf.api.core.model_api import FinetuneSpec
|
||||
from realhf.base import logging, stats_tracker, timeutil
|
||||
|
||||
logger = logging.getLogger("SFT Trainer")
|
||||
|
||||
|
||||
def compute_packed_sft_loss(
|
||||
logits: torch.Tensor,
|
||||
input_: Dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
packed_input_ids: torch.Tensor = input_["input_ids"].squeeze(dim=0)
|
||||
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
||||
input_lens: torch.Tensor = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
|
||||
prompt_mask = input_["prompt_mask"].squeeze(dim=0)
|
||||
logits = logits.squeeze(dim=0).float()
|
||||
|
||||
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1))
|
||||
logprobs = torch.where(prompt_mask, 0, logprobs)
|
||||
|
||||
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
|
||||
|
||||
with torch.no_grad():
|
||||
seqlogp = torch.zeros(
|
||||
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
|
||||
)
|
||||
for i in range(cu_seqlens.shape[0] - 1):
|
||||
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
|
||||
cu_seqlens,
|
||||
logprobs.shape,
|
||||
)
|
||||
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
|
||||
m.numel() - m.count_nonzero()
|
||||
)
|
||||
|
||||
## Loggin stats
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
|
||||
),
|
||||
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
||||
n_valid_tokens=prompt_mask.logical_not(),
|
||||
prompt_tokens=prompt_mask,
|
||||
)
|
||||
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
|
||||
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")
|
||||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
vocab_max_logits=vocab_max_logits,
|
||||
denominator="n_tokens",
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SFTTrainer(Trainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArgs,
|
||||
trainer_config: TrainerConfig,
|
||||
train_dataset: Dataset,
|
||||
valid_dataset: Optional[Dataset] = None,
|
||||
rollout_controller: Optional[RolloutController] = None,
|
||||
):
|
||||
super().__init__(
|
||||
args, trainer_config, train_dataset, valid_dataset, rollout_controller
|
||||
)
|
||||
|
||||
self.config = config = trainer_config.sft
|
||||
assert config is not None
|
||||
|
||||
engine_factory = EngineFactory(args)
|
||||
self.model = engine_factory.make_engine(config.model)
|
||||
self.tokenizer = load_hf_tokenizer(config.model.path)
|
||||
|
||||
self.mb_spec = config.mb_spec
|
||||
|
||||
self.save_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=self.args.exp_ctrl.save_freq_epochs,
|
||||
freq_step=self.args.exp_ctrl.save_freq_steps,
|
||||
freq_sec=self.args.exp_ctrl.save_freq_secs,
|
||||
)
|
||||
self.eval_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=self.args.exp_ctrl.eval_freq_epochs,
|
||||
freq_step=self.args.exp_ctrl.eval_freq_steps,
|
||||
freq_sec=self.args.exp_ctrl.eval_freq_steps,
|
||||
)
|
||||
self.summary_writer = init_stats_logging(args)
|
||||
|
||||
def _tokenize(self, strs: List[str]):
|
||||
# tokenize strings into unpadded tokens with lengths.
|
||||
return self.tokenizer(
|
||||
strs,
|
||||
padding=False,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
max_length=self.mb_spec.max_tokens_per_mb,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
def _get_packed_input(self, data: List[Dict[str, Any]]):
|
||||
data: Dict[str, List[Any]] = list_of_dict2dict_of_list(data)
|
||||
|
||||
tokenized_seqs = data["seq"]
|
||||
tokenized_prompts = data["prompt"]
|
||||
prompt_lens = [len(prompt) for prompt in tokenized_prompts]
|
||||
input_lens = [len(prompt) for prompt in tokenized_seqs]
|
||||
|
||||
input_lens = torch.tensor(input_lens, dtype=torch.int)
|
||||
input_ids = [torch.tensor(seq, dtype=torch.long) for seq in tokenized_seqs]
|
||||
|
||||
prompt_mask = []
|
||||
for input_len, prompt_len in zip(input_lens, prompt_lens):
|
||||
assert input_len >= prompt_len, (input_len, prompt_len)
|
||||
pm = [1] * prompt_len + [0] * (input_len - prompt_len)
|
||||
prompt_mask.append(torch.tensor(pm, dtype=torch.bool))
|
||||
|
||||
cu_seqlens = torch.nn.functional.pad(
|
||||
input_lens.cumsum(0, dtype=torch.int), (1, 0)
|
||||
)
|
||||
max_seqlen = int(torch.max(input_lens).item())
|
||||
packed_input_ids = torch.cat(input_ids, dim=0)
|
||||
prompt_mask = torch.cat(prompt_mask, dim=0)
|
||||
total_seqlen = int(cu_seqlens[-1].item())
|
||||
position_ids = compute_varlen_position_indices(total_seqlen, cu_seqlens)
|
||||
|
||||
return dict(
|
||||
input_ids=packed_input_ids.unsqueeze(0).cuda(),
|
||||
attention_mask=None,
|
||||
position_ids=position_ids.unsqueeze(0).cuda(),
|
||||
prompt_mask=prompt_mask.unsqueeze(0).cuda(),
|
||||
cu_seqlens=cu_seqlens.cuda(),
|
||||
max_seqlen=max_seqlen,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
def train(self, resume_from_checkpoint=None):
|
||||
self.create_train_dataloader()
|
||||
|
||||
total_epochs = self.args.exp_ctrl.total_train_epochs
|
||||
steps_per_epoch = len(self.train_dataloader)
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=steps_per_epoch,
|
||||
dataset_size=len(self.train_dataset),
|
||||
train_batch_size=self.args.train_dataset.batch_size,
|
||||
)
|
||||
|
||||
self.model.init_distributed(None, ft_spec)
|
||||
self.model.load_model_from_hf(self.config.model.path)
|
||||
self.model.train()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
global_step = 0
|
||||
start_time = time.monotonic()
|
||||
for epoch in range(total_epochs):
|
||||
for step, data in enumerate(self.train_dataloader):
|
||||
timing_stats = {}
|
||||
with record_timing("timeperf/data_processing", timing_stats):
|
||||
packed_input_data = self._get_packed_input(data)
|
||||
|
||||
with record_timing("timeperf/train_step", timing_stats):
|
||||
with stats_tracker.scope("sft"):
|
||||
stats = self.model.train_batch(
|
||||
input_=packed_input_data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"]
|
||||
.logical_not()
|
||||
.count_nonzero(),
|
||||
mb_spec=self.mb_spec,
|
||||
)
|
||||
self.model.step_lr_scheduler()
|
||||
stats_tracker.scalar(**stats)
|
||||
|
||||
if self.save_ctl.check(
|
||||
epochs=int(step == steps_per_epoch - 1), steps=1
|
||||
):
|
||||
if dist.get_rank() == 0:
|
||||
logger.info("Saving model ...")
|
||||
|
||||
with record_timing("timeperf/save", timing_stats):
|
||||
save_path = self.get_save_checkpoint_path(
|
||||
epoch, step, global_step
|
||||
)
|
||||
self.model.save_model_to_hf(save_path, self.tokenizer)
|
||||
|
||||
if self.eval_ctl.check(
|
||||
epochs=int(step == steps_per_epoch - 1), steps=1
|
||||
):
|
||||
if dist.get_rank() == 0:
|
||||
logger.info("Running evaluation ...")
|
||||
with record_timing("timeperf/eval", timing_stats):
|
||||
self._eval(global_step)
|
||||
|
||||
training_stats = stats_tracker.export()
|
||||
training_stats.update(timing_stats)
|
||||
log_wandb_tensorboard(global_step, training_stats, self.summary_writer)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(
|
||||
f"Epoch {epoch} Step {step} GlobalStep {global_step} done. Detailed time stats:"
|
||||
f"\n{tabulate_stats(timing_stats, floatfmt='.2f')}"
|
||||
)
|
||||
global_step += 1
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(
|
||||
f"Training completes! Total time elapsed {time.monotonic() - start_time:.2f}."
|
||||
)
|
||||
|
||||
close_wandb_tensorboard(self.summary_writer)
|
||||
|
||||
def _eval(self, global_step):
|
||||
self.create_valid_dataloader()
|
||||
if self.valid_dataloader is None:
|
||||
return
|
||||
|
||||
self.eval_data_generator = iter(self.valid_dataloader)
|
||||
n_steps = len(self.valid_dataloader)
|
||||
|
||||
losses = []
|
||||
|
||||
start_time = time.monotonic()
|
||||
for step in range(n_steps):
|
||||
data = next(self.eval_data_generator)
|
||||
packed_input_data = self._get_packed_input(data)
|
||||
with stats_tracker.scope("sft-eval"):
|
||||
avg_loss = self.model.eval_batch(
|
||||
input_=packed_input_data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"]
|
||||
.logical_not()
|
||||
.count_nonzero(),
|
||||
mb_spec=self.mb_spec,
|
||||
)
|
||||
losses.append(avg_loss)
|
||||
val_loss = torch.mean(torch.stack(losses))
|
||||
|
||||
logger.info(
|
||||
f"Global step: {global_step} evaluation time cost {time.monotonic() - start_time:.2f} "
|
||||
f"val_loss={val_loss:.4f}"
|
||||
)
|
|
@ -0,0 +1,228 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from queue import Empty as QueueEmpty
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
# NOTE: the start method of mp should be fork rather than spawn
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from arealite.api.cli_args import RolloutConfig, TrainingArgs
|
||||
from arealite.api.io_struct import Trajectory
|
||||
from arealite.api.llm_client_api import LLMClientFactory
|
||||
from arealite.api.rollout_api import RolloutCollector
|
||||
from arealite.system.rollout_worker import RolloutWorker
|
||||
from realhf.base import datapack, logging, network
|
||||
from realhf.system.push_pull_stream import ZMQJsonPuller, ZMQJsonPusher
|
||||
|
||||
logger = logging.getLogger("Rollout Controller")
|
||||
|
||||
|
||||
class RolloutController:
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArgs,
|
||||
config: RolloutConfig,
|
||||
collector: RolloutCollector,
|
||||
):
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.gconfig = config.gconfig
|
||||
self.collector = collector
|
||||
|
||||
# Process-based execution
|
||||
self._exiting = mp.Event()
|
||||
self._lock = mp.Lock()
|
||||
self._buffer: List[List[Trajectory]] = []
|
||||
self._version = 0
|
||||
|
||||
# Worker processes for asynchronous rollout
|
||||
self._worker_processes: List[mp.Process] = []
|
||||
|
||||
self.llm_client = LLMClientFactory(args).make_client(config.llm_client)
|
||||
|
||||
# PushPull communication for data to workers
|
||||
self._data_pusher = None
|
||||
self._data_pusher_port = None
|
||||
self._puller = None
|
||||
self._puller_port = None
|
||||
self._collector_thread = None
|
||||
|
||||
################### User Interfaces Start #################
|
||||
|
||||
def generate_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
env_options: Optional[List[Any]] = None,
|
||||
seeds: Optional[List[int]] = None,
|
||||
) -> List[Trajectory]:
|
||||
"""Run episodes in batch using the collector directly (for compatibility)."""
|
||||
if env_options is None:
|
||||
env_options = [None] * batch_size
|
||||
else:
|
||||
assert len(env_options) == batch_size
|
||||
if seeds is None:
|
||||
seeds = [None] * batch_size
|
||||
else:
|
||||
assert len(seeds) == batch_size
|
||||
|
||||
async def run_parallel_gen():
|
||||
worker = RolloutWorker(
|
||||
worker_id=0,
|
||||
args=self.args,
|
||||
config=self.config,
|
||||
llm_client=self.llm_client,
|
||||
)
|
||||
tasks = [
|
||||
worker._run_grouped_episode_async(None, env_option, seed)
|
||||
for env_option, seed in zip(env_options, seeds)
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return sum([r[1] for r in results], [])
|
||||
|
||||
return asyncio.run(run_parallel_gen())
|
||||
|
||||
def start_generate_loop(self):
|
||||
"""Start worker processes that run generation loops."""
|
||||
logger.info("Starting worker processes...")
|
||||
|
||||
# Start background thread to collect data from workers
|
||||
self._puller_port = network.find_free_port(
|
||||
experiment_name=self.args.experiment_name, trial_name=self.args.trial_name
|
||||
)
|
||||
self._collector_thread = threading.Thread(
|
||||
target=self._collect_from_workers, daemon=True
|
||||
)
|
||||
self._collector_thread.start()
|
||||
|
||||
# Start worker processes
|
||||
self._data_pusher_port = network.find_free_port(
|
||||
experiment_name=self.args.experiment_name, trial_name=self.args.trial_name
|
||||
)
|
||||
self._data_pusher = ZMQJsonPusher(
|
||||
host="localhost", port=self._data_pusher_port, bind=True
|
||||
)
|
||||
logger.info(f"RolloutController sending data on port {self._data_pusher_port}")
|
||||
|
||||
num_workers = self.config.num_workers
|
||||
for worker_id in range(num_workers):
|
||||
process = mp.Process(
|
||||
target=_run_worker_process,
|
||||
args=(
|
||||
worker_id,
|
||||
self.args,
|
||||
self.config,
|
||||
self._puller_port,
|
||||
self._data_pusher_port,
|
||||
),
|
||||
)
|
||||
process.start()
|
||||
self._worker_processes.append(process)
|
||||
logger.info(f"Started worker process {worker_id}")
|
||||
|
||||
def submit(self, data):
|
||||
"""Submit data to worker processes for processing."""
|
||||
if self._data_pusher is None:
|
||||
raise RuntimeError(
|
||||
"Data pusher not initialized. Call start_generate_loop() first."
|
||||
)
|
||||
|
||||
# Convert data to JSON-compatible format
|
||||
assert isinstance(data, list)
|
||||
for d in data:
|
||||
self._data_pusher.push(d)
|
||||
logger.debug(f"Submitted {len(data)} data to workers")
|
||||
|
||||
def prepare_batch(self, batch_size: int) -> List[Trajectory]:
|
||||
"""Prepare and wait for a batch of trajectories."""
|
||||
buf_size = -1
|
||||
while buf_size < batch_size:
|
||||
with self._lock:
|
||||
buf_size = len(self._buffer)
|
||||
time.sleep(0.1)
|
||||
with self._lock:
|
||||
self._buffer = sorted(
|
||||
self._buffer, key=lambda x: np.mean([xx.stats.start_time for xx in x])
|
||||
)
|
||||
data, self._buffer = self._buffer[:batch_size], self._buffer[batch_size:]
|
||||
return datapack.flat2d(data)
|
||||
|
||||
def stop_generate_loop(self):
|
||||
"""Stop worker processes and cleanup."""
|
||||
logger.info("Stopping worker processes...")
|
||||
self._exiting.set()
|
||||
|
||||
# Stop worker processes gracefully first, then forcefully if needed
|
||||
for i, process in enumerate(self._worker_processes):
|
||||
if process.is_alive():
|
||||
logger.info(f"Terminating worker process {i}...")
|
||||
try:
|
||||
process.terminate()
|
||||
process.join(timeout=1.0)
|
||||
except Exception:
|
||||
process.kill()
|
||||
self._worker_processes.clear()
|
||||
|
||||
if self._collector_thread is not None:
|
||||
# Wait for the thread to finish (with optional timeout)
|
||||
self._collector_thread.join(timeout=1.0)
|
||||
|
||||
# Close communication channels
|
||||
if self._puller:
|
||||
self._puller.close()
|
||||
if self._data_pusher:
|
||||
self._data_pusher.close()
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
################## User Interfaces End ##################
|
||||
|
||||
def _collect_from_workers(self):
|
||||
"""Background thread to collect trajectories from workers."""
|
||||
# Find a free port
|
||||
self._puller = ZMQJsonPuller(host="localhost", port=self._puller_port)
|
||||
logger.info(f"RolloutController listening on port {self._puller_port}")
|
||||
|
||||
while not self._exiting.is_set():
|
||||
try:
|
||||
# Pull data from workers
|
||||
data = self._puller.pull(timeout_ms=100)
|
||||
# Convert back to Trajectory objects
|
||||
trajs = [
|
||||
Trajectory.from_json_compatible(traj_data)
|
||||
for traj_data in data["trajs"]
|
||||
]
|
||||
# Add to buffer
|
||||
with self._lock:
|
||||
self._buffer.append(trajs)
|
||||
logger.debug(
|
||||
f"Received {len(trajs)} trajectories from worker {data['worker_id']}"
|
||||
)
|
||||
except QueueEmpty:
|
||||
# No data available, continue
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
except Exception as e:
|
||||
if not self._exiting.is_set():
|
||||
logger.error(f"Error in collector thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
break
|
||||
|
||||
|
||||
def _run_worker_process(worker_id: int, args, config, puller_port, data_pusher_port):
|
||||
worker = RolloutWorker(
|
||||
worker_id=worker_id,
|
||||
args=args,
|
||||
config=config,
|
||||
pusher_host="localhost",
|
||||
pusher_port=puller_port,
|
||||
data_puller_host="localhost",
|
||||
data_puller_port=data_pusher_port,
|
||||
)
|
||||
logger.info(f"Worker {worker_id} starting generation loop...")
|
||||
worker.run_generation_loop()
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import asyncio
|
||||
import queue
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
|
||||
from arealite.api.cli_args import RolloutConfig, TrainingArgs
|
||||
from arealite.api.io_struct import Trajectory
|
||||
from arealite.api.llm_client_api import LLMClient, LLMClientFactory
|
||||
from arealite.api.rollout_api import RolloutCollectorFactory
|
||||
from realhf.base import logging, name_resolve, names
|
||||
from realhf.base.monitor import RolloutStat
|
||||
from realhf.system.push_pull_stream import ZMQJsonPuller, ZMQJsonPusher
|
||||
|
||||
logger = logging.getLogger("RolloutWorker")
|
||||
|
||||
ROLLOUT_POLL_WAIT_TIME = 0.4
|
||||
|
||||
|
||||
class RolloutWorker:
|
||||
"""Standalone rollout worker that runs continuous generation loop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: int,
|
||||
args: TrainingArgs,
|
||||
config: RolloutConfig,
|
||||
llm_client: LLMClient | None = None,
|
||||
pusher_host: Optional[str] = "localhost",
|
||||
pusher_port: Optional[int] = 5555,
|
||||
data_puller_host: Optional[str] = "localhost",
|
||||
data_puller_port: Optional[int] = 5556,
|
||||
):
|
||||
self.worker_id = worker_id
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.gconfig = config.gconfig
|
||||
|
||||
# For staleness control
|
||||
self.train_batch_size = args.train_dataset.batch_size
|
||||
self.max_concurrent_rollouts = (
|
||||
config.max_concurrent_rollouts or self.train_batch_size
|
||||
)
|
||||
|
||||
self.pusher_host = pusher_host
|
||||
self.pusher_port = pusher_port
|
||||
self.data_puller_host = data_puller_host
|
||||
self.data_puller_port = data_puller_port
|
||||
|
||||
self._shutdown = False
|
||||
self.pusher = None
|
||||
self.data_puller = None
|
||||
|
||||
if llm_client is None:
|
||||
llm_client = LLMClientFactory(args).make_client(config.llm_client)
|
||||
self.llm_client = llm_client
|
||||
|
||||
def _cleanup(self):
|
||||
"""Clean up resources."""
|
||||
if self.pusher:
|
||||
self.pusher.close()
|
||||
if self.data_puller:
|
||||
self.data_puller.close()
|
||||
|
||||
def run_generation_loop(self):
|
||||
"""Run the continuous generation loop like the original _generate_loop."""
|
||||
try:
|
||||
asyncio.run(self._generate_loop())
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
async def _run_grouped_episode_async(
|
||||
self, rid: int, data: Any, seed: Optional[int] = None
|
||||
):
|
||||
"""Run grouped episode asynchronously."""
|
||||
tasks = []
|
||||
for _ in range(self.gconfig.n_samples):
|
||||
# Create collector
|
||||
factory = RolloutCollectorFactory(self.args)
|
||||
collector = factory.make_collector(self.config.collector)
|
||||
tasks += [
|
||||
collector.arun_episode(
|
||||
llm_client=self.llm_client,
|
||||
gconfig=self.gconfig.new(n_samples=1),
|
||||
env_option=data,
|
||||
seed=seed,
|
||||
)
|
||||
]
|
||||
trajs = await asyncio.gather(*tasks)
|
||||
return rid, trajs
|
||||
|
||||
def _get_model_version(self) -> int:
|
||||
name = names.model_version(
|
||||
self.args.experiment_name,
|
||||
self.args.trial_name,
|
||||
"actor",
|
||||
)
|
||||
try:
|
||||
return int(name_resolve.get(name))
|
||||
except name_resolve.NameEntryNotFoundError:
|
||||
return 0
|
||||
|
||||
async def _generate_loop(self):
|
||||
"""Main generation loop - similar to original RolloutController._generate_loop."""
|
||||
data = None
|
||||
|
||||
# Communication with main process
|
||||
self.pusher = ZMQJsonPusher(host=self.pusher_host, port=self.pusher_port)
|
||||
self.data_puller = ZMQJsonPuller(
|
||||
host=self.data_puller_host,
|
||||
port=self.data_puller_port,
|
||||
bind=False,
|
||||
)
|
||||
|
||||
rollout_stat = RolloutStat()
|
||||
rollout_tasks: Dict[int, asyncio.Task] = {}
|
||||
rid = 0
|
||||
|
||||
try:
|
||||
while not self._shutdown:
|
||||
# Load next data from controller
|
||||
if data is None:
|
||||
try:
|
||||
data = self.data_puller.pull(timeout_ms=50)
|
||||
logger.debug(f"Get data from puller: {data}")
|
||||
except queue.Empty:
|
||||
logger.debug(f"No data from puller stream.")
|
||||
|
||||
# Check capacity
|
||||
if dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
cannot_rollout_reason = []
|
||||
capacity = max(1, self.max_concurrent_rollouts // world_size)
|
||||
can_rollout = len(rollout_tasks) < capacity
|
||||
if not can_rollout:
|
||||
cannot_rollout_reason.append(
|
||||
f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}"
|
||||
)
|
||||
|
||||
# Staleness control
|
||||
version = self._get_model_version()
|
||||
ofp = self.config.max_head_offpolicyness
|
||||
sample_cnt = rollout_stat.accepted + rollout_stat.running
|
||||
expected_version = sample_cnt // self.train_batch_size
|
||||
not_staled = expected_version <= ofp + version
|
||||
can_rollout &= not_staled
|
||||
if not not_staled:
|
||||
cannot_rollout_reason.append(
|
||||
f"Staled: expected version ({expected_version}) = "
|
||||
f"global sample cnt ({sample_cnt}) // batch size ({self.train_batch_size}), "
|
||||
f"current latest version {version}, "
|
||||
f"offpolicyness {self.config.max_head_offpolicyness}."
|
||||
)
|
||||
|
||||
if not can_rollout:
|
||||
logger.debug(
|
||||
f"Worker {self.worker_id}: Cannot submit new rollouts. "
|
||||
+ "\n".join(cannot_rollout_reason)
|
||||
)
|
||||
|
||||
# Create new rollout task
|
||||
if can_rollout and data is not None:
|
||||
task = asyncio.create_task(
|
||||
self._run_grouped_episode_async(rid, data)
|
||||
)
|
||||
rollout_tasks[rid] = task
|
||||
|
||||
rollout_stat.submitted += 1
|
||||
rollout_stat.running += 1
|
||||
logger.debug(
|
||||
f"Worker {self.worker_id}: Submit rollout rid {rid}. "
|
||||
f"Submit: {rollout_stat.submitted}, "
|
||||
f"running: {rollout_stat.running}, "
|
||||
f"accepted: {rollout_stat.accepted}."
|
||||
)
|
||||
|
||||
rid += 1
|
||||
data = None
|
||||
|
||||
# Wait for rollout completion
|
||||
tasks = list(rollout_tasks.values())
|
||||
done = []
|
||||
if tasks:
|
||||
done, _ = await asyncio.wait(
|
||||
tasks,
|
||||
timeout=ROLLOUT_POLL_WAIT_TIME,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
else:
|
||||
await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME)
|
||||
|
||||
# Collect done results
|
||||
for task in done:
|
||||
task_rid, trajs = await task
|
||||
trajs: List[Trajectory]
|
||||
rollout_tasks.pop(task_rid)
|
||||
rollout_stat.running -= 1
|
||||
|
||||
# Filter data according to episodic return
|
||||
ret = np.mean([traj.stats.total_reward for traj in trajs])
|
||||
accepted = ret >= self.config.filter_reward_lb
|
||||
accepted &= ret <= self.config.filter_reward_ub
|
||||
|
||||
if accepted:
|
||||
# Push trajectories to main process
|
||||
trajectory_data = {
|
||||
"worker_id": self.worker_id,
|
||||
"trajs": [traj.to_json_compatible() for traj in trajs],
|
||||
}
|
||||
self.pusher.push(trajectory_data)
|
||||
rollout_stat.accepted += 1
|
||||
|
||||
logger.debug(
|
||||
f"Worker {self.worker_id}: Finish rollout {task_rid}. "
|
||||
f"Submit: {rollout_stat.submitted}, "
|
||||
f"running: {rollout_stat.running}, "
|
||||
f"accepted: {rollout_stat.accepted}."
|
||||
)
|
||||
finally:
|
||||
# Cancel remaining tasks
|
||||
for task in rollout_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import time
|
||||
|
||||
from arealite.api.io_struct import LLMRequest, LLMResponse, LLMServerInfo
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
from realhf.base import logging, pkg_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if pkg_version.is_available("sglang"):
|
||||
if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"):
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids"
|
||||
else:
|
||||
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
|
||||
|
||||
|
||||
class SGLangClient(LLMClient):
|
||||
"""SGLang implementation of LLMClient."""
|
||||
|
||||
async def agenerate(self, req: LLMRequest) -> LLMResponse:
|
||||
"""Async version of generate using aiohttp."""
|
||||
|
||||
# Convert messages to prompt
|
||||
if not req.text:
|
||||
assert req.input_ids is not None
|
||||
req.text = self.tokenizer.decode(req.input_ids)
|
||||
|
||||
# Prepare request payload
|
||||
gconfig = req.gconfig
|
||||
stop_token_ids = gconfig.stop_token_ids
|
||||
if self.tokenizer.eos_token_id not in stop_token_ids:
|
||||
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
if self.tokenizer.pad_token_id not in stop_token_ids:
|
||||
stop_token_ids.append(self.tokenizer.pad_token_id)
|
||||
|
||||
assert gconfig.n_samples == 1
|
||||
sample_params = {
|
||||
"top_p": gconfig.top_p,
|
||||
"top_k": gconfig.top_k,
|
||||
"max_new_tokens": gconfig.max_new_tokens,
|
||||
"temperature": 0.0 if gconfig.greedy else gconfig.temperature,
|
||||
"stop_token_ids": stop_token_ids,
|
||||
}
|
||||
|
||||
payload = {
|
||||
"rid": req.rid,
|
||||
"text": req.text,
|
||||
"sampling_params": sample_params,
|
||||
"return_logprob": True,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Make request
|
||||
start_time = time.perf_counter()
|
||||
accumulated_output_tokens = []
|
||||
accumulated_output_logprobs = []
|
||||
accumulated_versions = []
|
||||
|
||||
# Deal with rollout interruption
|
||||
completion = ""
|
||||
stop_reason = "length"
|
||||
|
||||
while (
|
||||
stop_reason != "stop"
|
||||
and len(accumulated_output_tokens) < gconfig.max_new_tokens
|
||||
):
|
||||
# loop until the generation is complete
|
||||
response, server_info = await self.arequest_with_retry(
|
||||
endpoint="/generate",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.client_config.request_timeout,
|
||||
)
|
||||
result = await response.json()
|
||||
|
||||
# Parse response
|
||||
completion += result["text"]
|
||||
meta_info = result["meta_info"]
|
||||
output_tokens = [x[1] for x in meta_info["output_token_logprobs"]]
|
||||
output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]]
|
||||
|
||||
# Update accumulated outputs
|
||||
accumulated_output_tokens.extend(output_tokens)
|
||||
accumulated_output_logprobs.extend(output_logprobs)
|
||||
accumulated_versions.extend([server_info.version] * len(output_tokens))
|
||||
|
||||
# Check if generation is complete
|
||||
finish_reason = meta_info["finish_reason"]
|
||||
stop_reason = finish_reason["type"]
|
||||
|
||||
payload["text"] += completion
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
return LLMResponse(
|
||||
completion=completion,
|
||||
input_tokens=req.input_ids,
|
||||
output_tokens=accumulated_output_tokens,
|
||||
output_logprobs=accumulated_output_logprobs,
|
||||
output_versions=accumulated_versions,
|
||||
stop_reason=stop_reason,
|
||||
latency=latency,
|
||||
ttft=latency, # Simplified for non-streaming
|
||||
)
|
||||
|
||||
async def aupdate_weights_from_disk(self, server_info: LLMServerInfo, path: str):
|
||||
server_url = f"http://{server_info.host}:{server_info.port}"
|
||||
response, _ = await self.arequest_with_retry(
|
||||
endpoint="/update_weights_from_disk",
|
||||
payload=dict(model_path=path, allow_interrupt=True),
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.client_config.request_timeout,
|
||||
target_server=server_info,
|
||||
)
|
||||
res = await response.json()
|
||||
assert res["success"]
|
||||
if "num_paused_requests" in res:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updating weights for server {server_url}"
|
||||
)
|
||||
self.registry.update_heartbeat(
|
||||
server_info.server_id, "healthy", version=server_info.version + 1
|
||||
)
|
||||
|
||||
async def ainit_weight_update_group(self, server_info, group_meta):
|
||||
payload = dict(
|
||||
master_address=group_meta.master_address,
|
||||
master_port=group_meta.master_port,
|
||||
rank_offset=group_meta.rank_offset,
|
||||
world_size=group_meta.world_size,
|
||||
group_name=group_meta.group_name,
|
||||
backend=group_meta.backend,
|
||||
)
|
||||
response, _ = await self.arequest_with_retry(
|
||||
endpoint="/init_weights_update_group",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.client_config.request_timeout,
|
||||
target_server=server_info,
|
||||
)
|
||||
res = await response.json()
|
||||
assert res["success"], res["message"]
|
||||
|
||||
async def aupdate_weights_from_distributed(self, server_info, weight_meta):
|
||||
payload = dict(
|
||||
name=weight_meta.param_name,
|
||||
dtype=weight_meta.dtype,
|
||||
shape=weight_meta.shape,
|
||||
)
|
||||
response, _ = await self.arequest_with_retry(
|
||||
endpoint="/update_weights_from_distributed",
|
||||
payload=payload,
|
||||
method="POST",
|
||||
max_retries=3,
|
||||
timeout=self.client_config.request_timeout,
|
||||
target_server=server_info,
|
||||
)
|
||||
res = await response.json()
|
||||
assert res["success"], res["message"]
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from arealite.api.cli_args import LLMServiceConfig, SGLangConfig
|
||||
from arealite.api.io_struct import AllocationMode, LLMServerInfo
|
||||
from arealite.api.llm_server_api import LLMServer
|
||||
from realhf.base import gpu_utils, logging, network, pkg_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def apply_sglang_path():
|
||||
"""Apply SGLang patch if available."""
|
||||
p = Path(os.path.dirname(__file__))
|
||||
patch_path = str(
|
||||
p.parent.parent.parent
|
||||
/ "patch"
|
||||
/ "sglang"
|
||||
/ f"v{pkg_version.get_version('sglang')}.patch"
|
||||
)
|
||||
|
||||
target_path = ""
|
||||
try:
|
||||
sglang_meta = subprocess.check_output(
|
||||
"python3 -m pip show sglang", shell=True
|
||||
).decode("ascii")
|
||||
for line in sglang_meta.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("Editable project location: "):
|
||||
target_path = str(Path(line.split(": ")[1]).parent)
|
||||
|
||||
if target_path and Path(patch_path).exists():
|
||||
proc = subprocess.Popen(
|
||||
["git", "apply", patch_path],
|
||||
cwd=target_path,
|
||||
stderr=sys.stdout,
|
||||
stdout=sys.stdout,
|
||||
)
|
||||
proc.wait()
|
||||
logger.info(f"Applied SGLang patch at {target_path}")
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
class SGLangServer(LLMServer):
|
||||
"""SGLang implementation of LLMServer."""
|
||||
|
||||
def __init__(self, args, service_config: LLMServiceConfig):
|
||||
super().__init__(args, service_config)
|
||||
self.server_info: LLMServerInfo | None = None
|
||||
self.base_gpu_id = 0
|
||||
self.config = args.rollout.sglang
|
||||
|
||||
self.alloc_mode = AllocationMode.from_str(args.allocation_mode)
|
||||
|
||||
def _resolve_base_gpu_id(self):
|
||||
# Determine GPU configuration
|
||||
import ray
|
||||
|
||||
tp_size = self.alloc_mode.gen_tp_size
|
||||
pp_size = self.alloc_mode.gen_pp_size
|
||||
mp_size = tp_size * pp_size
|
||||
if ray.is_initialized():
|
||||
self.base_gpu_id = 0
|
||||
elif "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
if len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1:
|
||||
self.base_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"])
|
||||
elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == mp_size:
|
||||
self.base_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown how to resolve cuda visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}, "
|
||||
f"setting base_gpu_id to 0."
|
||||
)
|
||||
self.base_gpu_id = 0
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
map(str, range(gpu_utils.gpu_count()))
|
||||
)
|
||||
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
# torchrun
|
||||
self.base_gpu_id = int(os.environ["RANK"]) % gpu_utils.gpu_count()
|
||||
elif gpu_utils.gpu_count() == mp_size:
|
||||
self.base_gpu_id = 0
|
||||
else:
|
||||
logger.warning("Unknown GPU configuration, setting base_gpu_id to 0. ")
|
||||
self.base_gpu_id = 0
|
||||
|
||||
def launch_server(self) -> LLMServerInfo | None:
|
||||
# Apply SGLang patch
|
||||
apply_sglang_path()
|
||||
self._resolve_base_gpu_id()
|
||||
# Get host and ports
|
||||
host_ip = network.gethostip()
|
||||
host = "localhost" if not self.config.enable_metrics else host_ip
|
||||
ports = network.find_multiple_free_ports(
|
||||
2,
|
||||
low=10000,
|
||||
high=60000,
|
||||
experiment_name=self.registry.expr_name,
|
||||
trial_name=self.registry.trial_name,
|
||||
)
|
||||
server_port = ports[0]
|
||||
nccl_port = ports[1]
|
||||
# Build command
|
||||
tp_size = self.alloc_mode.gen_tp_size
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
sglang_config=self.config,
|
||||
model_path=self.args.rollout.model_path,
|
||||
tp_size=tp_size,
|
||||
base_gpu_id=self.base_gpu_id,
|
||||
dist_init_addr=f"{host}:{nccl_port}",
|
||||
served_model_name=self.service_config.served_model_name,
|
||||
skip_tokenizer_init=False,
|
||||
)
|
||||
# Launch process
|
||||
full_command = f"{cmd} --port {server_port}"
|
||||
full_command = full_command.replace("\\\n", " ").replace("\\", " ")
|
||||
self.process = subprocess.Popen(
|
||||
full_command.split(),
|
||||
text=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stdout,
|
||||
)
|
||||
# Create server info
|
||||
self.server_info = LLMServerInfo(
|
||||
server_id=self.server_id,
|
||||
host=host,
|
||||
port=server_port,
|
||||
status="starting",
|
||||
version=0,
|
||||
)
|
||||
return self.server_info
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""Check if the SGLang server is healthy."""
|
||||
if not self.server_info or not self.process:
|
||||
return False
|
||||
|
||||
# Check if process is still running
|
||||
if self.process.poll() is not None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Check server endpoint
|
||||
base_url = f"http://{self.server_info.host}:{self.server_info.port}"
|
||||
response = requests.get(
|
||||
f"{base_url}/metrics",
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
return False
|
||||
# Update server load
|
||||
for line in response.text.split("\n"):
|
||||
if line.startswith("sglang:num_running_reqs"):
|
||||
self.load = float(line.split(" ")[1])
|
||||
break
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
Loading…
Reference in New Issue