mirror of https://github.com/inclusionAI/AReaL
227 lines
8.0 KiB
Python
227 lines
8.0 KiB
Python
# Modified from Megatron-LM.
|
|
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
from typing import Callable, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.init as init
|
|
from torch.nn.parameter import Parameter
|
|
|
|
import realhf.base.constants as constants
|
|
from realhf.api.core.model_api import ReaLModelConfig
|
|
from realhf.impl.model.modules.mlp import LlamaLayerNormMLP, get_activation_fn
|
|
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
|
|
copy_to_tensor_model_parallel_region,
|
|
reduce_from_tensor_model_parallel_region,
|
|
)
|
|
from realhf.impl.model.parallelism.tensor_parallel.utils import divide
|
|
from realhf.impl.model.utils.random import _initialize_affine_weight_gpu
|
|
|
|
try:
|
|
import grouped_gemm
|
|
except ImportError:
|
|
grouped_gemm = None
|
|
|
|
|
|
class SequentialMLP(torch.nn.Module):
|
|
"""An implementation of the Experts layer using a sequence of MLP layers.
|
|
|
|
This class executes each expert sequentially.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: ReaLModelConfig,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.num_experts = self.config.moe.num_experts
|
|
self.local_experts = torch.nn.ModuleList()
|
|
|
|
for _ in range(self.num_experts):
|
|
expert = LlamaLayerNormMLP(
|
|
hidden_dim=config.hidden_dim,
|
|
intermediate_dim=config.intermediate_dim,
|
|
activation_function=config.activation_function,
|
|
use_bias=config.use_mlp_bias,
|
|
is_expert=True,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
self.local_experts.append(expert)
|
|
|
|
def forward(
|
|
self,
|
|
permuted_local_hidden_states: torch.Tensor,
|
|
tokens_per_expert: torch.Tensor,
|
|
):
|
|
output_local = torch.zeros_like(permuted_local_hidden_states)
|
|
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
|
|
# Insert zero at the begining for offset index's convenience
|
|
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
|
|
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
|
|
|
|
for expert_num, expert in enumerate(self.local_experts):
|
|
start = cumsum_num_tokens[expert_num]
|
|
end = cumsum_num_tokens[expert_num + 1]
|
|
hidden = permuted_local_hidden_states[start:end]
|
|
output = expert(hidden)
|
|
output_local[start:end] = output
|
|
|
|
return output_local
|
|
|
|
|
|
class ExpertParam(torch.nn.Module):
|
|
"""A dummy class that maps weight tensors in GroupedMLP to pytorch
|
|
parameters for compatibility of weight saving/loading."""
|
|
|
|
def __init__(
|
|
self,
|
|
gate_proj: torch.Tensor,
|
|
up_proj: torch.Tensor,
|
|
down_proj: torch.Tensor,
|
|
):
|
|
class LinearParam(torch.nn.Module):
|
|
def __init__(self, param: torch.Tensor):
|
|
super(LinearParam, self).__init__()
|
|
self.weight = Parameter(param)
|
|
|
|
super(ExpertParam, self).__init__()
|
|
|
|
self.gate_proj = LinearParam(gate_proj)
|
|
self.up_proj = LinearParam(up_proj)
|
|
self.down_proj = LinearParam(down_proj)
|
|
|
|
|
|
class GroupedMLP(torch.nn.Module):
|
|
"""An efficient implementation of the Experts layer using CUTLASS GroupedGEMM.
|
|
See https://github.com/tgale96/grouped_gemm for details.
|
|
|
|
This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: ReaLModelConfig,
|
|
init_method: Callable = init.xavier_normal_,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
super().__init__()
|
|
assert (
|
|
not constants.sequence_parallel()
|
|
), "Grouped GEMM does not support sequence parallel"
|
|
|
|
self.config = config
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.num_experts = config.moe.num_experts
|
|
|
|
assert grouped_gemm is not None, "Grouped GEMM is not available."
|
|
|
|
self.activation_func = get_activation_fn(self.config.activation_function)
|
|
|
|
# How many feature each rank holds for fc1 and fc2, respectively.
|
|
tp_size = constants.tensor_parallel_world_size()
|
|
intermediate_dim_per_partition = divide(self.config.intermediate_dim, tp_size)
|
|
|
|
# Note: The current kernel implementations of grouped_gemm
|
|
# does not support transposition with CUTLASS grouped GEMM
|
|
# and as a result we avoid allocate the transpose of weights.
|
|
self.grouped_gate_proj = torch.empty(
|
|
self.num_experts,
|
|
self.config.hidden_dim,
|
|
intermediate_dim_per_partition,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
)
|
|
self.grouped_up_proj = torch.empty(
|
|
self.num_experts,
|
|
self.config.hidden_dim,
|
|
intermediate_dim_per_partition,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
)
|
|
self.grouped_down_proj = torch.empty(
|
|
self.num_experts,
|
|
intermediate_dim_per_partition,
|
|
self.config.hidden_dim,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
)
|
|
# Initialize weight.
|
|
_initialize_affine_weight_gpu(
|
|
self.grouped_gate_proj,
|
|
init_method,
|
|
partition_dim=1,
|
|
)
|
|
_initialize_affine_weight_gpu(
|
|
self.grouped_up_proj,
|
|
init_method,
|
|
partition_dim=0,
|
|
)
|
|
_initialize_affine_weight_gpu(
|
|
self.grouped_down_proj,
|
|
init_method,
|
|
partition_dim=0,
|
|
)
|
|
|
|
# Parameters for weight loading
|
|
self.local_experts = torch.nn.ModuleList()
|
|
for i in range(self.num_experts):
|
|
expert = ExpertParam(
|
|
self.grouped_gate_proj[i, :].transpose_(0, 1),
|
|
self.grouped_up_proj[i, :].transpose_(0, 1),
|
|
self.grouped_down_proj[i, :].transpose_(0, 1),
|
|
)
|
|
self.local_experts.append(expert)
|
|
|
|
def forward(
|
|
self,
|
|
permuted_local_hidden_states: torch.Tensor,
|
|
tokens_per_expert: torch.Tensor,
|
|
):
|
|
tokens_per_expert = tokens_per_expert.cpu()
|
|
if permuted_local_hidden_states.nelement() != 0:
|
|
if constants.tensor_parallel_world_size() > 1:
|
|
permuted_local_hidden_states = copy_to_tensor_model_parallel_region(
|
|
permuted_local_hidden_states
|
|
)
|
|
|
|
# Reshape the weights for the grouped GEMMs.
|
|
o1 = grouped_gemm.ops.gmm(
|
|
permuted_local_hidden_states,
|
|
self.grouped_gate_proj,
|
|
tokens_per_expert,
|
|
trans_b=False,
|
|
)
|
|
o2 = grouped_gemm.ops.gmm(
|
|
permuted_local_hidden_states,
|
|
self.grouped_up_proj,
|
|
tokens_per_expert,
|
|
trans_b=False,
|
|
)
|
|
inter = self.activation_func(o1) * o2
|
|
output = grouped_gemm.ops.gmm(
|
|
inter, self.grouped_down_proj, tokens_per_expert, trans_b=False
|
|
)
|
|
if constants.tensor_parallel_world_size() > 1:
|
|
output = reduce_from_tensor_model_parallel_region(output)
|
|
else:
|
|
# No token is allocated for local experts.
|
|
assert torch.count_nonzero(tokens_per_expert) == 0
|
|
|
|
# Make sure parameters still have gradients when no tokens are routed to this set of experts.
|
|
gate_proj = self.grouped_gate_proj.view(self.config.hidden_dim, -1)
|
|
up_proj = self.grouped_up_proj.view(self.config.hidden_dim, -1)
|
|
down_proj = self.grouped_down_proj.view(-1, self.config.hidden_dim)
|
|
|
|
o1 = torch.matmul(permuted_local_hidden_states, gate_proj)
|
|
o2 = torch.matmul(permuted_local_hidden_states, up_proj)
|
|
inter = self.activation_func(o1 * o2)
|
|
output = torch.matmul(inter, down_proj)
|
|
return output
|