mirror of https://github.com/inclusionAI/AReaL
431 lines
15 KiB
Python
431 lines
15 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
import functools
|
|
import math
|
|
import os
|
|
from typing import Callable, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from transformers.activations import ACT2FN
|
|
|
|
import realhf.base.constants as constants
|
|
import realhf.base.logging as logging
|
|
from realhf.impl.model.parallelism.tensor_parallel.modules import (
|
|
ColumnParallelLinear,
|
|
RowParallelLinear,
|
|
merged_linear_with_grad_accumulation_and_async_allreduce,
|
|
)
|
|
|
|
logger = logging.getLogger("Modules")
|
|
|
|
|
|
def get_activation_fn(activation_function: str) -> Callable:
|
|
return ACT2FN[activation_function]
|
|
|
|
|
|
SEQUENCE_PARALLEL_WARNED = False
|
|
|
|
|
|
class LayerNormQKVLinear(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
head_dim: int,
|
|
n_q_heads: int,
|
|
n_kv_heads: int,
|
|
layer_norm_epsilon: float,
|
|
use_attention_bias: bool,
|
|
layer_norm_type: Optional[str] = None,
|
|
do_layernorm_before: bool = True,
|
|
# dtype and device
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[torch.device] = None,
|
|
layer_index=None,
|
|
):
|
|
super().__init__()
|
|
tensor_parallel = constants.tensor_parallel_world_size() > 1
|
|
sequence_parallel = constants.sequence_parallel()
|
|
gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
|
|
if not tensor_parallel and (sequence_parallel or gradient_accumulation_fusion):
|
|
global SEQUENCE_PARALLEL_WARNED
|
|
if not SEQUENCE_PARALLEL_WARNED:
|
|
logger.warning(
|
|
"sequence_parallel and gradient_accumulation_fusion are only available in model parallel mode"
|
|
)
|
|
SEQUENCE_PARALLEL_WARNED = True
|
|
sequence_parallel = False
|
|
gradient_accumulation_fusion = False
|
|
if dtype is None:
|
|
dtype = torch.float16
|
|
if layer_norm_type is None:
|
|
layer_norm_fn = nn.LayerNorm
|
|
elif layer_norm_type == "rms":
|
|
layer_norm_fn = LlamaRMSNorm
|
|
elif layer_norm_type == "gemma":
|
|
layer_norm_fn = GemmaRMSNorm
|
|
self.ln = layer_norm_fn(
|
|
input_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
|
|
)
|
|
|
|
self.tensor_parallel = tensor_parallel
|
|
self.layer_index = layer_index
|
|
self.tp_worldsize = constants.tensor_parallel_world_size()
|
|
assert n_q_heads % self.tp_worldsize == 0, (
|
|
f"n_q_heads {n_q_heads} must be divisible by "
|
|
f"tp_worldsize {self.tp_worldsize}"
|
|
)
|
|
assert n_kv_heads % self.tp_worldsize == 0, (
|
|
f"n_kv_heads {n_kv_heads} must be divisible by "
|
|
f"tp_worldsize {self.tp_worldsize}"
|
|
)
|
|
hidden_dim = input_dim
|
|
# TODO: we can fuse the forward of qkv attention
|
|
self.q_attn = ColumnParallelLinear(
|
|
hidden_dim,
|
|
head_dim * n_q_heads,
|
|
bias=use_attention_bias,
|
|
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
self.k_attn = ColumnParallelLinear(
|
|
hidden_dim,
|
|
head_dim * n_kv_heads,
|
|
bias=use_attention_bias,
|
|
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
self.v_attn = ColumnParallelLinear(
|
|
hidden_dim,
|
|
head_dim * n_kv_heads,
|
|
bias=use_attention_bias,
|
|
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
self.d = head_dim
|
|
self.nq = n_q_heads
|
|
self.nkv = n_kv_heads
|
|
|
|
self.do_layernorm_before = do_layernorm_before
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self.do_layernorm_before:
|
|
hidden_states = self.ln(x)
|
|
_gradient_accumulation_fusion = self.q_attn.gradient_accumulation_fusion
|
|
_sequence_parallel = constants.sequence_parallel()
|
|
_async_grad_allreduce = not _sequence_parallel
|
|
_is_w_parallel = [
|
|
True,
|
|
isinstance(self.k_attn, ColumnParallelLinear),
|
|
isinstance(self.v_attn, ColumnParallelLinear),
|
|
]
|
|
q, k, v = merged_linear_with_grad_accumulation_and_async_allreduce(
|
|
hidden_states,
|
|
_gradient_accumulation_fusion,
|
|
_async_grad_allreduce,
|
|
_sequence_parallel,
|
|
_is_w_parallel,
|
|
self.q_attn.weight,
|
|
self.q_attn.bias,
|
|
self.k_attn.weight,
|
|
self.k_attn.bias,
|
|
self.v_attn.weight,
|
|
self.v_attn.bias,
|
|
)
|
|
q = q.view(*q.shape[:-1], self.nq // self.tp_worldsize, self.d)
|
|
k = k.view(*k.shape[:-1], self.nkv // self.tp_worldsize, self.d)
|
|
v = v.view(*v.shape[:-1], self.nkv // self.tp_worldsize, self.d)
|
|
return q, k, v
|
|
|
|
|
|
class LayerNormMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_dim: int,
|
|
intermediate_dim: int,
|
|
use_bias: bool,
|
|
resid_pdrop: float,
|
|
activation_function: str,
|
|
layer_norm_epsilon: float,
|
|
do_layernorm_before: bool = True,
|
|
# dtype and device
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
super().__init__()
|
|
tensor_parallel = constants.tensor_parallel_world_size() > 1
|
|
sequence_parallel = constants.sequence_parallel()
|
|
gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
|
|
if not tensor_parallel and (sequence_parallel or gradient_accumulation_fusion):
|
|
global SEQUENCE_PARALLEL_WARNED
|
|
if not SEQUENCE_PARALLEL_WARNED:
|
|
logger.warning(
|
|
"sequence_parallel and gradient_accumulation_fusion are only available in model parallel mode"
|
|
)
|
|
SEQUENCE_PARALLEL_WARNED = True
|
|
sequence_parallel = False
|
|
gradient_accumulation_fusion = False
|
|
if dtype is None:
|
|
dtype = torch.float16
|
|
|
|
self.ln = nn.LayerNorm(
|
|
hidden_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
|
|
)
|
|
self.c_fc = ColumnParallelLinear(
|
|
hidden_dim,
|
|
intermediate_dim,
|
|
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
|
bias=use_bias,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
self.c_proj = RowParallelLinear(
|
|
intermediate_dim,
|
|
hidden_dim,
|
|
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
|
bias=use_bias,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
self.act = get_activation_fn(activation_function)
|
|
self.dropout = nn.Dropout(resid_pdrop)
|
|
self.do_layernorm_before = do_layernorm_before
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
if self.do_layernorm_before:
|
|
hidden_states = self.ln(hidden_states)
|
|
hidden_states = self.c_fc(hidden_states)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states = self.c_proj(hidden_states)
|
|
return self.dropout(hidden_states)
|
|
|
|
|
|
class LlamaLayerNormMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_dim: int,
|
|
intermediate_dim: int,
|
|
activation_function: str,
|
|
use_bias: bool,
|
|
# layer norm
|
|
layer_norm_epsilon: float = 1e-5,
|
|
layer_norm_type: str = "rms",
|
|
# whether this MLP is used as expert
|
|
is_expert: bool = False,
|
|
# dtype and device
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
super().__init__()
|
|
self.tensor_parallel = constants.tensor_parallel_world_size() > 1
|
|
gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
|
|
self.is_expert = is_expert
|
|
# when used as experts the MLP always compute without sequence parallel
|
|
sequence_parallel = constants.sequence_parallel() and not is_expert
|
|
if not self.tensor_parallel and (
|
|
sequence_parallel or gradient_accumulation_fusion
|
|
):
|
|
global SEQUENCE_PARALLEL_WARNED
|
|
if not SEQUENCE_PARALLEL_WARNED:
|
|
logger.warning(
|
|
"sequence_parallel and gradient_accumulation_fusion are only available in model parallel mode"
|
|
)
|
|
SEQUENCE_PARALLEL_WARNED = True
|
|
gradient_accumulation_fusion = False
|
|
|
|
if dtype is None:
|
|
dtype = torch.float16
|
|
self.hidden_size = hidden_dim
|
|
self.intermediate_size = intermediate_dim
|
|
self.use_layer_norm = (
|
|
not is_expert
|
|
) # when used as experts layer norm is computed outside
|
|
|
|
if self.use_layer_norm:
|
|
if layer_norm_type == "rms":
|
|
self.ln = LlamaRMSNorm(
|
|
hidden_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
|
|
)
|
|
elif layer_norm_type == "gemma":
|
|
self.ln = GemmaRMSNorm(
|
|
hidden_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
|
|
)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
# TODO: we can fuse gate and up proj, as well as the silu and mul operations
|
|
self.gate_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.intermediate_size,
|
|
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
|
is_expert=is_expert,
|
|
bias=use_bias,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
self.up_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.intermediate_size,
|
|
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
|
is_expert=is_expert,
|
|
bias=use_bias,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
self.intermediate_size,
|
|
self.hidden_size,
|
|
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
|
is_expert=is_expert,
|
|
bias=use_bias,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
self.act_fn = get_activation_fn(activation_function)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self.use_layer_norm:
|
|
x = self.ln(x)
|
|
_gradient_accumulation_fusion = self.gate_proj.gradient_accumulation_fusion
|
|
_sequence_parallel = constants.sequence_parallel() and not self.is_expert
|
|
_async_grad_allreduce = not _sequence_parallel
|
|
_is_w_parallel = [True, True]
|
|
|
|
gate, upproj = merged_linear_with_grad_accumulation_and_async_allreduce(
|
|
x,
|
|
_gradient_accumulation_fusion,
|
|
_async_grad_allreduce,
|
|
_sequence_parallel,
|
|
_is_w_parallel,
|
|
self.gate_proj.weight,
|
|
self.gate_proj.bias,
|
|
self.up_proj.weight,
|
|
self.up_proj.bias,
|
|
)
|
|
return self.down_proj(self.act_fn(gate) * upproj)
|
|
|
|
|
|
class _LlamaRMSNorm(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
eps: float = 1e-6,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
"""LlamaRMSNorm is equivalent to T5LayerNorm."""
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
class GemmaRMSNorm(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
eps: float = 1e-6,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
|
|
|
|
def _norm(self, x):
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
output = self._norm(x.float())
|
|
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
|
# See https://github.com/huggingface/transformers/pull/29402
|
|
output = output * (1.0 + self.weight.float())
|
|
return output.type_as(x)
|
|
|
|
|
|
if constants.use_te_impl():
|
|
try:
|
|
# HACK: we use transformer engine's rms norm as long as we can find the transformer engine package
|
|
import transformer_engine.pytorch as te
|
|
|
|
def _TELlamaRMSNorm(
|
|
hidden_size: int,
|
|
eps: float = 1e-6,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
return te.module.rmsnorm.RMSNorm(
|
|
hidden_size=hidden_size,
|
|
eps=eps,
|
|
sequence_parallel=constants.sequence_parallel(),
|
|
params_dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
LlamaRMSNorm = _TELlamaRMSNorm
|
|
except ModuleNotFoundError:
|
|
LlamaRMSNorm = _LlamaRMSNorm
|
|
except ImportError:
|
|
LlamaRMSNorm = _LlamaRMSNorm
|
|
else:
|
|
LlamaRMSNorm = _LlamaRMSNorm
|
|
|
|
if constants.use_te_impl():
|
|
from transformer_engine.pytorch.module.layernorm_mlp import (
|
|
LayerNormMLP as _TELayerNormMLP,
|
|
)
|
|
|
|
# The same signature as LlamaLayerNormMLP
|
|
def LlamaLayerNormMLP(
|
|
hidden_dim: int,
|
|
intermediate_dim: int,
|
|
activation_function: str,
|
|
use_bias: bool,
|
|
# layer norm
|
|
layer_norm_epsilon: float = 1e-5,
|
|
layer_norm_type: str = "rms",
|
|
# moe
|
|
is_expert: bool = False,
|
|
# dtype and device
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
assert not use_bias
|
|
assert layer_norm_type == "rms"
|
|
assert not is_expert
|
|
assert activation_function == "silu"
|
|
return _TELayerNormMLP(
|
|
hidden_size=hidden_dim,
|
|
ffn_hidden_size=intermediate_dim,
|
|
eps=layer_norm_epsilon,
|
|
sequence_parallel=constants.sequence_parallel(),
|
|
return_bias=False,
|
|
tp_group=constants.tensor_parallel_group(),
|
|
tp_size=constants.tensor_parallel_world_size(),
|
|
bias=False,
|
|
normalization="RMSNorm",
|
|
activation="swiglu",
|
|
fuse_wgrad_accumulation=constants.gradient_accumulation_fusion(),
|
|
params_dtype=dtype,
|
|
set_parallel_mode=constants.tensor_parallel_world_size() > 1,
|
|
device=device,
|
|
)
|