AReaL/realhf/impl/model/utils/moe.py

413 lines
17 KiB
Python

# Modified from Megatron-LM.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import math
from typing import Any, Optional
import torch
import torch.distributed as dist
from realhf.base import constants, logging
def switch_load_balancing_loss_func(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk: int,
moe_aux_loss_coeff: float,
sequence_partition_group=None,
):
"""Calculate the auxiliary loss for load balancing.
Refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
Args:
probs (torch.Tensor): Softmax probabilities output by the router for each token. [num_tokens, num_experts]
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. [num_experts]
topk (int): The number of experts selected for each token.
moe_aux_loss_coeff (float): The coefficient for the auxiliary loss.
sequence_partition_group (optional): The parallel group over which the sequence is partitioned. If None, no partitioning is applied. Defaults to None.
Returns:
torch.Tensor: The auxiliary loss for load balancing.
"""
num_sub_sequence = 1
# If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full sequence.
if sequence_partition_group is not None:
# We can keep `aggregated_probs_per_expert` local since we don't need the gradient for `tokens_per_expert`, saving one allreduce operation for `aggregated_probs_per_expert`.
# NOTE: Since the auxiliary loss is computed on the local `aggregated_probs_per_expert`, it requires scaling by `dist.world_size(sequence_partition_group)` when printing the loss.
num_sub_sequence = dist.get_world_size(sequence_partition_group)
dist.all_reduce(tokens_per_expert, group=sequence_partition_group)
num_tokens = probs.shape[0] * topk * num_sub_sequence
num_experts = probs.shape[1]
# The formula of aux_loss: aux_loss = sum((probs_per_expert/num_tokens) * (tokens_per_expert/num_tokens)) * num_experts * moe_aux_loss_coeff.
# This can be simplified to fuse the division and multiplication operations.
aggregated_probs_per_expert = probs.sum(dim=0)
aux_loss = torch.sum(aggregated_probs_per_expert * tokens_per_expert) * (
num_experts * moe_aux_loss_coeff / (num_tokens * num_tokens)
)
return aux_loss
def z_loss_func(logits, z_loss_coeff):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
return z_loss
def sinkhorn(cost: torch.Tensor, tol: float = 0.0001):
"""Sinkhorn based MoE routing function."""
cost = torch.exp(cost)
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
eps = 0.00000001
error = 1e9
d1_old = d1
while error > tol:
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
error = torch.mean(torch.abs(d1_old - d1))
d1_old = d1
return d1 * cost * d0.unsqueeze(1)
def get_capacity(
num_tokens: int, num_experts: int, capacity_factor: float, min_capacity=None
):
"""Calculate the capacity of each expert.
Args:
num_tokens (int): num of the input tokens.
num_experts (int): num of the experts.
capacity_factor (float): Capacity factor.
min_capacity (int, optional): Minimum capacity. Defaults to None.
Returns:
Tensor: Capacity of each expert.
"""
capacity = math.ceil((num_tokens / num_experts) * capacity_factor)
if min_capacity is not None and capacity < min_capacity:
capacity = min_capacity
return capacity
def custom_histc(input: torch.Tensor, bins: int, min: Any, max: Any):
"""A CPU compatible version of torch.histc."""
if input.is_cpu:
out = torch.zeros(bins, dtype=torch.long)
bin_width = (max - min) / bins
# Iterate over the input tensor and increment the appropriate bin
for value in input.flatten():
if min <= value < max:
bin_index = int((value - min) / bin_width)
out[bin_index] += 1
elif value == max:
out[bins - 1] += 1
return out
else:
return torch.histc(input, bins=bins, min=min, max=max)
class MoEAuxLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that compute and scales the grad for auxiliary loss."""
main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
@staticmethod
def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
"""Preserve the aux_loss by storing it in the context to avoid garbage
collection.
Args:
output (torch.Tensor): The output tensor.
aux_loss (torch.Tensor): The auxiliary loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(aux_loss)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
"""Compute and scale the gradient for auxiliary loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient.
"""
(aux_loss,) = ctx.saved_tensors
aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
return grad_output, scaled_aux_loss_grad
@staticmethod
def set_loss_scale(scale: torch.Tensor):
"""Set the scale of the aux loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.
"""
MoEAuxLossAutoScaler.main_loss_backward_scale = scale
def permute(tokens, indices, num_out_tokens: int = None, padded_mode: bool = False):
"""Permute the tokens based on the indices. Token with the same index will be grouped together.
The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.
Args:
tokens (torch.Tensor): The input token tensor.
indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk].
num_out_tokens (int, optional): The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped.
padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
Returns:
torch.Tensor: The permuted tensor.
torch.Tensor: The sorted_indices corresponding permuted tensor.
"""
if padded_mode:
return permute_with_padded_tokens(tokens, indices)
if indices.dim() == 1:
topk = 1
else:
topk = indices.size(1)
flatten_indices = indices.view(-1)
sorted_indices = torch.argsort(flatten_indices, stable=True)
if num_out_tokens is not None:
sorted_indices = sorted_indices[:num_out_tokens]
permuted_tokens = tokens.index_select(0, sorted_indices // topk)
return permuted_tokens, sorted_indices
def unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
probs: torch.Tensor = None,
padded_mode: bool = False,
restore_shape: torch.Size = None,
):
"""Unpermute a tensor of permuted tokens based on sorted indices, and
optionally merge the tokens with their corresponding probabilities.
Args:
permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted.
sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens.
probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities.
padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
restore_shape (torch.Size, optional): The input shape before permutation, only used in padding mode. Defaults to None.
Returns:
torch.Tensor: The unpermuted tokens, optionally merged with probabilities.
"""
if padded_mode:
return unpermute_with_padded_tokens(
permuted_tokens, sorted_indices, probs, restore_shape=restore_shape
)
assert sorted_indices.numel() == permuted_tokens.size(0)
if probs is not None:
# Unpermute and merge the tokens with their probabilities
num_unpermuted_tokens = probs.numel()
topk = probs.size(1)
else:
# Unpermute the tokens without merge
num_unpermuted_tokens = permuted_tokens.size(0)
topk = 1
unpermuted_tokens = torch.zeros(
[num_unpermuted_tokens, permuted_tokens.shape[-1]],
dtype=permuted_tokens.dtype,
device=permuted_tokens.device,
)
unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
if probs is not None:
unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
unpermuted_tokens = unpermuted_tokens.sum(dim=1)
return unpermuted_tokens
def permute_with_padded_tokens(tokens, indices):
"""Permute the tokens based on the indices, only used in padding mode.
The input indices shape is [num_expert, capacity], it indicates which tokens were selected by each expert separately.
Args:
tokens (torch.Tensor): The input token tensor.
indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
Returns:
torch.Tensor: The permuted tensor.
torch.Tensor: The sorted_indices corresponding permuted tensor.
"""
permuted_tokens = tokens.index_select(dim=0, index=indices.view(-1))
return permuted_tokens, indices
def unpermute_with_padded_tokens(
permuted_tokens: torch.Tensor,
indices: torch.Tensor,
probs: torch.Tensor,
restore_shape: torch.Size,
) -> torch.Tensor:
"""Unpermutes a padded permuted tokens based on sorted indices and merges
the tokens with their corresponding probabilities.
This function takes a tensor of permuted tokens and reorders them according to the provided indices. It also combines the tokens with their associated probabilities.
Parameters:
permuted_tokens (torch.Tensor): A 2D tensor containing permuted tokens.
indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
probs (torch.Tensor): A tensor with the same shape as indices, containing probabilities corresponding to each token.
restore_shape (torch.Size): The target shape for the unpermuted tokens tensor.
Returns:
torch.Tensor: A tensor of unpermuted tokens, merged with their probabilities.
"""
# Ensure permuted_tokens is 2D
assert permuted_tokens.dim() == 2, f"Got {permuted_tokens.dim()}D."
# Reshape and expand probabilities and indices to match permuted_tokens
probs = probs.view(-1).unsqueeze(-1)
indices = indices.view(-1, 1).expand(-1, permuted_tokens.shape[1])
assert (
permuted_tokens.shape == indices.shape
), "Shape mismatch between permuted_tokens and indices."
# Combine tokens with their probabilities
combined_output = probs * permuted_tokens
# Prepare a tensor of zeros with the desired output shape
empty_tokens = torch.zeros(
restore_shape,
dtype=combined_output.dtype,
device=combined_output.device,
requires_grad=True,
)
# Scatter the combined tokens back to their original positions
unpermuted_tokens = torch.scatter_add(empty_tokens, 0, indices, combined_output)
return unpermuted_tokens
def topk_softmax_with_capacity(
logits: torch.Tensor,
topk: int,
capacity_factor: Optional[float] = None,
pad_to_capacity: bool = False,
drop_policy: str = "probs",
):
"""Apply capacity and padding to the top-k selection.
Args:
logits (torch.Tensor): Logits tensor.
topk (int): The number of experts to select for each token.
capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number of tokens exceeds the capacity.
pad_to_capacity (bool): Whether to need padding in token drop mode.
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". If "prob", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Probs, indices and tokens_per_expert tensor.
(1) If there's no token padding, the shape of probs and indices is [tokens, top_k], indicating the selected experts for each token.
(2) If there's token padding, the shape of probs and indices is [num_expert, capacity], indicating the tokens selected for each expert.
"""
assert (
logits.dim() == 2
), f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens = logits.shape[0]
num_experts = logits.shape[1]
scores, top_indices = torch.topk(logits, k=topk, dim=1)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
if capacity_factor is None:
# TopK without capacity
tokens_per_expert = custom_histc(
top_indices, bins=num_experts, min=0, max=num_experts
)
return probs, top_indices, tokens_per_expert
else:
# TopK with capacity
expert_capacity = get_capacity(
num_tokens=num_tokens * topk,
num_experts=num_experts,
capacity_factor=capacity_factor,
)
# TopK selection, Maskout unused experts
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_mask = torch.zeros_like(logits).scatter(1, top_indices, 1)
# Maskout exceeded tokens
if drop_policy == "probs":
capacity_probs, capacity_indices = torch.topk(
topk_masked_gates, k=expert_capacity, dim=0, sorted=False
)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
elif drop_policy == "position":
_, capacity_indices = torch.topk(
topk_mask, k=expert_capacity, dim=0, sorted=False
)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
capacity_probs = torch.gather(topk_masked_gates, 0, capacity_indices)
else:
raise ValueError(f"Invalid drop_policy: {drop_policy}")
if pad_to_capacity:
final_probs, final_indices = (
capacity_probs.T.contiguous(),
capacity_indices.T.contiguous(),
)
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
else:
# Get exceed mask and maskout exceeded probs and indices
final_mask = torch.logical_and(topk_mask, capacity_mask)
drop_mask = torch.logical_not(final_mask)
exceed_mask = torch.gather(drop_mask, 1, top_indices)
final_probs = probs * torch.logical_not(exceed_mask)
final_indices = top_indices.clone().masked_fill_(
exceed_mask, torch.iinfo(torch.long).max
)
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
return final_probs, final_indices, tokens_per_expert_before_capacity
# logging related
aux_loss_names = ["load_balancing_loss", "z_loss"]
def update_aux_losses_tracker(
name: str, loss: torch.Tensor, layer_number: int, num_layers: int
):
"""Save the auxiliary loss for logging.
Args:
name (str): The name of the loss.
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
"""
from realhf.base.stats_tracker import MOE_AUX_LOSSES
assert name in aux_loss_names, f"Invalid aux loss name: {name}."
losses = MOE_AUX_LOSSES.get(name, None)
if losses is None:
losses = torch.zeros(num_layers, device=loss.device)
losses[layer_number] += loss.detach()