mirror of https://github.com/inclusionAI/AReaL
PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine
Merge branch fw/lite-fix1 of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/354 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * .
This commit is contained in:
parent
434d2f5064
commit
d8038b2669
|
@ -12,7 +12,6 @@ from hydra import initialize as hydra_init
|
|||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
from arealite.utils.fs import get_user_tmp
|
||||
from realhf.api.cli_args import OptimizerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -84,6 +83,61 @@ class GenerationHyperparameters:
|
|||
|
||||
|
||||
# Train Engine Configs
|
||||
@dataclass
|
||||
class OptimizerConfig:
|
||||
"""Configuration for model optimization during training.
|
||||
Note:
|
||||
Set type to "empty" for models that won't be trained.
|
||||
"""
|
||||
|
||||
type: str = field(
|
||||
default="adam",
|
||||
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
|
||||
)
|
||||
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
|
||||
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
|
||||
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
|
||||
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
|
||||
eps: float = field(default=1e-5, metadata={"help": "Adam epsilon parameter"})
|
||||
min_lr_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "Minimum learning rate ratio after annealing",
|
||||
},
|
||||
)
|
||||
lr_scheduler_type: str = field(
|
||||
default="constant",
|
||||
metadata={
|
||||
"help": "Learning rate scheduler type",
|
||||
"choices": ["linear", "cosine", "constant"],
|
||||
},
|
||||
)
|
||||
warmup_steps_proportion: float = field(
|
||||
default=0.001,
|
||||
metadata={
|
||||
"help": "Proportion of training steps for warmup",
|
||||
},
|
||||
)
|
||||
offload: bool = field(
|
||||
default=False, metadata={"help": "Enable optimizer state offloading"}
|
||||
)
|
||||
initial_loss_scale: float = field(
|
||||
default=2**32, metadata={"help": "Initial loss scaling factor"}
|
||||
)
|
||||
min_loss_scale: float = field(
|
||||
default=1.0, metadata={"help": "Minimum loss scaling factor"}
|
||||
)
|
||||
loss_scale_window: float = field(
|
||||
default=5, metadata={"help": "Window size for loss scaling adjustment"}
|
||||
)
|
||||
hysteresis: int = field(
|
||||
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
|
||||
)
|
||||
gradient_clipping: float = field(
|
||||
default=1.0, metadata={"help": "Gradient clipping threshold"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FSDPWrapPolicy:
|
||||
transformer_layer_cls_to_wrap: Optional[List[str]] = field(
|
||||
|
@ -127,10 +181,11 @@ class TrainEngineConfig:
|
|||
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
||||
|
||||
# Training Backend Configuration
|
||||
disable_dropout: bool = field(default=False)
|
||||
gradient_checkpointing: bool = field(
|
||||
default=True, metadata={"help": "Enable gradient checkpointing"}
|
||||
)
|
||||
bf16: bool = field(default=False, metadata={"help": "Use bf16 precision"})
|
||||
dtype: str = field(default="float16", metadata={"help": "Parameter dtype."})
|
||||
optimizer: Optional[OptimizerConfig] = field(
|
||||
default=None, metadata={"help": "Optimizer configuration"}
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import gc
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
@ -32,7 +33,7 @@ from arealite.utils.data import (
|
|||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_mb_list,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mb_list,
|
||||
split_padded_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
unsqueeze_mb_list,
|
||||
)
|
||||
|
@ -45,6 +46,7 @@ from arealite.utils.fsdp import (
|
|||
fsdp2_load_full_state_dict,
|
||||
get_cosine_schedule_with_warmup,
|
||||
)
|
||||
from arealite.utils.model import disable_dropout_in_model
|
||||
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import logging, name_resolve, names, pkg_version
|
||||
|
@ -95,19 +97,38 @@ class FSDPEngine(TrainEngine):
|
|||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
self.device = torch.device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
|
||||
dtype = getattr(torch, self.config.dtype)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||
tik = time.perf_counter()
|
||||
with torch.device("cuda"):
|
||||
# initialize scratch model from config
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
if self.config.init_from_scratch:
|
||||
# initialize scratch model from config
|
||||
# NOTE: VLM cannot directly load state dict using this
|
||||
# random initialized model, so otherwise we call
|
||||
# from_pretrained rather than loading weights into this random model.
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
self.model_config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
if self.config.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logger.info(f"Model creation and loading time: {time.perf_counter() - tik}")
|
||||
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
|
@ -116,7 +137,7 @@ class FSDPEngine(TrainEngine):
|
|||
|
||||
# Simple auto wrap policy
|
||||
self.mixed_precision_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16,
|
||||
param_dtype=dtype,
|
||||
reduce_dtype=torch.float32,
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
|
@ -134,23 +155,14 @@ class FSDPEngine(TrainEngine):
|
|||
}
|
||||
|
||||
# Wrap with FSDP2
|
||||
tik = time.perf_counter()
|
||||
apply_fsdp2(model, fsdp_kwargs, self.config.fsdp.wrap_policy)
|
||||
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
|
||||
self.model = model
|
||||
|
||||
if not self.config.init_from_scratch:
|
||||
# Load model from a initial checkpoint path,
|
||||
# which should only be a huggingface checkpoint.
|
||||
load_meta = SaveLoadMeta(
|
||||
path=self.config.path,
|
||||
weight_format="hf",
|
||||
with_optim=False,
|
||||
tokenizer=None,
|
||||
base_model_path=self.config.path,
|
||||
)
|
||||
self.load(load_meta)
|
||||
|
||||
# Set up optimizer
|
||||
if self.optimizer_config is not None:
|
||||
tik = time.perf_counter()
|
||||
assert (
|
||||
self.optimizer_config.type == "adam"
|
||||
), "Only AdamW optimizer is supported in this engine."
|
||||
|
@ -194,6 +206,7 @@ class FSDPEngine(TrainEngine):
|
|||
raise ValueError(
|
||||
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
||||
)
|
||||
logger.info(f"Create optimizer time: {time.perf_counter() - tik}")
|
||||
|
||||
self.initialized = True
|
||||
|
||||
|
@ -328,15 +341,19 @@ class FSDPEngine(TrainEngine):
|
|||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
input_ = amend_position_ids(input_)
|
||||
packed_input = pack_tensor_dict(input_)
|
||||
mb_list = split_packed_tensor_dict_into_mb_list(
|
||||
packed_input,
|
||||
self.config.mb_spec,
|
||||
mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
|
||||
logger.info(
|
||||
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
|
||||
)
|
||||
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
|
||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
# packed input to be of shape [1, total_seqlen].
|
||||
mb_list = unsqueeze_mb_list(mb_list)
|
||||
# FIXME: the resulting max_seqlen is a tensor rather than an integer
|
||||
for mb in mb_list.mbs:
|
||||
mb["max_seqlen"] = int(mb["max_seqlen"])
|
||||
mb["use_cache"] = False
|
||||
return mb_list
|
||||
|
||||
def train_batch(
|
||||
|
@ -361,9 +378,10 @@ class FSDPEngine(TrainEngine):
|
|||
dist.all_reduce(total_loss_weight)
|
||||
|
||||
# Process microbatches with gradient accumulation
|
||||
for pad_length, padded_mb_input, mb_input in zip(
|
||||
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
||||
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
||||
):
|
||||
self.model.set_is_last_backward(i == len(mb_list.mbs) - 1)
|
||||
outputs = self.model(**padded_mb_input)
|
||||
|
||||
logits = outputs.logits.squeeze(0)
|
||||
|
|
|
@ -8,7 +8,7 @@ from arealite.utils.data import (
|
|||
pad_and_stack_tensors_along_first_dim,
|
||||
pad_sequences_to_tensors,
|
||||
reorder_list,
|
||||
split_packed_tensor_dict_into_mb_list,
|
||||
split_padded_tensor_dict_into_mb_list,
|
||||
unpack_sequence,
|
||||
)
|
||||
|
||||
|
@ -45,7 +45,8 @@ def test_micro_batch_split(mock_padded_data, n_mbs, max_tokens_per_mb):
|
|||
packed_data = pack_tensor_dict(mock_padded_data)
|
||||
original_lens = packed_data["cu_seqlens"][1:] - packed_data["cu_seqlens"][:-1]
|
||||
assert torch.allclose(original_lens, mock_padded_data["attention_mask"].sum(1))
|
||||
split_result = split_packed_tensor_dict_into_mb_list(packed_data, mb_spec)
|
||||
split_result = split_padded_tensor_dict_into_mb_list(mock_padded_data, mb_spec)
|
||||
split_result.mbs = [pack_tensor_dict(mb) for mb in split_result.mbs]
|
||||
reordered_lens = [original_lens[i] for i in split_result.forward_indices]
|
||||
|
||||
# assert microbatch split result does not violate requirements
|
||||
|
|
|
@ -110,11 +110,11 @@ def pad_input(hidden_states, indices, batch, seqlen):
|
|||
|
||||
|
||||
def concat_padded_tensors(
|
||||
tensor_dicts: List[Dict[str, torch.Tensor]], pad_value: float = 0.0
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
tensor_dicts: List[TensorDict], pad_value: float = 0.0
|
||||
) -> TensorDict:
|
||||
"""Concatenate and pad tensors from multiple padded tensor dictionaries."""
|
||||
if not tensor_dicts:
|
||||
return {}
|
||||
return TensorDict()
|
||||
|
||||
# Find max sequence length across all dictionaries
|
||||
lens = []
|
||||
|
@ -156,7 +156,7 @@ def concat_padded_tensors(
|
|||
result[key] = torch.cat(tensors_to_concat, dim=0)
|
||||
if "attention_mask" not in result:
|
||||
result["attention_mask"] = attn_mask
|
||||
return result
|
||||
return TensorDict(result, batch_size=[len(lens)])
|
||||
|
||||
|
||||
def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]:
|
||||
|
@ -290,13 +290,13 @@ class MicroBatchList:
|
|||
DEFAULT_MAX_TOKENS_PER_MB = int(1e12)
|
||||
|
||||
|
||||
def split_packed_tensor_dict_into_mb_list(
|
||||
def split_padded_tensor_dict_into_mb_list(
|
||||
data: TensorDict, mb_spec: MicroBatchSpec, group: Optional[dist.ProcessGroup] = None
|
||||
) -> MicroBatchList:
|
||||
"""Split a packed tensordict into micro-batches based on the cumulative sequence lengths.
|
||||
"""Split a padded tensordict into micro-batches based on the attention mask.
|
||||
|
||||
Args:
|
||||
data (TensorDict): Dictionary containing packed tensors with "cu_seqlens" key.
|
||||
data (TensorDict): Dictionary containing padded tensors.
|
||||
mb_spec (MicroBatchSpec): Specification for micro-batch splitting.
|
||||
group (Optional[dist.ProcessGroup]): Process group for distributed synchronization.
|
||||
|
||||
|
@ -304,24 +304,21 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
MicroBatchList: A structure containing the split micro-batches and metadata.
|
||||
"""
|
||||
assert (
|
||||
"cu_seqlens" in data
|
||||
), "Input data must be packed and contain 'cu_seqlens' key."
|
||||
"attention_mask" in data
|
||||
), "Input data must be padded and contain 'attention_mask' key."
|
||||
if mb_spec.max_tokens_per_mb is None:
|
||||
mb_spec = MicroBatchSpec.new(
|
||||
mb_spec, max_tokens_per_mb=DEFAULT_MAX_TOKENS_PER_MB
|
||||
)
|
||||
cu_seqlens = data["cu_seqlens"]
|
||||
bs = cu_seqlens.shape[0] - 1
|
||||
total_lens = int(cu_seqlens[-1])
|
||||
input_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy()
|
||||
bs = data["attention_mask"].shape[0]
|
||||
max_seqlen = data["attention_mask"].shape[1]
|
||||
input_lens = data["attention_mask"].sum(1).long().cpu().numpy()
|
||||
|
||||
# check tensor shape, split only 1d tensors with length "total_lens"
|
||||
to_split = {}
|
||||
not_to_split = {}
|
||||
for key, value in data.items():
|
||||
if key == "cu_seqlens" or key == "max_seqlen":
|
||||
continue
|
||||
if not torch.is_tensor(value) or value.numel() != total_lens:
|
||||
if not torch.is_tensor(value) or value.numel() != bs * max_seqlen:
|
||||
not_to_split[key] = value
|
||||
else:
|
||||
to_split[key] = value
|
||||
|
@ -331,6 +328,7 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
splitted_lens = [
|
||||
[input_lens[i] for i in group_index] for group_index in group_indices
|
||||
]
|
||||
group_n_seqs = [len(x) for x in splitted_lens]
|
||||
group_lens = [sum(x) for x in splitted_lens]
|
||||
|
||||
forward_indices = datapack.flat2d(group_indices)
|
||||
|
@ -340,12 +338,16 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
def _split(tensor):
|
||||
"""Split and pad a tensor based on forward indices and lens."""
|
||||
# Unpack the sequence
|
||||
unpacked = unpack_sequence(tensor, cu_seqlens=cu_seqlens)
|
||||
unpacked = [tensor[i] for i in range(bs)]
|
||||
# Reorder according to forward indices
|
||||
reordered = reorder_list(unpacked, forward_indices)
|
||||
reordered = torch.cat(reordered)
|
||||
reordered = torch.stack(reordered)
|
||||
# Unpack again according to split lens
|
||||
splitted = unpack_sequence(reordered, lens=group_lens)
|
||||
splitted = []
|
||||
offset = 0
|
||||
for _n_seqs in group_n_seqs:
|
||||
splitted.append(reordered[offset : offset + _n_seqs])
|
||||
offset += _n_seqs
|
||||
return splitted
|
||||
|
||||
to_split = dict_map(to_split, lambda x: _split(x))
|
||||
|
@ -355,16 +357,7 @@ def split_packed_tensor_dict_into_mb_list(
|
|||
# organize splitted micro batches
|
||||
assert len(mbs) == len(splitted_lens), (len(mbs), len(splitted_lens))
|
||||
for i, (mb, lens) in enumerate(zip(mbs, splitted_lens)):
|
||||
max_seqlen = max(lens)
|
||||
lens = torch.tensor(lens, device="cuda")
|
||||
batch_cu_seqlens = torch.nn.functional.pad(
|
||||
lens.cumsum(0, dtype=torch.int), (1, 0)
|
||||
)
|
||||
results.append(
|
||||
TensorDict(
|
||||
**mb, **not_to_split, max_seqlen=max_seqlen, cu_seqlens=batch_cu_seqlens
|
||||
)
|
||||
)
|
||||
results.append(TensorDict(**mb, **not_to_split))
|
||||
return MicroBatchList(
|
||||
data=data,
|
||||
mbs=results,
|
||||
|
@ -433,7 +426,7 @@ def pad_mb_list(
|
|||
# NOTE: GPU page size is 2MB
|
||||
# Take hidden size 4096 with bf16 dtype as an example,
|
||||
# the batch size of a page is 256
|
||||
pad_to_length = (l + 255) // 256 * 256
|
||||
pad_to_length = (int(l) + 255) // 256 * 256
|
||||
padded_mb, pad_len = pad_packed_tensor_dict(
|
||||
mb, pad_to_length, pad_value=pad_value
|
||||
)
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
|
||||
|
||||
# Copied from trl
|
||||
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Dropout):
|
||||
module.p = 0
|
Loading…
Reference in New Issue