PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine

Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/354

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
This commit is contained in:
博惟 2025-07-14 11:19:40 +08:00 committed by 晓雷
parent 434d2f5064
commit d8038b2669
5 changed files with 135 additions and 60 deletions

View File

@ -12,7 +12,6 @@ from hydra import initialize as hydra_init
from omegaconf import MISSING, OmegaConf
from arealite.utils.fs import get_user_tmp
from realhf.api.cli_args import OptimizerConfig
@dataclass
@ -84,6 +83,61 @@ class GenerationHyperparameters:
# Train Engine Configs
@dataclass
class OptimizerConfig:
"""Configuration for model optimization during training.
Note:
Set type to "empty" for models that won't be trained.
"""
type: str = field(
default="adam",
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
)
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
min_lr_ratio: float = field(
default=0.0,
metadata={
"help": "Minimum learning rate ratio after annealing",
},
)
lr_scheduler_type: str = field(
default="constant",
metadata={
"help": "Learning rate scheduler type",
"choices": ["linear", "cosine", "constant"],
},
)
warmup_steps_proportion: float = field(
default=0.001,
metadata={
"help": "Proportion of training steps for warmup",
},
)
offload: bool = field(
default=False, metadata={"help": "Enable optimizer state offloading"}
)
initial_loss_scale: float = field(
default=2**32, metadata={"help": "Initial loss scaling factor"}
)
min_loss_scale: float = field(
default=1.0, metadata={"help": "Minimum loss scaling factor"}
)
loss_scale_window: float = field(
default=5, metadata={"help": "Window size for loss scaling adjustment"}
)
hysteresis: int = field(
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
)
gradient_clipping: float = field(
default=1.0, metadata={"help": "Gradient clipping threshold"}
)
@dataclass
class FSDPWrapPolicy:
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
@ -127,10 +181,11 @@ class TrainEngineConfig:
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
# Training Backend Configuration
disable_dropout: bool = field(default=False)
gradient_checkpointing: bool = field(
default=True, metadata={"help": "Enable gradient checkpointing"}
)
bf16: bool = field(default=False, metadata={"help": "Use bf16 precision"})
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."})
optimizer: Optional[OptimizerConfig] = field(
default=None, metadata={"help": "Optimizer configuration"}
)

View File

@ -1,6 +1,7 @@
import gc
import os
import time
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
import torch
@ -32,7 +33,7 @@ from arealite.utils.data import (
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_packed_tensor_dict_into_mb_list,
split_padded_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
@ -45,6 +46,7 @@ from arealite.utils.fsdp import (
fsdp2_load_full_state_dict,
get_cosine_schedule_with_warmup,
)
from arealite.utils.model import disable_dropout_in_model
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, name_resolve, names, pkg_version
@ -95,19 +97,38 @@ class FSDPEngine(TrainEngine):
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):
# initialize scratch model from config
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
@ -116,7 +137,7 @@ class FSDPEngine(TrainEngine):
# Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
param_dtype=dtype,
reduce_dtype=torch.float32,
cast_forward_inputs=True,
)
@ -134,23 +155,14 @@ class FSDPEngine(TrainEngine):
}
# Wrap with FSDP2
tik = time.perf_counter()
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
self.model = model
if not self.config.init_from_scratch:
# Load model from a initial checkpoint path,
# which should only be a huggingface checkpoint.
load_meta = SaveLoadMeta(
path=self.config.path,
weight_format="hf",
with_optim=False,
tokenizer=None,
base_model_path=self.config.path,
)
self.load(load_meta)
# Set up optimizer
if self.optimizer_config is not None:
tik = time.perf_counter()
assert (
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
@ -194,6 +206,7 @@ class FSDPEngine(TrainEngine):
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
self.initialized = True
@ -328,15 +341,19 @@ class FSDPEngine(TrainEngine):
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
packed_input = pack_tensor_dict(input_)
mb_list = split_packed_tensor_dict_into_mb_list(
packed_input,
self.config.mb_spec,
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
logger.info(
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
)
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
# FIXME: the resulting max_seqlen is a tensor rather than an integer
for mb in mb_list.mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
return mb_list
def train_batch(
@ -361,9 +378,10 @@ class FSDPEngine(TrainEngine):
dist.all_reduce(total_loss_weight)
# Process microbatches with gradient accumulation
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)

View File

@ -8,7 +8,7 @@ from arealite.utils.data import (
pad_and_stack_tensors_along_first_dim,
pad_sequences_to_tensors,
reorder_list,
split_packed_tensor_dict_into_mb_list,
split_padded_tensor_dict_into_mb_list,
unpack_sequence,
)
@ -45,7 +45,8 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
packed_data = pack_tensor_dict(mock_padded_data)
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
split_result = split_packed_tensor_dict_into_mb_list(packed_data, mb_spec)
split_result = split_padded_tensor_dict_into_mb_list(mock_padded_data, mb_spec)
split_result.mbs = [pack_tensor_dict(mb) for mb in split_result.mbs]
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
# assert microbatch split result does not violate requirements

View File

@ -110,11 +110,11 @@ def pad_input(hidden_states, indices, batch, seqlen):
def concat_padded_tensors(
tensor_dicts: List[Dict[str, torch.Tensor]], pad_value: float = 0.0
) -> Dict[str, torch.Tensor]:
tensor_dicts: List[TensorDict], pad_value: float = 0.0
) -> TensorDict:
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
if not tensor_dicts:
return {}
return TensorDict()
# Find max sequence length across all dictionaries
lens = []
@ -156,7 +156,7 @@ def concat_padded_tensors(
result[key] = torch.cat(tensors_to_concat, dim=0)
if "attention_mask" not in result:
result["attention_mask"] = attn_mask
return result
return TensorDict(result, batch_size=[len(lens)])
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
@ -290,13 +290,13 @@ class MicroBatchList:
DEFAULT_MAX_TOKENS_PER_MB = int(1e12)
def split_packed_tensor_dict_into_mb_list(
def split_padded_tensor_dict_into_mb_list(
data: TensorDict, mb_spec: MicroBatchSpec, group: Optional[dist.ProcessGroup] = None
) -> MicroBatchList:
"""Split a packed tensordict into micro-batches based on the cumulative sequence lengths.
"""Split a padded tensordict into micro-batches based on the attention mask.
Args:
data (TensorDict): Dictionary containing packed tensors with "cu_seqlens" key.
data (TensorDict): Dictionary containing padded tensors.
mb_spec (MicroBatchSpec): Specification for micro-batch splitting.
group (Optional[dist.ProcessGroup]): Process group for distributed synchronization.
@ -304,24 +304,21 @@ def split_packed_tensor_dict_into_mb_list(
MicroBatchList: A structure containing the split micro-batches and metadata.
"""
assert (
"cu_seqlens" in data
), "Input data must be packed and contain 'cu_seqlens' key."
"attention_mask" in data
), "Input data must be padded and contain 'attention_mask' key."
if mb_spec.max_tokens_per_mb is None:
mb_spec = MicroBatchSpec.new(
mb_spec, max_tokens_per_mb=DEFAULT_MAX_TOKENS_PER_MB
)
cu_seqlens = data["cu_seqlens"]
bs = cu_seqlens.shape[0] - 1
total_lens = int(cu_seqlens[-1])
input_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy()
bs = data["attention_mask"].shape[0]
max_seqlen = data["attention_mask"].shape[1]
input_lens = data["attention_mask"].sum(1).long().cpu().numpy()
# check tensor shape, split only 1d tensors with length "total_lens"
to_split = {}
not_to_split = {}
for key, value in data.items():
if key == "cu_seqlens" or key == "max_seqlen":
continue
if not torch.is_tensor(value) or value.numel() != total_lens:
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
not_to_split[key] = value
else:
to_split[key] = value
@ -331,6 +328,7 @@ def split_packed_tensor_dict_into_mb_list(
splitted_lens = [
[input_lens[i] for i in group_index] for group_index in group_indices
]
group_n_seqs = [len(x) for x in splitted_lens]
group_lens = [sum(x) for x in splitted_lens]
forward_indices = datapack.flat2d(group_indices)
@ -340,12 +338,16 @@ def split_packed_tensor_dict_into_mb_list(
def _split(tensor):
"""Split and pad a tensor based on forward indices and lens."""
# Unpack the sequence
unpacked = unpack_sequence(tensor, cu_seqlens=cu_seqlens)
unpacked = [tensor[i] for i in range(bs)]
# Reorder according to forward indices
reordered = reorder_list(unpacked, forward_indices)
reordered = torch.cat(reordered)
reordered = torch.stack(reordered)
# Unpack again according to split lens
splitted = unpack_sequence(reordered, lens=group_lens)
splitted = []
offset = 0
for _n_seqs in group_n_seqs:
splitted.append(reordered[offset : offset + _n_seqs])
offset += _n_seqs
return splitted
to_split = dict_map(to_split, lambda x: _split(x))
@ -355,16 +357,7 @@ def split_packed_tensor_dict_into_mb_list(
# organize splitted micro batches
assert len(mbs) == len(splitted_lens), (len(mbs), len(splitted_lens))
for i, (mb, lens) in enumerate(zip(mbs, splitted_lens)):
max_seqlen = max(lens)
lens = torch.tensor(lens, device="cuda")
batch_cu_seqlens = torch.nn.functional.pad(
lens.cumsum(0, dtype=torch.int), (1, 0)
)
results.append(
TensorDict(
**mb, **not_to_split, max_seqlen=max_seqlen, cu_seqlens=batch_cu_seqlens
)
)
results.append(TensorDict(**mb, **not_to_split))
return MicroBatchList(
data=data,
mbs=results,
@ -433,7 +426,7 @@ def pad_mb_list(
# NOTE: GPU page size is 2MB
# Take hidden size 4096 with bf16 dtype as an example,
# the batch size of a page is 256
pad_to_length = (l + 255) // 256 * 256
pad_to_length = (int(l) + 255) // 256 * 256
padded_mb, pad_len = pad_packed_tensor_dict(
mb, pad_to_length, pad_value=pad_value
)

8
arealite/utils/model.py Normal file
View File

@ -0,0 +1,8 @@
import torch
# Copied from trl
def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0