AReaL/realhf/impl/model/nn/real_llm_base.py

585 lines
22 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import contextlib
import dataclasses
import functools
import itertools
import json
import os
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
import transformers
import realhf.base.constants as constants
import realhf.base.logging as logging
import realhf.impl.model.parallelism.tensor_parallel.mappings as tensor_parallel
from realhf.api.core import model_api
from realhf.impl.model.modules import (
CausalSelfAttentionLayer,
GemmaRMSNorm,
LayerNormMLP,
LayerNormMoELayer,
LlamaLayerNormMLP,
LlamaRMSNorm,
OffsetParallelPositionalEmbedding,
OffsetPositionalEmbedding,
)
from realhf.impl.model.parallelism.tensor_parallel.modules import (
ColumnParallelLinear,
ParallelEmbedding,
gather_from_sequence_parallel_region,
gather_from_tensor_model_parallel_region,
parallel_lm_logits,
scatter_to_tensor_model_parallel_region,
)
from realhf.impl.model.utils.functional import compute_varlen_position_indices
logger = logging.getLogger("ReaLModelBase")
@dataclasses.dataclass
class PipeTransferData:
"""Data structure for transferring data between stages.
Each pipeline stage has exactly one PipeTransferData as the input and the output,
no matter how many layers are in this stage.
Attributes:
pp_input: The input to the current stage. Usually hidden states
with shape [bs, seq_len, hidden_dim].
pp_output: The output of the current stage, also the input to the next stage.
Usually hidden states with shape [bs, seq_len, hidden_dim].
cu_seqlens: The cumulative sequence lengths of packed input_ids.
Used by flash_attn_varlen_func. Will not be used during generation.
It's configuration-like data that must be transfered from the first stage
to the last. Shape [bs + 1].
max_seqlen: The maximum sequence length of packed input_ids.
Used by flash_attn_varlen_func. Will not be used during generation.
It's configuration-like data that must be transfered from the first stage
to the last.
store_kv_cache: Whether to store the key and value cache for generation.
attention_mask: The attention mask of the input, the same as huggingface transformers.
Used by torch_attn_func to examine the outputs of PyTorch attention and flash
attention are the same. Only for debugging. Shape [bs, seq_len].
"""
pp_input: torch.Tensor = None
pp_output: torch.Tensor = None
# The followings are "configuration"-like data that should be passed across all stages.
cu_seqlens: torch.Tensor = None
max_seqlen: int = None
store_kv_cache: bool = False
@dataclasses.dataclass
class PipeCacheData:
"""Data structure for caching data locally that will not be trasferred.
Each layer has exactly one PipeCacheData as the input.
If a pipeline stage has multiple layers, a list of PipeCacheData should be passed
as the input. The cached tensors will be changed in-place.
Attributes:
input_ids: The input token ids. Used only at the first stage.
Can be packed with shape [total_seq_len] or unpacked with shape [bs, seq].
prompt_mask: Prompt mask used
position_ids: Input position IDs. Can be resolved automatically in most cases.
Used only at the first stage. The same shape as input_ids.
If None, will be resolved automatically.
k_cache: Key cache used for generation, shape [bs, max_seq, n_kv_heads, head_dim].
Note that this is the cache for a specific layer, not for all layers.
v_cache: Value cache used for generation, shape [bs, max_seq, n_kv_heads, head_dim].
Note that this is the cache for a specific layer, not for all layers.
cache_seqlens: The sequence lengths of the cached tokens. Used for generation. Shape [bs].
"""
# Only cached in the first stage.
packed_input_ids: torch.Tensor = None
packed_position_ids: torch.Tensor = None
# Cached in each transformer layer.
k_cache: torch.Tensor = None
v_cache: torch.Tensor = None
cache_seqlens: torch.Tensor = None
class ReaLModelBlock(nn.Module):
def __init__(
self,
config: model_api.ReaLModelConfig,
layer_index: int,
output_layernorm: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
):
super().__init__()
if dtype is None:
dtype = torch.float16
self.config = config
self.layer_index = layer_index
self.attn = CausalSelfAttentionLayer(
hidden_dim=config.hidden_dim,
n_kv_heads=config.n_kv_heads,
n_q_heads=config.n_q_heads,
head_dim=config.head_dim,
resid_pdrop=config.resid_pdrop,
attn_pdrop=config.attn_pdrop,
layer_index=layer_index,
layer_norm_epsilon=config.layer_norm_epsilon,
scale_attn_by_inverse_layer_idx=config.scale_attn_by_inverse_layer_idx,
scale_attn_weights=config.scale_attn_weights,
layer_norm_type=config.layer_norm_type,
use_attention_bias=config.use_attention_bias,
use_attn_proj_bias=config.use_attn_proj_bias,
do_layernorm_before=config.do_layernorm_before,
qk_layernorm=config.qk_layernorm,
apply_rotary=config.apply_rotary,
rotary_base=config.rotary_base,
rotary_interleaved=config.rotary_interleaved,
rotary_scaling=config.rotary_scaling,
rotary_scaling_type=config.rotary_scaling_type,
rotary_special_impl=config.rotary_special_impl,
dtype=dtype,
device=device,
)
if config.mlp_type is None:
self.mlp = LayerNormMLP(
hidden_dim=config.hidden_dim,
intermediate_dim=config.intermediate_dim,
resid_pdrop=config.resid_pdrop,
use_bias=config.use_mlp_bias,
do_layernorm_before=config.do_layernorm_before,
activation_function=config.activation_function,
layer_norm_epsilon=config.layer_norm_epsilon,
dtype=dtype,
device=device,
)
elif config.mlp_type == "llama":
self.mlp = LlamaLayerNormMLP(
hidden_dim=config.hidden_dim,
intermediate_dim=config.intermediate_dim,
activation_function=config.activation_function,
layer_norm_epsilon=config.layer_norm_epsilon,
layer_norm_type=config.layer_norm_type,
use_bias=config.use_mlp_bias,
dtype=dtype,
device=device,
)
elif config.mlp_type == "moe":
self.mlp = LayerNormMoELayer(
config=config,
layer_idx=layer_index,
dtype=dtype,
device=device,
)
else:
raise NotImplementedError(f"Unknown MLP type: {config.mlp_type}")
self.output_layernorm = output_layernorm
if output_layernorm:
if config.layer_norm_type is None:
layer_norm_fn = nn.LayerNorm
elif config.layer_norm_type == "rms":
layer_norm_fn = LlamaRMSNorm
elif config.layer_norm_type == "gemma":
layer_norm_fn = GemmaRMSNorm
self.ln_f = layer_norm_fn(
config.hidden_dim,
eps=config.layer_norm_epsilon,
dtype=dtype,
device=device,
)
def forward(self, x: PipeTransferData, y: PipeCacheData) -> PipeTransferData:
pp_input = x.pp_input
cu_seqlens = x.cu_seqlens
k_cache = y.k_cache
v_cache = y.v_cache
cache_seqlens = y.cache_seqlens
max_seqlen = x.max_seqlen
if constants.gradient_checkpointing():
pp_output, k, v = torch.utils.checkpoint.checkpoint(
self._forward,
pp_input,
cu_seqlens,
k_cache,
v_cache,
cache_seqlens,
max_seqlen,
use_reentrant=True,
)
else:
pp_output, k, v = self._forward(
pp_input,
cu_seqlens,
k_cache,
v_cache,
cache_seqlens,
max_seqlen,
)
x.pp_output = pp_output
if x.store_kv_cache:
if y.k_cache is None:
y.k_cache = k.detach()
if y.v_cache is None:
y.v_cache = v.detach()
if y.cache_seqlens is None and x.cu_seqlens is not None:
y.cache_seqlens = x.cu_seqlens[1:] - x.cu_seqlens[:-1]
return x
def _forward(
self,
pp_input: torch.Tensor,
cu_seqlens: torch.Tensor,
k_cache: Optional[torch.Tensor],
v_cache: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
max_seqlen: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
h = pp_input
attn_out, k, v = self.attn(
hidden_states=h,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
)
h = h + attn_out
# For opt-350m
if not self.config.do_layernorm_before:
h = self.attn.c_attn.ln(h)
h = self.mlp(h) + h
# For opt-350m
if not self.config.do_layernorm_before:
h = self.mlp.ln(h)
if self.output_layernorm:
h = self.ln_f(h)
return h, k, v
class VocabPositionEmbedding(nn.Module):
def __init__(
self,
config: model_api.ReaLModelConfig,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
):
super().__init__()
self.n_positions = config.n_positions
self.hidden_dim = config.hidden_dim
tensor_parallel = constants.tensor_parallel_world_size() > 1
if tensor_parallel:
embed_cls = ParallelEmbedding
else:
embed_cls = nn.Embedding
self.wte = embed_cls(
config.vocab_size, config.hidden_dim, dtype=dtype, device=device
)
self.apply_abs_pos_embed = not config.apply_rotary
if self.apply_abs_pos_embed:
p_embed_cls = (
OffsetParallelPositionalEmbedding
if tensor_parallel
else OffsetPositionalEmbedding
)
self.wpe = p_embed_cls(
config.n_positions,
config.hidden_dim,
offset=config.abs_position_embedding_offset,
dtype=dtype,
device=device,
)
self.normalize_embed = config.normalize_embed
self.embed_drop = nn.Dropout(config.embd_pdrop)
def forward(self, x: PipeTransferData, y: PipeCacheData) -> PipeTransferData:
# Set position ids.
# if y.packed_position_ids is not None:
# raise ValueError("In our use cases, position_ids must be None.")
y.packed_position_ids = compute_varlen_position_indices(
total_seqlen=y.packed_input_ids.shape[0],
cu_seqlens=x.cu_seqlens,
seqlen_offsets=y.cache_seqlens,
)
if x.max_seqlen > self.n_positions:
raise ValueError(
f"max_seqlen ({x.max_seqlen}) must be <= n_positions ({self.n_positions})."
)
assert y.packed_position_ids.shape == y.packed_input_ids.shape, (
y.packed_position_ids.shape,
y.packed_input_ids.shape,
x.cu_seqlens,
)
x.pp_output = self._forward(y.packed_input_ids, y.packed_position_ids)
return x
def _forward(
self, input_ids: torch.LongTensor, position_ids: torch.LongTensor
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
if self.apply_abs_pos_embed:
inputs_embeds = inputs_embeds + self.wpe(position_ids)
if self.normalize_embed:
normalizer = torch.tensor(self.hidden_dim**0.5, dtype=inputs_embeds.dtype)
inputs_embeds = inputs_embeds * normalizer
if constants.sequence_parallel():
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(
inputs_embeds
)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
inputs_embeds = inputs_embeds.clone()
# with tensor_parallel.get_cuda_rng_tracker().fork():
x = self.embed_drop(inputs_embeds)
else:
x = self.embed_drop(inputs_embeds)
return x
class OutputHead(nn.Linear):
def __init__(
self, *args, norm_head: bool = False, norm_softmax: bool = False, **kwargs
):
super().__init__(*args, **kwargs)
self._norm_head = norm_head
self._norm_softmax = norm_softmax
def forward(self, x: PipeTransferData, y: PipeCacheData) -> PipeTransferData:
x.pp_output = self._forward(x.pp_input)
return x
def _forward(self, x: torch.Tensor):
if self._norm_head and self.out_features != 1:
unnormed_head = nn.functional.linear(
torch.eye(
self.in_features, dtype=self.weight.dtype, device=self.weight.device
),
self.weight,
).transpose(1, 0)
head_norm = unnormed_head.norm(dim=0, keepdim=True, p=2)
normed_head = unnormed_head / (head_norm + 1e-7)
logits = nn.functional.linear(x, normed_head, None)
else:
logits = super().forward(x)
if self._norm_softmax and self.out_features != 1:
logits = logits / (torch.std(logits, dim=-1, keepdim=True) + 1e-6)
return logits
class SequenceParallelCriticHead(nn.Linear):
def forward(self, x: PipeTransferData, y: PipeCacheData) -> PipeTransferData:
all_gather_buffer = tensor_parallel.gather_from_sequence_parallel_region(
x.pp_input
)
x.pp_output = nn.functional.linear(all_gather_buffer, self.weight, self.bias)
return x
def _forward(self, x: torch.Tensor):
x = tensor_parallel.gather_from_sequence_parallel_region(x)
return super().forward(x)
class ParallelActorHead(ColumnParallelLinear):
def __init__(
self, *args, norm_head: bool = False, norm_softmax: bool = False, **kwargs
):
super().__init__(*args, **kwargs)
self._norm_head = norm_head
self._norm_softmax = norm_softmax
def forward(self, x: PipeTransferData, y: PipeCacheData) -> PipeTransferData:
x.pp_output = self._forward(x.pp_input)
return x
def _forward(self, x: torch.Tensor):
weight = self.weight
if self._norm_head:
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
gather_from_sequence_parallel_region,
)
# HACK: This is a terrible implementation for Bailing's head norming,
# because we basically eliminates TP for the LM head.
whole_weight = gather_from_sequence_parallel_region(self.weight)
unnormed_head = nn.functional.linear(
torch.eye(
self.input_size, dtype=self.weight.dtype, device=self.weight.device
),
whole_weight,
).transpose(1, 0)
head_norm = unnormed_head.norm(dim=0, keepdim=True, p=2)
normed_head = unnormed_head / (head_norm + 1e-7)
weight = tensor_parallel.scatter_to_sequence_parallel_region(normed_head)
output = parallel_lm_logits(
x,
weight,
parallel_output=True,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
bias=self.bias,
)
if self._norm_softmax:
whole_output = gather_from_tensor_model_parallel_region(output)
whole_output = whole_output / (
torch.std(whole_output, dim=-1, keepdim=True) + 1e-6
)
output = scatter_to_tensor_model_parallel_region(whole_output)
return output
class ReaLModelParamKeys:
"""The keys of parameters in ReaLModel, used for parameter reallocation.
**IMPORTANT**: The returned keys are **ordered**. They should have
the same order as we iterate layer indices and call
layer.state_dict().
"""
@staticmethod
def embed(config: model_api.ReaLModelConfig) -> int:
keys = ["0.wte.weight"]
if not config.apply_rotary:
keys += ["0.wpe.weight"]
return keys
@staticmethod
def tblock(config: model_api.ReaLModelConfig, idx: int) -> List[str]:
# NOTE: `idx`` is the index of transformer blocks,
# i.e, 0 for the first block or the second layer of the transformer.
# NOTE: The order matters, we should not change the order of keys.
keys = [f"{idx + 1}.attn.c_attn.ln.weight"]
if config.layer_norm_type is None:
keys += [f"{idx + 1}.attn.c_attn.ln.bias"]
keys += [f"{idx + 1}.attn.c_attn.q_attn.weight"]
if config.use_attention_bias:
keys += [f"{idx + 1}.attn.c_attn.q_attn.bias"]
keys += [f"{idx + 1}.attn.c_attn.k_attn.weight"]
if config.use_attention_bias:
keys += [f"{idx + 1}.attn.c_attn.k_attn.bias"]
keys += [f"{idx + 1}.attn.c_attn.v_attn.weight"]
if config.use_attention_bias:
keys += [f"{idx + 1}.attn.c_attn.v_attn.bias"]
keys += [f"{idx + 1}.attn.c_proj.weight"]
if config.use_attn_proj_bias:
keys += [f"{idx + 1}.attn.c_proj.bias"]
if config.qk_layernorm:
keys += [f"{idx + 1}.attn.q_ln.weight"]
keys += [f"{idx + 1}.attn.k_ln.weight"]
if config.layer_norm_type is None:
keys += [f"{idx + 1}.attn.q_ln.bias"]
keys += [f"{idx + 1}.attn.k_ln.bias"]
keys += [f"{idx + 1}.mlp.ln.weight"]
if config.layer_norm_type is None:
keys += [f"{idx + 1}.mlp.ln.bias"]
if config.mlp_type is None:
if config.use_mlp_bias:
keys += [
f"{idx + 1}.mlp.c_fc.weight",
f"{idx + 1}.mlp.c_fc.bias",
f"{idx + 1}.mlp.c_proj.weight",
f"{idx + 1}.mlp.c_proj.bias",
]
else:
keys += [
f"{idx + 1}.mlp.c_fc.weight",
f"{idx + 1}.mlp.c_proj.weight",
]
elif config.mlp_type == "llama":
weights_key = [
f"{idx + 1}.mlp.gate_proj.weight",
f"{idx + 1}.mlp.up_proj.weight",
f"{idx + 1}.mlp.down_proj.weight",
]
if config.use_mlp_bias:
keys += list(
itertools.chain.from_iterable(
zip(
weights_key,
[
f"{idx + 1}.mlp.gate_proj.bias",
f"{idx + 1}.mlp.up_proj.bias",
f"{idx + 1}.mlp.down_proj.bias",
],
)
)
)
else:
keys += weights_key
elif config.mlp_type == "moe":
num_experts = config.moe.num_experts
keys += [
f"{idx + 1}.mlp.router.weight",
]
for j in range(num_experts):
weights_key = [
f"{idx + 1}.mlp.experts.local_experts.{j}.gate_proj.weight",
f"{idx + 1}.mlp.experts.local_experts.{j}.up_proj.weight",
f"{idx + 1}.mlp.experts.local_experts.{j}.down_proj.weight",
]
if config.use_mlp_bias:
keys += list(
itertools.chain.from_iterable(
zip(
weights_key,
[
f"{idx + 1}.mlp.experts.local_experts.{j}.gate_proj.bias",
f"{idx + 1}.mlp.experts.local_experts.{j}.up_proj.bias",
f"{idx + 1}.mlp.experts.local_experts.{j}.down_proj.bias",
],
)
)
)
else:
keys += weights_key
else:
raise NotImplementedError()
if idx == config.n_layers - 1:
keys += [f"{idx + 1}.ln_f.weight"]
if config.layer_norm_type is None:
keys += [f"{idx + 1}.ln_f.bias"]
return keys
@staticmethod
def head(config: model_api.ReaLModelConfig) -> List[str]:
return [f"{config.n_layers + 1}.weight"]
def keys_from_layer_indices(
config: model_api.ReaLModelConfig, layer_indices: List[int]
) -> List[str]:
# assert _is_integer_list_contiguous(layer_indices)
sd_keys = []
for layer_idx in sorted(layer_indices):
if layer_idx == 0:
sd_keys += ReaLModelParamKeys.embed(config)
elif layer_idx == config.n_layers + 1:
sd_keys += ReaLModelParamKeys.head(config)
else:
sd_keys += ReaLModelParamKeys.tblock(config, layer_idx - 1)
return sd_keys