mirror of https://github.com/inclusionAI/AReaL
413 lines
17 KiB
Python
413 lines
17 KiB
Python
# Modified from Megatron-LM.
|
|
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
import math
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from realhf.base import constants, logging
|
|
|
|
|
|
def switch_load_balancing_loss_func(
|
|
probs: torch.Tensor,
|
|
tokens_per_expert: torch.Tensor,
|
|
topk: int,
|
|
moe_aux_loss_coeff: float,
|
|
sequence_partition_group=None,
|
|
):
|
|
"""Calculate the auxiliary loss for load balancing.
|
|
Refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
|
|
|
|
Args:
|
|
probs (torch.Tensor): Softmax probabilities output by the router for each token. [num_tokens, num_experts]
|
|
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. [num_experts]
|
|
topk (int): The number of experts selected for each token.
|
|
moe_aux_loss_coeff (float): The coefficient for the auxiliary loss.
|
|
sequence_partition_group (optional): The parallel group over which the sequence is partitioned. If None, no partitioning is applied. Defaults to None.
|
|
|
|
Returns:
|
|
torch.Tensor: The auxiliary loss for load balancing.
|
|
"""
|
|
num_sub_sequence = 1
|
|
|
|
# If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full sequence.
|
|
if sequence_partition_group is not None:
|
|
# We can keep `aggregated_probs_per_expert` local since we don't need the gradient for `tokens_per_expert`, saving one allreduce operation for `aggregated_probs_per_expert`.
|
|
# NOTE: Since the auxiliary loss is computed on the local `aggregated_probs_per_expert`, it requires scaling by `dist.world_size(sequence_partition_group)` when printing the loss.
|
|
num_sub_sequence = dist.get_world_size(sequence_partition_group)
|
|
dist.all_reduce(tokens_per_expert, group=sequence_partition_group)
|
|
|
|
num_tokens = probs.shape[0] * topk * num_sub_sequence
|
|
num_experts = probs.shape[1]
|
|
|
|
# The formula of aux_loss: aux_loss = sum((probs_per_expert/num_tokens) * (tokens_per_expert/num_tokens)) * num_experts * moe_aux_loss_coeff.
|
|
# This can be simplified to fuse the division and multiplication operations.
|
|
aggregated_probs_per_expert = probs.sum(dim=0)
|
|
aux_loss = torch.sum(aggregated_probs_per_expert * tokens_per_expert) * (
|
|
num_experts * moe_aux_loss_coeff / (num_tokens * num_tokens)
|
|
)
|
|
return aux_loss
|
|
|
|
|
|
def z_loss_func(logits, z_loss_coeff):
|
|
"""Encourages the router's logits to remain small to enhance stability.
|
|
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
|
|
|
|
Args:
|
|
logits (torch.Tensor): The logits of the router.
|
|
|
|
Returns:
|
|
torch.Tensor: The logits after applying the z-loss.
|
|
"""
|
|
|
|
z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
|
|
return z_loss
|
|
|
|
|
|
def sinkhorn(cost: torch.Tensor, tol: float = 0.0001):
|
|
"""Sinkhorn based MoE routing function."""
|
|
cost = torch.exp(cost)
|
|
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
|
|
d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
|
|
|
|
eps = 0.00000001
|
|
error = 1e9
|
|
d1_old = d1
|
|
while error > tol:
|
|
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
|
|
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
|
|
error = torch.mean(torch.abs(d1_old - d1))
|
|
d1_old = d1
|
|
return d1 * cost * d0.unsqueeze(1)
|
|
|
|
|
|
def get_capacity(
|
|
num_tokens: int, num_experts: int, capacity_factor: float, min_capacity=None
|
|
):
|
|
"""Calculate the capacity of each expert.
|
|
|
|
Args:
|
|
num_tokens (int): num of the input tokens.
|
|
num_experts (int): num of the experts.
|
|
capacity_factor (float): Capacity factor.
|
|
min_capacity (int, optional): Minimum capacity. Defaults to None.
|
|
|
|
Returns:
|
|
Tensor: Capacity of each expert.
|
|
"""
|
|
capacity = math.ceil((num_tokens / num_experts) * capacity_factor)
|
|
if min_capacity is not None and capacity < min_capacity:
|
|
capacity = min_capacity
|
|
return capacity
|
|
|
|
|
|
def custom_histc(input: torch.Tensor, bins: int, min: Any, max: Any):
|
|
"""A CPU compatible version of torch.histc."""
|
|
if input.is_cpu:
|
|
out = torch.zeros(bins, dtype=torch.long)
|
|
bin_width = (max - min) / bins
|
|
|
|
# Iterate over the input tensor and increment the appropriate bin
|
|
for value in input.flatten():
|
|
if min <= value < max:
|
|
bin_index = int((value - min) / bin_width)
|
|
out[bin_index] += 1
|
|
elif value == max:
|
|
out[bins - 1] += 1
|
|
return out
|
|
else:
|
|
return torch.histc(input, bins=bins, min=min, max=max)
|
|
|
|
|
|
class MoEAuxLossAutoScaler(torch.autograd.Function):
|
|
"""An AutoScaler that compute and scales the grad for auxiliary loss."""
|
|
|
|
main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
|
|
|
|
@staticmethod
|
|
def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
|
|
"""Preserve the aux_loss by storing it in the context to avoid garbage
|
|
collection.
|
|
|
|
Args:
|
|
output (torch.Tensor): The output tensor.
|
|
aux_loss (torch.Tensor): The auxiliary loss tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The output tensor.
|
|
"""
|
|
ctx.save_for_backward(aux_loss)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output: torch.Tensor):
|
|
"""Compute and scale the gradient for auxiliary loss..
|
|
|
|
Args:
|
|
grad_output (torch.Tensor): The gradient of the output.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient.
|
|
"""
|
|
(aux_loss,) = ctx.saved_tensors
|
|
aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
|
|
scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
|
|
return grad_output, scaled_aux_loss_grad
|
|
|
|
@staticmethod
|
|
def set_loss_scale(scale: torch.Tensor):
|
|
"""Set the scale of the aux loss.
|
|
|
|
Args:
|
|
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.
|
|
"""
|
|
MoEAuxLossAutoScaler.main_loss_backward_scale = scale
|
|
|
|
|
|
def permute(tokens, indices, num_out_tokens: int = None, padded_mode: bool = False):
|
|
"""Permute the tokens based on the indices. Token with the same index will be grouped together.
|
|
The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.
|
|
Args:
|
|
tokens (torch.Tensor): The input token tensor.
|
|
indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk].
|
|
num_out_tokens (int, optional): The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped.
|
|
padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
|
|
|
|
Returns:
|
|
torch.Tensor: The permuted tensor.
|
|
torch.Tensor: The sorted_indices corresponding permuted tensor.
|
|
"""
|
|
if padded_mode:
|
|
return permute_with_padded_tokens(tokens, indices)
|
|
|
|
if indices.dim() == 1:
|
|
topk = 1
|
|
else:
|
|
topk = indices.size(1)
|
|
|
|
flatten_indices = indices.view(-1)
|
|
sorted_indices = torch.argsort(flatten_indices, stable=True)
|
|
if num_out_tokens is not None:
|
|
sorted_indices = sorted_indices[:num_out_tokens]
|
|
permuted_tokens = tokens.index_select(0, sorted_indices // topk)
|
|
return permuted_tokens, sorted_indices
|
|
|
|
|
|
def unpermute(
|
|
permuted_tokens: torch.Tensor,
|
|
sorted_indices: torch.Tensor,
|
|
probs: torch.Tensor = None,
|
|
padded_mode: bool = False,
|
|
restore_shape: torch.Size = None,
|
|
):
|
|
"""Unpermute a tensor of permuted tokens based on sorted indices, and
|
|
optionally merge the tokens with their corresponding probabilities.
|
|
|
|
Args:
|
|
permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted.
|
|
sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens.
|
|
probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities.
|
|
padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
|
|
restore_shape (torch.Size, optional): The input shape before permutation, only used in padding mode. Defaults to None.
|
|
|
|
Returns:
|
|
torch.Tensor: The unpermuted tokens, optionally merged with probabilities.
|
|
"""
|
|
if padded_mode:
|
|
return unpermute_with_padded_tokens(
|
|
permuted_tokens, sorted_indices, probs, restore_shape=restore_shape
|
|
)
|
|
|
|
assert sorted_indices.numel() == permuted_tokens.size(0)
|
|
if probs is not None:
|
|
# Unpermute and merge the tokens with their probabilities
|
|
num_unpermuted_tokens = probs.numel()
|
|
topk = probs.size(1)
|
|
else:
|
|
# Unpermute the tokens without merge
|
|
num_unpermuted_tokens = permuted_tokens.size(0)
|
|
topk = 1
|
|
|
|
unpermuted_tokens = torch.zeros(
|
|
[num_unpermuted_tokens, permuted_tokens.shape[-1]],
|
|
dtype=permuted_tokens.dtype,
|
|
device=permuted_tokens.device,
|
|
)
|
|
unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
|
|
unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
|
|
if probs is not None:
|
|
unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
|
|
unpermuted_tokens = unpermuted_tokens.sum(dim=1)
|
|
|
|
return unpermuted_tokens
|
|
|
|
|
|
def permute_with_padded_tokens(tokens, indices):
|
|
"""Permute the tokens based on the indices, only used in padding mode.
|
|
The input indices shape is [num_expert, capacity], it indicates which tokens were selected by each expert separately.
|
|
Args:
|
|
tokens (torch.Tensor): The input token tensor.
|
|
indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
|
|
|
|
Returns:
|
|
torch.Tensor: The permuted tensor.
|
|
torch.Tensor: The sorted_indices corresponding permuted tensor.
|
|
"""
|
|
permuted_tokens = tokens.index_select(dim=0, index=indices.view(-1))
|
|
|
|
return permuted_tokens, indices
|
|
|
|
|
|
def unpermute_with_padded_tokens(
|
|
permuted_tokens: torch.Tensor,
|
|
indices: torch.Tensor,
|
|
probs: torch.Tensor,
|
|
restore_shape: torch.Size,
|
|
) -> torch.Tensor:
|
|
"""Unpermutes a padded permuted tokens based on sorted indices and merges
|
|
the tokens with their corresponding probabilities.
|
|
|
|
This function takes a tensor of permuted tokens and reorders them according to the provided indices. It also combines the tokens with their associated probabilities.
|
|
|
|
Parameters:
|
|
permuted_tokens (torch.Tensor): A 2D tensor containing permuted tokens.
|
|
indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
|
|
probs (torch.Tensor): A tensor with the same shape as indices, containing probabilities corresponding to each token.
|
|
restore_shape (torch.Size): The target shape for the unpermuted tokens tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: A tensor of unpermuted tokens, merged with their probabilities.
|
|
"""
|
|
# Ensure permuted_tokens is 2D
|
|
assert permuted_tokens.dim() == 2, f"Got {permuted_tokens.dim()}D."
|
|
|
|
# Reshape and expand probabilities and indices to match permuted_tokens
|
|
probs = probs.view(-1).unsqueeze(-1)
|
|
indices = indices.view(-1, 1).expand(-1, permuted_tokens.shape[1])
|
|
assert (
|
|
permuted_tokens.shape == indices.shape
|
|
), "Shape mismatch between permuted_tokens and indices."
|
|
|
|
# Combine tokens with their probabilities
|
|
combined_output = probs * permuted_tokens
|
|
|
|
# Prepare a tensor of zeros with the desired output shape
|
|
empty_tokens = torch.zeros(
|
|
restore_shape,
|
|
dtype=combined_output.dtype,
|
|
device=combined_output.device,
|
|
requires_grad=True,
|
|
)
|
|
|
|
# Scatter the combined tokens back to their original positions
|
|
unpermuted_tokens = torch.scatter_add(empty_tokens, 0, indices, combined_output)
|
|
|
|
return unpermuted_tokens
|
|
|
|
|
|
def topk_softmax_with_capacity(
|
|
logits: torch.Tensor,
|
|
topk: int,
|
|
capacity_factor: Optional[float] = None,
|
|
pad_to_capacity: bool = False,
|
|
drop_policy: str = "probs",
|
|
):
|
|
"""Apply capacity and padding to the top-k selection.
|
|
Args:
|
|
logits (torch.Tensor): Logits tensor.
|
|
topk (int): The number of experts to select for each token.
|
|
capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number of tokens exceeds the capacity.
|
|
pad_to_capacity (bool): Whether to need padding in token drop mode.
|
|
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". If "prob", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Probs, indices and tokens_per_expert tensor.
|
|
|
|
(1) If there's no token padding, the shape of probs and indices is [tokens, top_k], indicating the selected experts for each token.
|
|
(2) If there's token padding, the shape of probs and indices is [num_expert, capacity], indicating the tokens selected for each expert.
|
|
"""
|
|
assert (
|
|
logits.dim() == 2
|
|
), f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
|
|
num_tokens = logits.shape[0]
|
|
num_experts = logits.shape[1]
|
|
|
|
scores, top_indices = torch.topk(logits, k=topk, dim=1)
|
|
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
|
|
|
|
if capacity_factor is None:
|
|
# TopK without capacity
|
|
tokens_per_expert = custom_histc(
|
|
top_indices, bins=num_experts, min=0, max=num_experts
|
|
)
|
|
return probs, top_indices, tokens_per_expert
|
|
else:
|
|
# TopK with capacity
|
|
expert_capacity = get_capacity(
|
|
num_tokens=num_tokens * topk,
|
|
num_experts=num_experts,
|
|
capacity_factor=capacity_factor,
|
|
)
|
|
# TopK selection, Maskout unused experts
|
|
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
|
|
topk_mask = torch.zeros_like(logits).scatter(1, top_indices, 1)
|
|
|
|
# Maskout exceeded tokens
|
|
if drop_policy == "probs":
|
|
capacity_probs, capacity_indices = torch.topk(
|
|
topk_masked_gates, k=expert_capacity, dim=0, sorted=False
|
|
)
|
|
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
|
|
elif drop_policy == "position":
|
|
_, capacity_indices = torch.topk(
|
|
topk_mask, k=expert_capacity, dim=0, sorted=False
|
|
)
|
|
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
|
|
capacity_probs = torch.gather(topk_masked_gates, 0, capacity_indices)
|
|
else:
|
|
raise ValueError(f"Invalid drop_policy: {drop_policy}")
|
|
|
|
if pad_to_capacity:
|
|
final_probs, final_indices = (
|
|
capacity_probs.T.contiguous(),
|
|
capacity_indices.T.contiguous(),
|
|
)
|
|
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
|
|
else:
|
|
# Get exceed mask and maskout exceeded probs and indices
|
|
final_mask = torch.logical_and(topk_mask, capacity_mask)
|
|
drop_mask = torch.logical_not(final_mask)
|
|
exceed_mask = torch.gather(drop_mask, 1, top_indices)
|
|
final_probs = probs * torch.logical_not(exceed_mask)
|
|
final_indices = top_indices.clone().masked_fill_(
|
|
exceed_mask, torch.iinfo(torch.long).max
|
|
)
|
|
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
|
|
return final_probs, final_indices, tokens_per_expert_before_capacity
|
|
|
|
|
|
# logging related
|
|
aux_loss_names = ["load_balancing_loss", "z_loss"]
|
|
|
|
|
|
def update_aux_losses_tracker(
|
|
name: str, loss: torch.Tensor, layer_number: int, num_layers: int
|
|
):
|
|
"""Save the auxiliary loss for logging.
|
|
|
|
Args:
|
|
name (str): The name of the loss.
|
|
loss (torch.Tensor): The loss tensor.
|
|
layer_number (int): Layer index of the loss.
|
|
num_layers (int): The number of total layers.
|
|
"""
|
|
from realhf.base.stats_tracker import MOE_AUX_LOSSES
|
|
|
|
assert name in aux_loss_names, f"Invalid aux loss name: {name}."
|
|
losses = MOE_AUX_LOSSES.get(name, None)
|
|
if losses is None:
|
|
losses = torch.zeros(num_layers, device=loss.device)
|
|
losses[layer_number] += loss.detach()
|