[Fix] [lite] Merge from the internal repo to fix GRPO bugs and refactor the train engine (#181)

* PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing

* PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP  engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .

* PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngine thread and SGLang configuration

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/355?tab=commit

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* fix
* .

* PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment utilities

Merge branch fw/lite-fix2 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/357?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* fix destroy process group

* PullRequest: 358 [lite] Support GRPO training locally with the GSM8k dataset

Merge branch fw/lite-fix3 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/358

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* fix loss mask
* fix
* .

* PullRequest: 368 [lite] Refactor train engine after merging contributions from GitHub

Merge branch fw/lite-train-engine of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/368

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .

* PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation

Merge branch fw/lite-fix0716 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/371

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

---------

Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
This commit is contained in:
Wei Fu 2025-07-16 17:26:49 +08:00 committed by GitHub
parent 0283cfa124
commit 29e164a69d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 790 additions and 906 deletions

View File

@ -18,7 +18,7 @@ from arealite.utils.fs import get_user_tmp
class MicroBatchSpec: class MicroBatchSpec:
"""Specification for splitting micro-batches during training.""" """Specification for splitting micro-batches during training."""
n_mbs: int = field( n_mbs: Optional[int] = field(
default=1, default=1,
metadata={ metadata={
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count", "help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
@ -161,7 +161,7 @@ class FSDPEngineConfig:
@dataclass @dataclass
class HFEngineConfig: class DeepSpeedAutoTPEngineConfig:
autotp_size: Optional[int] = field( autotp_size: Optional[int] = field(
default=1, default=1,
metadata={"help": "DeepSpeed AutoTP size"}, metadata={"help": "DeepSpeed AutoTP size"},
@ -201,7 +201,88 @@ class TrainEngineConfig:
) )
backend: str = "" backend: str = ""
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig) fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
hf: HFEngineConfig = field(default_factory=HFEngineConfig) ds_auto_tp: DeepSpeedAutoTPEngineConfig = field(
default_factory=DeepSpeedAutoTPEngineConfig
)
@dataclass
class PPOActorConfig(TrainEngineConfig):
# Core PPO/GRPO Parameters
group_size: int = field(
default=1, metadata={"help": "Number of sequences in each group"}
)
group_adv_norm: bool = field(
default=False,
metadata={
"help": "Normalize advantages within each prompt group rather than globally"
},
)
ppo_n_minibatches: int = field(
default=4, metadata={"help": "Number of minibatches for each PPO update"}
)
eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
)
c_clip: Optional[float] = field(
default=None,
metadata={
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
},
)
temperature: float = field(
default=1.0, metadata={"help": "Temperature during generation."}
)
# Reward
group_reward_norm: bool = field(
default=False,
metadata={
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
},
)
reward_scaling: float = field(
default=1.0, metadata={"help": "Reward scaling factor"}
)
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
reward_clip: float = field(
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
)
mask_no_eos_with_zero: bool = field(
default=False,
metadata={
"help": "Mask truncated generations (no EOS token) and exclude from training"
},
)
# Advantage Estimation
discount: float = field(
default=1.0, metadata={"help": "Discount factor for future rewards"}
)
gae_lambda: float = field(
default=1.0, metadata={"help": "Lambda parameter for GAE"}
)
adv_norm: bool = field(
default=True, metadata={"help": "Enable advantage normalization"}
)
# KL Control
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
# Asynchronous RL
recompute_logprob: bool = field(
default=False,
metadata={"help": "Recompute logp and replace the logp returned by inference."},
)
use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
)
behav_imp_weight_cap: Optional[float] = field(
default=None,
metadata={
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
},
)
@dataclass @dataclass

View File

@ -25,10 +25,10 @@ class Scheduling:
cpu: int cpu: int
gpu: int gpu: int
mem: int mem: int
nodelist: str = None nodelist: Optional[str] = None
exclude: str = None exclude: Optional[str] = None
partition: str = None partition: Optional[str] = None
container_image: str = None container_image: Optional[str] = None
env_vars: Dict[str, str] = field(default_factory=dict) env_vars: Dict[str, str] = field(default_factory=dict)
# time utils from "https://slurm.schedmd.com/sbatch.html" # time utils from "https://slurm.schedmd.com/sbatch.html"
time_limit: Optional[str] = None # see "--time" option for format time_limit: Optional[str] = None # see "--time" option for format
@ -105,7 +105,7 @@ class TrainEngine(abc.ABC):
def forward( def forward(
self, self,
input_: TensorDict, input_: TensorDict,
output_seqlens: List[List[int]] | None = None, output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None, post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat, aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None: ) -> Any | None:

View File

@ -71,7 +71,7 @@ class AllocationType(enum.Enum):
@dataclass @dataclass
class AllocationMode: class AllocationMode:
type_: AllocationType type_: AllocationType
parallel_strat: None | Dict[str, Dict[str, int]] parallel_strat: Dict[str, Dict[str, int]]
@property @property
def gen_tp_size(self) -> int: def gen_tp_size(self) -> int:
@ -115,7 +115,7 @@ class AllocationMode:
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}") raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")
@staticmethod @staticmethod
def extract_3d_alloc(allocation_mode: str) -> Dict | None: def extract_parallelism_strategy(allocation_mode: str) -> Dict:
for x, y, z in itertools.permutations(["d", "t", "p"]): for x, y, z in itertools.permutations(["d", "t", "p"]):
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)" pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
m = re.match(pattern, allocation_mode) m = re.match(pattern, allocation_mode)
@ -130,29 +130,28 @@ class AllocationMode:
z: c, z: c,
} }
} }
raise ValueError(
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
)
@staticmethod @staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None: def extract_decoupled_alloc(allocation_mode: str) -> Dict:
pattern = re.compile( pattern = re.compile(
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))" r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
) )
m = pattern.match(allocation_mode) m = pattern.match(allocation_mode)
if not m: if not m:
return raise ValueError(
f"Unknown how to resolve decoupled allocation: {allocation_mode}"
)
if m.group(1): if m.group(1):
gen_alloc = m.group(1) gen_alloc = m.group(1)
other_alloc = m.group(2) other_alloc = m.group(2)
else: else:
gen_alloc = m.group(4) gen_alloc = m.group(4)
other_alloc = m.group(3) other_alloc = m.group(3)
gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc) gen_alloc = AllocationMode.extract_parallelism_strategy(gen_alloc)
if not gen_alloc: other_alloc = AllocationMode.extract_parallelism_strategy(other_alloc)
return
other_alloc = AllocationMode.extract_3d_alloc(
other_alloc
) or AllocationMode.extract_key_value_alloc(other_alloc)
if not other_alloc:
return
other_alloc.update({"gen": gen_alloc["*"]}) other_alloc.update({"gen": gen_alloc["*"]})
return other_alloc return other_alloc
@ -171,7 +170,7 @@ class SaveLoadMeta:
path: str path: str
weight_format: str weight_format: str
with_optim: bool with_optim: bool
tokenizer: PreTrainedTokenizerFast | None tokenizer: Optional[PreTrainedTokenizerFast]
base_model_path: str | None base_model_path: str | None
naive_distributed: bool = False naive_distributed: bool = False

View File

@ -0,0 +1,128 @@
import os
import torch
import torch.distributed as dist
from safetensors.torch import save_file
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
from arealite.engine.base_hf_engine import BaseHFEngine
from arealite.utils.save_load import (
get_state_dict_from_repo_id_or_path,
is_existing_local_path,
)
from realhf.base import constants, logging
logger = logging.getLogger("DeepSpeedAutoTPEngine")
class DeepSpeedAutoTPEngine(BaseHFEngine):
def __init__(self, config: TrainEngineConfig):
super().__init__(config)
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
"""Initialize distributed communication and model."""
assert (
addr is None
), "DeepSpeedAutoTPEngine does not support remote initialization."
import deepspeed
self.create_process_group()
world_size = int(os.environ.get("WORLD_SIZE"))
deepspeed.init_distributed(
dist_backend="nccl",
world_size=world_size,
timeout=constants.NCCL_DEFAULT_TIMEOUT,
)
self.create_device_model()
# NOTE: the device context manager does not work here.
self.model = deepspeed.tp_model_init(
self.model,
tp_size=self.config.ds_auto_tp.autotp_size,
dtype=getattr(torch, self.config.dtype),
).to(self.device)
self.create_optimizer(ft_spec)
self.initialized = True
def _check_autotp(self):
tp_size = self.config.ds_auto_tp.autotp_size
config = self.model_config
num_attention_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
return (
num_attention_heads % tp_size == 0
and num_key_value_heads % tp_size == 0
and hidden_size % tp_size == 0
and intermediate_size % tp_size == 0
)
def save(self, meta: SaveLoadMeta):
if meta.weight_format != "naive_distributed":
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if self.model is None:
raise RuntimeError("Model not initialized")
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
os.makedirs(meta.path, exist_ok=True)
self.model_config.save_pretrained(
meta.path,
)
if meta.tokenizer is not None:
meta.tokenizer.save_pretrained(
meta.path,
)
state_dict = self.model.state_dict()
if hasattr(self.model, "module"):
state_dict = {
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
for k, v in state_dict.items()
}
else:
state_dict = {k: v.cpu() for k, v in state_dict.items()}
# Only support store parameters from model partitions respectively
gathered_state_dicts = None
if rank == 0:
gathered_state_dicts = [None for _ in range(world_size)]
dist.gather_object(
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
)
if rank == 0:
for i, state_dict in enumerate(gathered_state_dicts):
save_file(state_dict, f"{meta.path}/rank_{i:02d}_model.safetensors")
if meta.with_optim:
self.save_optimizer_state(meta.path)
def load(self, meta: SaveLoadMeta):
if meta.weight_format != "naive_distributed":
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
rank = dist.get_rank()
# Only support load full model parameters from huggingface
# and load model partition locally
if rank == 0 or is_existing_local_path(meta.path):
path = f"{meta.path}/rank_{rank:02d}_model.safetensors"
full_state = get_state_dict_from_repo_id_or_path(meta.path)
if hasattr(self.model, "module") and not hasattr(full_state):
full_state = {
f"module.{k}" if not k.startswith("module.") else k: v
for k, v in full_state.items()
}
self.model.load_state_dict(
full_state, strict=not self.model_config.tie_word_embeddings
)
if self.model_config.tie_word_embeddings:
self.model.tie_weights()
if meta.with_optim:
self.load_optimizer_state(meta.path)
def upload_weights(self, meta: WeightUpdateMeta):
raise ValueError(f"update weight not implemented {meta.type}")

View File

@ -0,0 +1,360 @@
import gc
import os
import time
from typing import Any, Callable, Dict, List
import torch
import torch.distributed as dist
from tensordict import TensorDict
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PretrainedConfig,
PreTrainedTokenizerFast,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, TrainEngine
from arealite.utils.data import (
MicroBatchList,
amend_position_ids,
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_padded_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
from arealite.utils.model import disable_dropout_in_model
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import constants, logging
logger = logging.getLogger("Base HF Engine")
class BaseHFEngine(TrainEngine):
def __init__(self, config: TrainEngineConfig):
self.config = config
self.optimizer_config = config.optimizer
self.model: torch.nn.Module
self.optimizer: torch.optim.Optimizer
self.tokenizer: PreTrainedTokenizerFast
# huggingface model config
self.model_config: PretrainedConfig
# initialization
self.initialized = False
self.own_global_group = False
self._parallelism_group: dist.ProcessGroup
self.weight_update_group_initialized = False
self.world_size = int(os.environ["WORLD_SIZE"])
def train(self, mode: bool = True):
assert self.model is not None
self.model.train(mode=mode)
return self
@property
def parallelism_group(self) -> dist.ProcessGroup:
assert self.initialized
return self._parallelism_group
def create_process_group(self):
if not dist.is_initialized():
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
dist.init_process_group(
backend="nccl",
timeout=constants.NCCL_DEFAULT_TIMEOUT,
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
)
self.own_global_group = True
self._parallelism_group = dist.new_group()
def create_device_model(self):
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
self.model = model
def create_optimizer(self, ft_spec: FinetuneSpec):
if self.optimizer_config is None:
return
assert self.model is not None
# Set up optimizer
tik = time.perf_counter()
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
self.optimizer_config.warmup_steps_proportion * total_train_steps
)
if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
def destroy(self):
"""Destroy the engine and release GPU memory."""
del self.optimizer
del self.model
gc.collect()
torch.cuda.empty_cache()
gc.collect()
dist.destroy_process_group(self.parallelism_group)
if self.own_global_group:
dist.destroy_process_group()
self.initialized = False
def save_optimizer_state(self, path: str):
# Save FSDP sharded state dict on each rank
assert self.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
)
state_dict = self.optimizer.state_dict()
torch.save(state_dict, shard_path)
dist.barrier()
def load_optimizer_state(self, path: str):
# Load FSDP sharded state dict
assert self.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
)
optimizer_state_dict = torch.load(shard_path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict)
dist.barrier()
def step_lr_scheduler(self):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
logger.info(
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
)
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
# FIXME: the resulting max_seqlen is a tensor rather than an integer
for mb in mb_list.mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
for mb in mb_list.padded_mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
return mb_list
def train_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> Dict[str, float]:
"""Train on a batch using gradient accumulation."""
input_ = input_.to(self.device)
assert self.optimizer is not None
assert self.optimizer_config is not None
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_list = self.prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
dist.all_reduce(total_loss_weight)
# Process microbatches with gradient accumulation
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
# Scale loss for accumulation
# Revert gradient averaging across dp ranks
loss_scale *= self.world_size
loss *= loss_scale
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.optimizer_config.gradient_clipping,
norm_type=2.0,
error_if_nonfinite=False,
foreach=None,
)
if not torch.isfinite(grad_norm):
self.optimizer.zero_grad()
update_successful = False
else:
self.optimizer.step()
update_successful = True
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
lr=current_lr,
)
@torch.no_grad()
def eval_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[TensorDict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
input_ = input_.to(self.device)
mb_list = self.prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
total_loss += loss.item() * loss_scale
total_weight += loss_scale
return torch.tensor(total_loss / total_weight)
@torch.no_grad()
def forward(
self,
input_: TensorDict,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Forward pass with optional post-processing."""
input_ = input_.to(self.device)
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
mb_list = self.prepare_mb_list(input_)
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
if post_hook:
result = post_hook(logits, mb_input)
results.append(result)
else:
results.append(logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
reordered = reorder_list(unpacked, mb_list.backward_indices)
return pad_and_stack_tensors_along_first_dim(reordered)

View File

@ -1,42 +1,20 @@
import gc
import os import os
import time import time
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Dict, List, Optional from typing import Callable, Dict, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers
from tensordict import TensorDict from tensordict import TensorDict
from torch.distributed.checkpoint.state_dict import ( from torch.distributed.checkpoint.state_dict import (
StateDictOptions, StateDictOptions,
get_model_state_dict, get_model_state_dict,
) )
from transformers import ( from transformers import PreTrainedTokenizerFast
AutoConfig,
AutoModelForCausalLM,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import TrainEngineConfig from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import ( from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
FinetuneSpec, from arealite.engine.base_hf_engine import BaseHFEngine
SaveLoadMeta,
TrainEngine,
WeightUpdateMeta,
)
from arealite.utils.data import (
MicroBatchList,
amend_position_ids,
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_padded_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
from arealite.utils.fsdp import ( from arealite.utils.fsdp import (
CPUOffloadPolicy, CPUOffloadPolicy,
MixedPrecisionPolicy, MixedPrecisionPolicy,
@ -44,108 +22,35 @@ from arealite.utils.fsdp import (
create_fsdp_device_mesh, create_fsdp_device_mesh,
fsdp2_clip_grad_norm_, fsdp2_clip_grad_norm_,
fsdp2_load_full_state_dict, fsdp2_load_full_state_dict,
get_cosine_schedule_with_warmup,
) )
from arealite.utils.model import disable_dropout_in_model
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_tokenizer from realhf.base import logging, name_resolve, names, pkg_version
from realhf.base import constants, logging, name_resolve, names, pkg_version
logger = logging.getLogger("FSDPEngine") logger = logging.getLogger("FSDPEngine")
class FSDPEngine(TrainEngine): class FSDPEngine(BaseHFEngine):
def __init__(self, config: TrainEngineConfig): def __init__(self, config: TrainEngineConfig):
self.config = config super().__init__(config)
self.optimizer_config = config.optimizer
self.model = None
self.optimizer = None
self.tokenizer = None
# huggingface model config
self.model_config = None
# FSDP options # FSDP options
self.mixed_precision_policy = None self.mixed_precision_policy = None
self.device_mesh = None self.device_mesh = None
self.cpu_offload = None self.cpu_offload = None
# initialization
self.initialized = False
self.own_global_group = False
self._parallelism_group = None
self.weight_update_group_initialized = False
# TODO: Handle the case when WORLD_SIZE is not set in launcher
self.world_size = int(os.environ["WORLD_SIZE"])
def train(self, mode: bool = True):
assert self.model is not None
self.model.train(mode=mode)
return self
@property
def parallelism_group(self) -> dist.ProcessGroup:
assert self.initialized
return self._parallelism_group
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None): def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
# Initialize distributed enviroments and load model. # Initialize distributed enviroments and load model.
assert addr is None, "FSDPEngine does not support remote initialization." assert addr is None, "FSDPEngine does not support remote initialization."
assert pkg_version.is_version_greater_or_equal( assert pkg_version.is_version_greater_or_equal(
"torch", "2.4.0" "torch", "2.4.0"
), f"arealite only supports FSDP2, which requires torch>=2.4.0" ), f"arealite only supports FSDP2, which requires torch>=2.4.0"
"""Initialize distributed communication and model.""" self.create_process_group()
if not dist.is_initialized(): self.create_device_model()
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
dist.init_process_group(
backend="nccl",
timeout=constants.NCCL_DEFAULT_TIMEOUT,
device_id=torch.device(int(os.environ["LOCAL_RANK"])),
)
self.own_global_group = True
self._parallelism_group = dist.new_group()
# TODO: Handle the condition when LOCAL_RANK is not set in launcher
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
# Wrap with FSDP2
# Simple auto wrap policy # Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy( self.mixed_precision_policy = MixedPrecisionPolicy(
param_dtype=dtype, param_dtype=getattr(torch, self.config.dtype),
reduce_dtype=torch.float32, reduce_dtype=torch.float32,
cast_forward_inputs=True, cast_forward_inputs=True,
) )
@ -154,82 +59,19 @@ class FSDPEngine(TrainEngine):
self.cpu_offload = ( self.cpu_offload = (
CPUOffloadPolicy() if self.config.fsdp.offload_params else None CPUOffloadPolicy() if self.config.fsdp.offload_params else None
) )
fsdp_kwargs = { fsdp_kwargs = {
"mesh": self.device_mesh, "mesh": self.device_mesh,
"mp_policy": self.mixed_precision_policy, "mp_policy": self.mixed_precision_policy,
"offload_policy": self.cpu_offload, "offload_policy": self.cpu_offload,
"reshard_after_forward": True, "reshard_after_forward": True,
} }
# Wrap with FSDP2
tik = time.perf_counter() tik = time.perf_counter()
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy) apply_fsdp2(self.model, fsdp_kwargs, self.config.fsdp.wrap_policy)
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}") logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
self.model = model
# Set up optimizer
if self.optimizer_config is not None:
tik = time.perf_counter()
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
self.optimizer_config.warmup_steps_proportion * total_train_steps
)
if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
self.create_optimizer(ft_spec)
self.initialized = True self.initialized = True
def destroy(self):
"""Destroy the engine and release GPU memory."""
self.model = None
self.optimizer = None
gc.collect()
torch.cuda.empty_cache()
gc.collect()
dist.destroy_process_group(self.parallelism_group)
if self.own_global_group:
dist.destroy_process_group()
self.initialized = False
def save(self, meta: SaveLoadMeta): def save(self, meta: SaveLoadMeta):
if meta.weight_format == "hf": if meta.weight_format == "hf":
self._save_model_to_hf(meta.path, meta.tokenizer) self._save_model_to_hf(meta.path, meta.tokenizer)
@ -240,7 +82,7 @@ class FSDPEngine(TrainEngine):
raise ValueError(f"Unknown weight format {meta.weight_format}. ") raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim: if meta.with_optim:
self._save_optimizer_state(meta.path) self.save_optimizer_state(meta.path)
def load(self, meta: SaveLoadMeta): def load(self, meta: SaveLoadMeta):
if meta.weight_format == "hf": if meta.weight_format == "hf":
@ -252,34 +94,10 @@ class FSDPEngine(TrainEngine):
raise ValueError(f"Unknown weight format {meta.weight_format}. ") raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim: if meta.with_optim:
self._load_optimizer_state(meta.path) self.load_optimizer_state(meta.path)
def _save_optimizer_state(self, path: str):
# Save FSDP sharded state dict on each rank
assert self.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
)
state_dict = self.optimizer.state_dict()
torch.save(state_dict, shard_path)
dist.barrier()
def _load_optimizer_state(self, path: str):
# Load FSDP sharded state dict
assert self.optimizer is not None
assert dist.is_initialized()
rank = dist.get_rank()
shard_path = os.path.join(
path, f"optim_world_size_{self.world_size}_rank_{rank}.pt"
)
optimizer_state_dict = torch.load(shard_path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict)
dist.barrier()
def _save_model_to_hf( def _save_model_to_hf(
self, path: str, tokenizer: Optional[transformers.PreTrainedTokenizerFast] self, path: str, tokenizer: Optional[PreTrainedTokenizerFast]
): ):
"""Save model in HuggingFace format.""" """Save model in HuggingFace format."""
if self.model is None: if self.model is None:
@ -345,35 +163,11 @@ class FSDPEngine(TrainEngine):
"Distributed weight update is not implemented for FSDPEngine yet. " "Distributed weight update is not implemented for FSDPEngine yet. "
) )
def step_lr_scheduler(self):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
logger.info(
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
)
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
# FIXME: the resulting max_seqlen is a tensor rather than an integer
for mb in mb_list.mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
return mb_list
def train_batch( def train_batch(
self, self,
input_: TensorDict, input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float], loss_weight_fn: Callable[[TensorDict], float],
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Train on a batch using gradient accumulation.""" """Train on a batch using gradient accumulation."""
input_ = input_.to(self.device) input_ = input_.to(self.device)
@ -382,7 +176,7 @@ class FSDPEngine(TrainEngine):
assert self.lr_scheduler is not None assert self.lr_scheduler is not None
self.optimizer.zero_grad() self.optimizer.zero_grad()
mb_list = self._prepare_mb_list(input_) mb_list = self.prepare_mb_list(input_)
total_loss_weight = torch.tensor( total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32 sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
@ -394,7 +188,6 @@ class FSDPEngine(TrainEngine):
for i, (pad_length, padded_mb_input, mb_input) in enumerate( for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs) zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
): ):
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
outputs = self.model(**padded_mb_input) outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0) logits = outputs.logits.squeeze(0)
@ -409,6 +202,7 @@ class FSDPEngine(TrainEngine):
loss *= loss_scale loss *= loss_scale
loss.backward() loss.backward()
# NOTE: grad norm clip function is different
grad_norm = fsdp2_clip_grad_norm_( grad_norm = fsdp2_clip_grad_norm_(
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
) )
@ -427,72 +221,3 @@ class FSDPEngine(TrainEngine):
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"), grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
lr=current_lr, lr=current_lr,
) )
@torch.no_grad()
def eval_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
input_ = input_.to(self.device)
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
total_loss += loss.item() * loss_scale
total_weight += loss_scale
return torch.tensor(total_loss / total_weight)
@torch.no_grad()
def forward(
self,
input_: TensorDict,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Forward pass with optional post-processing."""
input_ = input_.to(self.device)
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
mb_list = self._prepare_mb_list(input_)
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
if post_hook:
result = post_hook(logits, mb_input)
results.append(result)
else:
results.append(logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
reordered = reorder_list(unpacked, mb_list.backward_indices)
return pad_and_stack_tensors_along_first_dim(reordered)

View File

@ -1,467 +0,0 @@
import gc
import os
import time
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
import transformers
from safetensors.torch import save_file
from tensordict import TensorDict
from transformers import (
AutoConfig,
AutoModelForCausalLM,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import (
FinetuneSpec,
SaveLoadMeta,
TrainEngine,
WeightUpdateMeta,
)
from arealite.utils.data import (
MicroBatchList,
amend_position_ids,
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_packed_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
from arealite.utils.save_load import (
get_state_dict_from_repo_id_or_path,
is_existing_local_path,
)
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, name_resolve, names
logger = logging.getLogger("HFEngine")
class HFEngine(TrainEngine):
def __init__(self, config: TrainEngineConfig):
self.config = config
self.optimizer_config = config.optimizer
self.model = None
self.optimizer = None
self.tokenizer = None
# huggingface model config
self.model_config = None
# initialization
self.initialized = False
self.weight_update_group_initialized = False
def train(self, mode: bool = True):
assert self.model is not None
self.model.train(mode=mode)
return self
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
"""Initialize distributed communication and model."""
assert addr is None, "HFEngine does not support remote initialization."
world_size = int(os.environ.get("WORLD_SIZE", 0))
if not dist.is_initialized() and world_size > 1:
try:
import deepspeed
except ImportError:
print(
"Warning: deepspeed is not installed. Some functionality may be disabled."
)
deepspeed.init_distributed(dist_backend="nccl", world_size=world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
self.device = torch.device(f"cuda:{local_rank}")
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
self.model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
).to(f"cuda:{local_rank}")
if not self.config.init_from_scratch:
# Load model from a initial checkpoint path,
# which should only be a huggingface checkpoint.
load_meta = SaveLoadMeta(
path=self.config.path,
weight_format="hf",
with_optim=False,
tokenizer=None,
base_model_path=self.config.path,
naive_distributed=False,
)
self.load(load_meta)
if world_size > 1:
if self._check_autotp():
self.model = deepspeed.tp_model_init(
self.model, tp_size=self.config.hf.autotp_size, dtype=dtype
)
else:
raise RuntimeError("DeepSpeed AutoTP configuration error in HFEngine. ")
# Set up optimizer
if self.optimizer_config is not None:
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(beta1, beta2),
eps=eps,
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
self.optimizer_config.warmup_steps_proportion * total_train_steps
)
if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
self.initialized = True
def _check_autotp(self):
tp_size = self.config.hf.autotp_size
config = self.model_config
num_attention_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
return (
num_attention_heads % tp_size == 0
and num_key_value_heads % tp_size == 0
and hidden_size % tp_size == 0
and intermediate_size % tp_size == 0
)
def destroy(self):
"""Destroy the engine and release GPU memory."""
self.model = None
self.optimizer = None
gc.collect()
torch.cuda.empty_cache()
gc.collect()
self.initialized = False
def save(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._save_model_to_hf(meta.path, meta.tokenizer, meta.naive_distributed)
elif meta.weight_format == "dcp":
# TODO: implement DCP save/load for HF
raise NotImplementedError("DCP format saving is not implemented yet. ")
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim:
self._save_optimizer_state(meta.path)
def load(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._load_model_from_hf(meta.path, meta.naive_distributed)
elif meta.weight_format == "dcp":
# TODO: implement DCP save/load for HF
raise NotImplementedError("DCP format loading is not implemented yet. ")
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim:
self._load_optimizer_state(meta.path)
def _save_optimizer_state(self, path: str):
assert self.optimizer is not None
os.makedirs(path, exist_ok=True)
torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt"))
def _load_optimizer_state(self, path: str):
assert self.optimizer is not None
path = os.path.join(path, "optim.pt")
optimizer_state_dict = torch.load(path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict)
def _save_model_to_hf(
self,
path: str,
tokenizer: Optional[transformers.PreTrainedTokenizerFast],
naive_distributed: bool,
):
"""Save model in HuggingFace format."""
if self.model is None:
raise RuntimeError("Model not initialized")
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
os.makedirs(path, exist_ok=True)
self.model_config.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
if world_size > 1:
dist.barrier()
state_dict = self.model.state_dict()
if hasattr(self.model, "module"):
state_dict = {
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
for k, v in state_dict.items()
}
else:
state_dict = {k: v.cpu() for k, v in state_dict.items()}
if world_size > 1 and naive_distributed:
# Only support store parameters from model partitions respectively
gathered_state_dicts = None
if rank == 0:
gathered_state_dicts = [None for _ in range(world_size)]
dist.gather_object(
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
)
if rank == 0:
for i, state_dict in enumerate(gathered_state_dicts):
save_file(state_dict, f"{path}/rank_{i:02d}_model.safetensors")
else:
self.model.save_pretrained(path, state_dict=state_dict)
if world_size > 1:
dist.barrier()
def _load_model_from_hf(self, path: str, naive_distributed: bool):
"""Load model from HuggingFace format."""
rank = dist.get_rank()
# Only support load full model parameters from huggingface
# and load model partition locally
if rank == 0 or is_existing_local_path(path):
if naive_distributed:
path = f"{path}/rank_{rank:02d}_model.safetensors"
full_state = get_state_dict_from_repo_id_or_path(path)
if hasattr(self.model, "module") and not hasattr(full_state):
full_state = {
f"module.{k}" if not k.startswith("module.") else k: v
for k, v in full_state.items()
}
self.model.load_state_dict(
full_state, strict=not self.model_config.tie_word_embeddings
)
if self.model_config.tie_word_embeddings:
self.model.tie_weights()
def upload_weights(self, meta: WeightUpdateMeta):
if meta.type == "nccl":
if not self.weight_update_group_initialized:
self._init_distributed_weight_update(meta)
self._update_weights_from_distributed()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer, meta.naive_distributed)
update_name = names.update_weights_from_disk(
self.config.experiment_name,
self.config.trial_name,
meta.model_version,
)
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
else:
raise ValueError(f"Unknown weight update type {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
raise NotImplementedError(
"Distributed weight update is not implemented for HFEngine yet. "
)
def _update_weights_from_distributed(self):
raise NotImplementedError(
"Distributed weight update is not implemented for HFEngine yet. "
)
def step_lr_scheduler(self):
assert self.lr_scheduler is not None
return self.lr_scheduler.step()
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
packed_input = pack_tensor_dict(input_)
mb_list = split_packed_tensor_dict_into_mb_list(
packed_input,
self.config.mb_spec,
)
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
return mb_list
def train_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
"""Train on a batch using gradient accumulation."""
input_ = input_.to(self.device)
assert self.optimizer is not None
assert self.optimizer_config is not None
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
# Process microbatches with gradient accumulation
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
loss *= loss_scale
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.optimizer_config.gradient_clipping,
norm_type=2.0,
error_if_nonfinite=False,
foreach=None,
)
if not torch.isfinite(grad_norm):
self.optimizer.zero_grad()
update_successful = False
else:
self.optimizer.step()
update_successful = True
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
lr=current_lr,
)
@torch.no_grad()
def eval_batch(
self,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
total_loss += loss.item() * loss_scale
total_weight += loss_scale
return torch.tensor(total_loss / total_weight)
@torch.no_grad()
def forward(
self,
input_: TensorDict,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Forward pass with optional post-processing."""
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
mb_list = self._prepare_mb_list(input_)
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
if post_hook:
result = post_hook(logits, mb_input)
results.append(result)
else:
results.append(logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
reordered = reorder_list(unpacked, mb_list.backward_indices)
return pad_and_stack_tensors_along_first_dim(reordered)

View File

@ -114,7 +114,9 @@ class PPOActor:
values = torch.zeros_like(rewards) values = torch.zeros_like(rewards)
else: else:
values = data["values"] values = data["values"]
advantages_reversed = [] advantages_reversed = [
torch.zeros(bs, dtype=torch.float32, device=values.device)
]
lastgaelam = 0 lastgaelam = 0
for t in reversed(range(max_seqlen - 1)): for t in reversed(range(max_seqlen - 1)):
nextvalues = values[:, t + 1] nextvalues = values[:, t + 1]
@ -123,9 +125,6 @@ class PPOActor:
delta = rewards[:, t] + self.discount * nextvalues - values[:, t] delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam lastgaelam = delta + self.discount * self.gae_lambda * lastgaelam
advantages_reversed.append(lastgaelam) advantages_reversed.append(lastgaelam)
advantages_reversed.append(
torch.zeros(bs, dtype=torch.float32, device=values.device)
)
advantages = torch.stack(advantages_reversed[::-1], dim=1) advantages = torch.stack(advantages_reversed[::-1], dim=1)
# Optionally perform advantage normalization. # Optionally perform advantage normalization.

View File

@ -1,5 +1,3 @@
from typing import Dict
import torch import torch
import torch.utils.data import torch.utils.data
from tensordict import TensorDict from tensordict import TensorDict
@ -44,9 +42,7 @@ class FSDPLMEngine(FSDPEngine):
return self.lm_engine.evaluate_lm(data) return self.lm_engine.evaluate_lm(data)
def compute_packed_sft_loss( def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.Tensor:
logits: torch.Tensor, input_: Dict[str, torch.Tensor]
) -> torch.Tensor:
packed_input_ids: torch.Tensor = input_["input_ids"] packed_input_ids: torch.Tensor = input_["input_ids"]
cu_seqlens: torch.Tensor = input_["cu_seqlens"] cu_seqlens: torch.Tensor = input_["cu_seqlens"]
loss_mask = input_["loss_mask"].bool() loss_mask = input_["loss_mask"].bool()

View File

@ -7,10 +7,12 @@ import traceback
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from datetime import datetime from datetime import datetime
from queue import Empty, Full, Queue from queue import Empty, Full, Queue
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List from typing import TYPE_CHECKING, Any, Callable, Dict, List
import aiohttp
import requests import requests
import torch.distributed as dist import torch.distributed as dist
import uvloop
from tensordict import TensorDict from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
@ -23,8 +25,8 @@ from arealite.api.io_struct import (
RolloutStat, RolloutStat,
WeightUpdateMeta, WeightUpdateMeta,
) )
from arealite.utils.http import arequest_with_retry from arealite.utils.data import concat_padded_tensors
from arealite.utils.padding import concat_padded_tensors from arealite.utils.http import arequest_with_retry, get_default_connector
from realhf.base import logging, name_resolve, names, pkg_version from realhf.base import logging, name_resolve, names, pkg_version
if TYPE_CHECKING: if TYPE_CHECKING:
@ -37,7 +39,7 @@ if pkg_version.is_available("sglang"):
else: else:
SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids" SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids"
ROLLOUT_POLL_WAIT_TIME = 0.1 ROLLOUT_POLL_WAIT_TIME = 0.05
RID_CACHE_SIZE = 128 RID_CACHE_SIZE = 128
@ -98,6 +100,7 @@ class RemoteSGLangEngine(InferenceEngine):
def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None): def initialize(self, addr: str | None, ft_spec: FinetuneSpec = None):
self.rollout_tasks: Dict[str, asyncio.Task] = {} self.rollout_tasks: Dict[str, asyncio.Task] = {}
self.executor = ProcessPoolExecutor(max_workers=1) self.executor = ProcessPoolExecutor(max_workers=1)
self.rollout_thread = threading.Thread(target=self._rollout_thread) self.rollout_thread = threading.Thread(target=self._rollout_thread)
self.rollout_thread.start() self.rollout_thread.start()
@ -118,32 +121,39 @@ class RemoteSGLangEngine(InferenceEngine):
def _rollout_thread(self): def _rollout_thread(self):
"""Thread that runs the rollout loop.""" """Thread that runs the rollout loop."""
try: try:
asyncio.run(self._rollout_thread_async()) uvloop.run(self._rollout_thread_async())
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
async def _rollout_thread_async(self): async def _rollout_thread_async(self):
pending_data = []
rollout_tasks = self.rollout_tasks rollout_tasks = self.rollout_tasks
rid = 0 rid = 0
# NOTE: session is not thread-safe, but we only submit requests in the sub-thread.
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.config.request_timeout,
sock_connect=self.config.request_timeout,
connect=self.config.request_timeout,
),
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
)
try: try:
while not self.exiting.is_set(): while not self.exiting.is_set():
# Load next data from controller
while True:
try:
data, workflow = self.input_queue.get_nowait()
logger.debug(f"Get data from puller: {data}")
pending_data.append(data)
except Empty:
logger.debug(f"No data from puller stream.")
break
# Check capacity # Check capacity
capacity = self.get_capacity() capacity = self.get_capacity()
# Create new rollout task # Create new rollout task
while capacity > 0 and pending_data and not self.paused.is_set(): while (
capacity > 0
and not self.paused.is_set()
and self.input_queue.qsize() > 0
):
data, workflow = self.input_queue.get_nowait()
logger.debug(f"Get data from puller: {data}")
task = asyncio.create_task( task = asyncio.create_task(
workflow.arun_episode(self, pending_data.pop(0)), name=str(rid) workflow.arun_episode(self, data), name=str(rid)
) )
with self.lock: with self.lock:
rollout_tasks[str(rid)] = task rollout_tasks[str(rid)] = task
@ -158,7 +168,6 @@ class RemoteSGLangEngine(InferenceEngine):
) )
capacity -= 1 capacity -= 1
rid += 1 rid += 1
# Wait for rollout completion # Wait for rollout completion
with self.lock: with self.lock:
tasks = list(rollout_tasks.values()) tasks = list(rollout_tasks.values())
@ -169,11 +178,6 @@ class RemoteSGLangEngine(InferenceEngine):
timeout=ROLLOUT_POLL_WAIT_TIME, timeout=ROLLOUT_POLL_WAIT_TIME,
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) )
if not done:
await asyncio.sleep(1)
else:
await asyncio.sleep(1)
# Collect done results # Collect done results
for task in done: for task in done:
traj = await task traj = await task
@ -199,6 +203,7 @@ class RemoteSGLangEngine(InferenceEngine):
f"running: {self.rollout_stat.running}, " f"running: {self.rollout_stat.running}, "
f"accepted: {self.rollout_stat.accepted}." f"accepted: {self.rollout_stat.accepted}."
) )
await asyncio.sleep(1)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
finally: finally:
@ -213,10 +218,11 @@ class RemoteSGLangEngine(InferenceEngine):
pass pass
def choose_server(self) -> str: def choose_server(self) -> str:
if self.config.schedule_policy == "round_robin": with self.lock:
server = self.addresses[self.server_idx] if self.config.schedule_policy == "round_robin":
self.server_idx = (self.server_idx + 1) % len(self.addresses) server = self.addresses[self.server_idx]
return server self.server_idx = (self.server_idx + 1) % len(self.addresses)
return server
raise NotImplementedError("Only round-robin scheduling is implemented.") raise NotImplementedError("Only round-robin scheduling is implemented.")
async def agenerate(self, req: LLMRequest) -> LLMResponse: async def agenerate(self, req: LLMRequest) -> LLMResponse:
@ -253,7 +259,6 @@ class RemoteSGLangEngine(InferenceEngine):
accumulated_versions = [] accumulated_versions = []
# Deal with rollout interruption # Deal with rollout interruption
completions = ""
stop_reason = "length" stop_reason = "length"
if req.rid in self.rid_to_address: if req.rid in self.rid_to_address:
@ -273,11 +278,12 @@ class RemoteSGLangEngine(InferenceEngine):
): ):
# loop until the generation is complete # loop until the generation is complete
result = await arequest_with_retry( result = await arequest_with_retry(
addr=self.choose_server(), session=self.session,
addr=server_addr,
endpoint="/generate", endpoint="/generate",
payload=payload, payload=payload,
method="POST", method="POST",
max_retries=3, max_retries=self.config.request_retries,
timeout=self.config.request_timeout, timeout=self.config.request_timeout,
) )
@ -297,10 +303,7 @@ class RemoteSGLangEngine(InferenceEngine):
stop_reason = finish_reason["type"] stop_reason = finish_reason["type"]
payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER] payload["input_ids"] += result[SGLANG_TOKEN_OUTPUT_IDENTIFIER]
sample_params["max_new_tokens"] = min( sample_params["max_new_tokens"] -= len(output_tokens)
sample_params["max_new_tokens"],
gconfig.max_new_tokens - len(output_tokens),
)
latency = time.perf_counter() - start_time latency = time.perf_counter() - start_time
@ -408,18 +411,24 @@ class RemoteSGLangEngine(InferenceEngine):
def prepare_batch( def prepare_batch(
self, self,
data_generator: Iterator,
dataloader: StatefulDataLoader, dataloader: StatefulDataLoader,
workflow: "RolloutWorkflow", workflow: "RolloutWorkflow",
): ):
if not hasattr(self, "data_generator"):
self.data_generator = iter(dataloader)
assert dataloader.batch_size is not None assert dataloader.batch_size is not None
while True: while True:
if self.get_capacity() + dataloader.batch_size > 0: # Submit at least two batches to allow maximum overlap
if (
self.get_capacity() + dataloader.batch_size > 0
and self.input_queue.qsize() + dataloader.batch_size
< self.input_queue.maxsize
):
try: try:
data = next(data_generator) data = next(self.data_generator)
except StopIteration: except StopIteration:
data_generator = iter(dataloader) self.data_generator = iter(dataloader)
data = next(data_generator) data = next(self.data_generator)
for item in data: for item in data:
self.submit(item, workflow=workflow) self.submit(item, workflow=workflow)
try: try:
@ -435,10 +444,12 @@ class RemoteSGLangEngine(InferenceEngine):
async def aupdate_weights_from_disk( async def aupdate_weights_from_disk(
addr, path: str, request_retries: int, request_timeout: float session, addr, path: str, request_retries: int, request_timeout: float
): ):
tik = time.time()
res = await arequest_with_retry( res = await arequest_with_retry(
addr=addr, addr=addr,
session=session,
endpoint="/update_weights_from_disk", endpoint="/update_weights_from_disk",
payload=dict(model_path=str(path), allow_interrupt=True), payload=dict(model_path=str(path), allow_interrupt=True),
method="POST", method="POST",
@ -472,9 +483,19 @@ def update_weights_from_disk(
logger.info( logger.info(
f"Begin update weights from {path}, responded in {(load_timestamp - save_timestamp):.2f}s" f"Begin update weights from {path}, responded in {(load_timestamp - save_timestamp):.2f}s"
) )
session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=request_timeout,
sock_connect=request_timeout,
connect=request_timeout,
),
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
)
jobs = [ jobs = [
aupdate_weights_from_disk( aupdate_weights_from_disk(
addr, session=session,
addr=addr,
path=path, path=path,
request_retries=request_retries, request_retries=request_retries,
request_timeout=request_timeout, request_timeout=request_timeout,
@ -482,8 +503,9 @@ def update_weights_from_disk(
for addr in addresses for addr in addresses
] ]
await asyncio.gather(*jobs) await asyncio.gather(*jobs)
await session.close()
logger.info( logger.info(
f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s" f"Loading weights done in {(datetime.now().timestamp() - load_timestamp):.2f}s"
) )
return asyncio.run(_fn()) return uvloop.run(_fn())

View File

@ -5,6 +5,7 @@ import time
import uuid import uuid
import pytest import pytest
import requests
import torch import torch
from tensordict import TensorDict from tensordict import TensorDict
@ -28,6 +29,17 @@ HOST = network.gethostip()
RUN_SERVER_TIMEOUT = 180 RUN_SERVER_TIMEOUT = 180
def check_server_health(base_url):
try:
response = requests.get(
f"{base_url}/metrics",
timeout=30,
)
return response.status_code == 200
except requests.exceptions.RequestException as e:
return False
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def sglang_server(): def sglang_server():
from realhf.base import seeding from realhf.base import seeding

View File

@ -67,7 +67,6 @@ async def test_local_sglang_generate():
gconfig=GenerationHyperparameters(max_new_tokens=16), gconfig=GenerationHyperparameters(max_new_tokens=16),
) )
resp = await engine.agenerate(req) resp = await engine.agenerate(req)
print(resp.completions)
assert isinstance(resp, LLMResponse) assert isinstance(resp, LLMResponse)
assert resp.input_tokens == req.input_ids assert resp.input_tokens == req.input_ids
@ -76,9 +75,6 @@ async def test_local_sglang_generate():
== len(resp.output_tokens) == len(resp.output_tokens)
== len(resp.output_versions) == len(resp.output_versions)
) )
assert isinstance(resp.completions, str)
time.sleep(5)
engine.destroy() engine.destroy()

View File

@ -13,7 +13,6 @@ from transformers import AutoTokenizer
from arealite.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig from arealite.api.cli_args import MicroBatchSpec, OptimizerConfig, TrainEngineConfig
from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta
from arealite.engine.fsdp_engine import FSDPEngine
VOCAB_SIZE = 100 VOCAB_SIZE = 100
MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/" MODEL_PATH = "/storage/testing/models/Qwen__Qwen3-1.7B/"
@ -53,10 +52,10 @@ def mock_input(
def get_engine(engine_type: str, model_path: str): def get_engine(engine_type: str, model_path: str):
from arealite.engine.autotp_engine import DeepSpeedAutoTPEngine
from arealite.engine.fsdp_engine import FSDPEngine from arealite.engine.fsdp_engine import FSDPEngine
from arealite.engine.hf_engine import HFEngine
engine_cls = {"hf": HFEngine, "fsdp": FSDPEngine}[engine_type] engine_cls = {"auto_tp": DeepSpeedAutoTPEngine, "fsdp": FSDPEngine}[engine_type]
engine_config = TrainEngineConfig( engine_config = TrainEngineConfig(
experiment_name=f"test-{engine_type}-engine", experiment_name=f"test-{engine_type}-engine",
@ -75,7 +74,7 @@ def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
return torch.mean(logits) return torch.mean(logits)
@pytest.fixture(scope="module", params=["fsdp", "hf"]) @pytest.fixture(scope="module", params=["fsdp", "auto_tp"])
def engine(request): def engine(request):
os.environ.update( os.environ.update(
{ {
@ -136,6 +135,12 @@ def test_train_batch(engine, mock_input):
@torch.no_grad() @torch.no_grad()
def test_hf_save_load_weights(tmp_path_factory, engine, mock_input): def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
from arealite.engine.autotp_engine import DeepSpeedAutoTPEngine
if isinstance(engine, DeepSpeedAutoTPEngine):
print("AutoTP engine does not support HF save/load for now.")
return
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
path = tmp_path_factory.mktemp("hf_engine_test") path = tmp_path_factory.mktemp("hf_engine_test")
save_load_meta = SaveLoadMeta( save_load_meta = SaveLoadMeta(

View File

@ -92,9 +92,7 @@ def unpad_input(
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad( cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
return ( return (
rearrange(hidden_states, "b s ... -> (b s) ...")[indices], rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
indices, indices,
@ -116,16 +114,12 @@ def concat_padded_tensors(
if not tensor_dicts: if not tensor_dicts:
return TensorDict() return TensorDict()
# Find max sequence length across all dictionaries batch_sizes = [tuple(d.batch_size) for d in tensor_dicts]
lens = [] new_batch_size = [sum(x[0] for x in batch_sizes), *batch_sizes[0][1:]]
for tensor_dict in tensor_dicts:
for key, tensor in tensor_dict.items():
if key != "attention_mask" and len(tensor.shape) == 2:
lens.append(tensor.shape[1])
break
max_length = max(lens)
attn_mask = torch.arange(max_length).unsqueeze(0) < torch.tensor(lens).unsqueeze(1)
# Find max sequence length across all dictionaries
assert all("attention_mask" in td for td in tensor_dicts)
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
result = {} result = {}
# Process each key # Process each key
for key in tensor_dicts[0].keys(): for key in tensor_dicts[0].keys():
@ -154,9 +148,7 @@ def concat_padded_tensors(
tensors_to_concat.append(tensor) tensors_to_concat.append(tensor)
result[key] = torch.cat(tensors_to_concat, dim=0) result[key] = torch.cat(tensors_to_concat, dim=0)
if "attention_mask" not in result: return TensorDict(result, batch_size=new_batch_size)
result["attention_mask"] = attn_mask
return TensorDict(result, batch_size=[len(lens)])
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]: def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
@ -231,7 +223,7 @@ def pack_tensor_dict(data: TensorDict):
cu_seqlens = torch.cumsum(lens, dim=0) cu_seqlens = torch.cumsum(lens, dim=0)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
total_length = cu_seqlens[-1].item() total_length = int(cu_seqlens[-1].item())
# Pack tensors # Pack tensors
packed_data = {} packed_data = {}
for key, value in data.items(): for key, value in data.items():
@ -246,7 +238,7 @@ def pack_tensor_dict(data: TensorDict):
and value.shape[1] == seq_len and value.shape[1] == seq_len
): ):
packed_tensor = torch.empty( packed_tensor = torch.empty(
total_length, *value.shape[2:], dtype=value.dtype, device=value.device (total_length, *value.shape[2:]), dtype=value.dtype, device=value.device
) )
# Fill the packed tensor with values from the original tensor # Fill the packed tensor with values from the original tensor
for i in range(bs): for i in range(bs):
@ -363,7 +355,7 @@ def split_padded_tensor_dict_into_mb_list(
mbs=results, mbs=results,
mb_spec=mb_spec, mb_spec=mb_spec,
forward_indices=forward_indices, forward_indices=forward_indices,
backward_indices=backward_indices, backward_indices=backward_indices.tolist(),
group_lens=group_lens, group_lens=group_lens,
) )

View File

@ -121,19 +121,24 @@ def masked_normalization(
def ppo_actor_loss_fn( def ppo_actor_loss_fn(
logprobs: torch.Tensor, logprobs: torch.Tensor,
proximal_logprobs: torch.Tensor,
old_logprobs: torch.Tensor, old_logprobs: torch.Tensor,
advantages: torch.Tensor, advantages: torch.Tensor,
eps_clip: float, eps_clip: float,
loss_mask: torch.Tensor, loss_mask: torch.Tensor,
c_clip: Optional[float] = None, c_clip: Optional[float] = None,
proximal_logprobs: Optional[torch.Tensor] = None,
behav_imp_weight_cap: Optional[float] = None, behav_imp_weight_cap: Optional[float] = None,
) -> Tuple[torch.Tensor, Dict]: ) -> Tuple[torch.Tensor, Dict]:
denorm_logprobs = ( """
proximal_logprobs if proximal_logprobs is not None else old_logprobs When decoupled loss is disabled:
) 1. if recompute logp, both old_logprobs and proximal_logprobs are recomputed logp;
2. if no recomputation, both old_logp and proximal_logprobs are produced by the inference backend.
When decoupled loss is enabled, proximal_logprobs is the recomputed logp,
old_logprobs is produced by the inference engine.
"""
loss_mask_count = loss_mask.count_nonzero() or 1 loss_mask_count = loss_mask.count_nonzero() or 1
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0) ratio = torch.where(loss_mask, torch.exp(logprobs - proximal_logprobs), 0)
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
pg_loss1 = -advantages * ratio pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * clipped_ratio pg_loss2 = -advantages * clipped_ratio
@ -146,17 +151,16 @@ def ppo_actor_loss_fn(
pg_loss = torch.min(pg_loss, pg_loss3) pg_loss = torch.min(pg_loss, pg_loss3)
else: else:
dual_clip_mask = torch.zeros_like(clip_mask) dual_clip_mask = torch.zeros_like(clip_mask)
if proximal_logprobs is not None: behav_kl = proximal_logprobs - old_logprobs
behav_kl = proximal_logprobs - old_logprobs behav_imp_weight = behav_kl.exp()
behav_imp_weight = behav_kl.exp() behav_mask = (
behav_mask = ( (behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask)
(behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask) if behav_imp_weight_cap is not None
if behav_imp_weight_cap is not None else loss_mask
else loss_mask )
) behav_kl = torch.where(behav_mask, behav_kl, 0.0)
behav_kl = torch.where(behav_mask, behav_kl, 0.0) behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0) pg_loss = pg_loss * behav_imp_weight
pg_loss = pg_loss * behav_imp_weight
logging_loss = pg_loss.detach() logging_loss = pg_loss.detach()
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
clip_mask.logical_and_(loss_mask) clip_mask.logical_and_(loss_mask)
@ -164,7 +168,7 @@ def ppo_actor_loss_fn(
stat = dict( stat = dict(
loss=logging_loss, loss=logging_loss,
importance_weight=ratio.detach(), importance_weight=ratio.detach(),
approx_kl=(logprobs - denorm_logprobs).detach(), approx_kl=(logprobs - proximal_logprobs).detach(),
clip_mask=clip_mask, clip_mask=clip_mask,
dual_clip_mask=dual_clip_mask, dual_clip_mask=dual_clip_mask,
) )

View File

@ -3,45 +3,74 @@ from typing import Any, Dict, Optional
import aiohttp import aiohttp
from realhf.base import logging
DEFAULT_RETRIES = 1 DEFAULT_RETRIES = 1
DEFAULT_REQUEST_TIMEOUT = 3600 DEFAULT_REQUEST_TIMEOUT = 3600
logger = logging.getLogger(__file__)
def get_default_connector():
return aiohttp.TCPConnector(limit=0, use_dns_cache=False, force_close=True)
async def arequest_with_retry( async def arequest_with_retry(
addr: str, addr: str,
endpoint: str, endpoint: str,
payload: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None,
session: aiohttp.ClientSession | None = None,
method: str = "POST", method: str = "POST",
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
retry_delay: float = 1.0, retry_delay: float = 1.0,
) -> aiohttp.ClientResponse: verbose=False,
) -> Dict:
timeout = timeout or DEFAULT_REQUEST_TIMEOUT timeout = timeout or DEFAULT_REQUEST_TIMEOUT
last_exception = None last_exception = None
max_retries = max_retries or DEFAULT_RETRIES max_retries = max_retries or DEFAULT_RETRIES
base_url = f"http://{addr}" base_url = f"http://{addr}"
url = f"{base_url}{endpoint}" url = f"{base_url}{endpoint}"
timeo = aiohttp.ClientTimeout(
total=timeout,
sock_connect=timeout,
connect=timeout,
)
if session is None:
_session = aiohttp.ClientSession(
timeout=timeo,
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
)
else:
_session = session
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
async with aiohttp.ClientSession( if verbose:
timeout=aiohttp.ClientTimeout( logger.info("enter client session, start sending requests")
total=timeout, if method.upper() == "GET":
sock_connect=timeout, ctx = _session.get(url, timeout=timeo)
) elif method.upper() == "POST":
) as session: ctx = _session.post(url, json=payload, timeout=timeo)
if method.upper() == "GET": elif method.upper() == "PUT":
response = await session.get(url) ctx = _session.put(url, json=payload, timeout=timeo)
elif method.upper() == "POST": elif method.upper() == "DELETE":
response = await session.post(url, json=payload) ctx = _session.delete(url, timeout=timeo)
elif method.upper() == "PUT": else:
response = await session.put(url, json=payload) raise ValueError(f"Unsupported HTTP method: {method}")
elif method.upper() == "DELETE": async with ctx as response:
response = await session.delete(url) if verbose:
else: logger.info("http requests return")
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status() response.raise_for_status()
return await response.json() res = await response.json()
if verbose:
logger.info("get http result")
if session is None:
await _session.close()
return res
except ( except (
aiohttp.ClientError, aiohttp.ClientError,
aiohttp.ClientResponseError, aiohttp.ClientResponseError,
@ -51,6 +80,10 @@ async def arequest_with_retry(
if attempt < max_retries - 1: if attempt < max_retries - 1:
await asyncio.sleep(retry_delay) await asyncio.sleep(retry_delay)
continue continue
if session is None:
await _session.close()
raise RuntimeError( raise RuntimeError(
f"Failed after {max_retries} retries each. " f"Last error: {last_exception}" f"Failed after {max_retries} retries each. "
f"Payload: {payload}. Addr: {addr}. Endpoint: {endpoint}. "
f"Last error: {last_exception}"
) )

View File

@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizerFast
from arealite.api.cli_args import GenerationHyperparameters from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.io_struct import LLMRequest from arealite.api.io_struct import LLMRequest
from arealite.api.workflow_api import RolloutWorkflow from arealite.api.workflow_api import RolloutWorkflow
from arealite.utils.padding import concat_padded_tensors from arealite.utils.data import concat_padded_tensors
class RLVRWorkflow(RolloutWorkflow): class RLVRWorkflow(RolloutWorkflow):
@ -61,7 +61,7 @@ class RLVRWorkflow(RolloutWorkflow):
versions=torch.tensor(versions).unsqueeze(0), versions=torch.tensor(versions).unsqueeze(0),
attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0), attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
# reward # reward
rewards=torch.tensor([reward]), rewards=torch.tensor([float(reward)]),
) )
results.append(TensorDict(res, batch_size=[1])) results.append(TensorDict(res, batch_size=[1]))

View File

@ -152,11 +152,7 @@ def main_grpo():
with stats_tracker.record_timing("rollout"): with stats_tracker.record_timing("rollout"):
if config.async_training: if config.async_training:
batch = rollout.prepare_batch( batch = rollout.prepare_batch(train_dataloader, workflow=workflow)
data_generator,
train_dataloader,
workflow=workflow,
)
else: else:
try: try:
data = next(data_generator) data = next(data_generator)
@ -210,8 +206,9 @@ def main_grpo():
actor.upload_weights(meta) actor.upload_weights(meta)
if dist.get_rank() == 0: if dist.get_rank() == 0:
future.result() future.result()
rollout.set_version(global_step + 1)
dist.barrier() dist.barrier()
torch.cuda.synchronize()
rollout.set_version(global_step + 1)
with stats_tracker.record_timing("save"): with stats_tracker.record_timing("save"):
saver.save(actor, epoch, step, global_step) saver.save(actor, epoch, step, global_step)

View File

@ -1360,7 +1360,9 @@ class RayNameResolveRepository:
def make_repository(args: "NameResolveConfig"): def make_repository(args: "NameResolveConfig"):
if args.type == "nfs": if args.type == "nfs":
return NfsNameRecordRepository(args.nfs_record_root) repo = NfsNameRecordRepository(args.nfs_record_root)
os.makedirs(repo.record_root, exist_ok=True)
return repo
elif args.type == "etcd3": elif args.type == "etcd3":
host, port = args.etcd3_addr.split(":") host, port = args.etcd3_addr.split(":")
return Etcd3NameRecordRepository(host=host, port=int(port)) return Etcd3NameRecordRepository(host=host, port=int(port))