mirror of https://github.com/inclusionAI/AReaL
1196 lines
45 KiB
Python
1196 lines
45 KiB
Python
# Modified from Megatron-LM.
|
|
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
import functools
|
|
import itertools
|
|
import os
|
|
import warnings
|
|
from typing import Callable, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.nn.init as init
|
|
from packaging.version import Version
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from realhf.base import constants
|
|
from realhf.impl.model.utils.random import _initialize_affine_weight_gpu
|
|
|
|
from .mappings import (
|
|
copy_to_tensor_model_parallel_region,
|
|
gather_from_sequence_parallel_region,
|
|
gather_from_tensor_model_parallel_region,
|
|
reduce_from_tensor_model_parallel_region,
|
|
reduce_scatter_to_sequence_parallel_region,
|
|
scatter_to_tensor_model_parallel_region,
|
|
)
|
|
from .utils import VocabUtility, divide, set_tensor_model_parallel_attributes
|
|
|
|
if Version(Version(torch.__version__).base_version) >= Version("2.4"):
|
|
# To disable an annoying FutureWarning
|
|
from torch.amp import custom_bwd, custom_fwd
|
|
|
|
custom_bwd = functools.partial(custom_bwd, device_type="cuda")
|
|
custom_fwd = functools.partial(custom_fwd, device_type="cuda")
|
|
else:
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
|
|
_grad_accum_fusion_available = True
|
|
try:
|
|
import fused_weight_gradient_mlp_cuda
|
|
except ImportError:
|
|
_grad_accum_fusion_available = False
|
|
|
|
import realhf.base.logging as logging
|
|
|
|
logger = logging.getLogger("tensor_parallel.modules")
|
|
|
|
|
|
def get_activation_fn(activation_function: str) -> Callable:
|
|
if activation_function == "gelu":
|
|
return nn.functional.gelu
|
|
elif activation_function == "gelu_new":
|
|
from realhf.impl.model.modules.activations import new_gelu_activation
|
|
|
|
return new_gelu_activation
|
|
elif activation_function == "silu":
|
|
return nn.SiLU()
|
|
else:
|
|
raise NotImplementedError('Only "gelu" activation function is available.')
|
|
|
|
|
|
class ParallelEmbedding(torch.nn.Module):
|
|
"""Embedding parallelized in the vocabulary dimension.
|
|
|
|
This is mainly adapted from torch.nn.Embedding and all the default
|
|
values are kept.
|
|
Arguments:
|
|
num_embeddings: vocabulary size.
|
|
embedding_dim: size of hidden state.
|
|
|
|
Keyword Arguments:
|
|
init_method: method to initialize weights.
|
|
perform_initialization
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
init_method=init.xavier_normal_,
|
|
# params_dtype: torch.dtype=torch.float32,
|
|
perform_initialization: bool = True,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
super(ParallelEmbedding, self).__init__()
|
|
# Keep the input dimensions.
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
# Set the detauls for compatibility.
|
|
self.padding_idx = None
|
|
self.max_norm = None
|
|
self.norm_type = 2.0
|
|
self.scale_grad_by_freq = False
|
|
self.sparse = False
|
|
self._weight = None
|
|
self.tensor_model_parallel_size = constants.tensor_parallel_world_size()
|
|
# Divide the weight matrix along the vocaburaly dimension.
|
|
self.vocab_start_index, self.vocab_end_index = (
|
|
VocabUtility.vocab_range_from_global_vocab_size(
|
|
self.num_embeddings,
|
|
constants.tensor_parallel_rank(),
|
|
self.tensor_model_parallel_size,
|
|
)
|
|
)
|
|
self.num_embeddings_per_partition = (
|
|
self.vocab_end_index - self.vocab_start_index
|
|
)
|
|
|
|
logger.debug(
|
|
f"ParallelEmbedding: num_embeddings={num_embeddings}, per_partition={self.num_embeddings_per_partition}, embedding_dim={embedding_dim},"
|
|
f"tp_rank={constants.tensor_parallel_rank()},tp_world_size={constants.tensor_parallel_world_size()}"
|
|
)
|
|
# Allocate weights and initialize.
|
|
self.weight = Parameter(
|
|
torch.empty(
|
|
self.num_embeddings_per_partition,
|
|
self.embedding_dim,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
)
|
|
if perform_initialization:
|
|
_initialize_affine_weight_gpu(
|
|
self.weight, init_method, partition_dim=0, stride=1
|
|
)
|
|
|
|
def forward(self, input_) -> torch.Tensor:
|
|
if self.tensor_model_parallel_size > 1:
|
|
# Build the mask.
|
|
input_mask = (input_ < self.vocab_start_index) | (
|
|
input_ >= self.vocab_end_index
|
|
)
|
|
# Mask the input.
|
|
masked_input = input_.clone() - self.vocab_start_index
|
|
masked_input[input_mask] = 0
|
|
else:
|
|
masked_input = input_
|
|
# Get the embeddings.
|
|
|
|
output_parallel = F.embedding(
|
|
masked_input,
|
|
self.weight,
|
|
self.padding_idx,
|
|
self.max_norm,
|
|
self.norm_type,
|
|
self.scale_grad_by_freq,
|
|
self.sparse,
|
|
)
|
|
# Mask the output embedding.
|
|
if self.tensor_model_parallel_size > 1:
|
|
output_parallel[input_mask, :] = 0.0
|
|
# Reduce across all the model parallel GPUs.
|
|
output = reduce_from_tensor_model_parallel_region(output_parallel)
|
|
return output
|
|
|
|
|
|
class LinearWithFrozenWeight(torch.autograd.Function):
|
|
"""Linear operator that does not calculate gradient for weight. This op and
|
|
LinearWithGradAccumulationAndAsyncCommunication performs mathematically-
|
|
identical forward and DGRAD.
|
|
|
|
Conceptually this op is the same as torch.nn.functional.linear with
|
|
weight.requires_grad==False, but in realhf.experiments they are not
|
|
identical mathematically.
|
|
"""
|
|
|
|
@staticmethod
|
|
@custom_fwd
|
|
def forward(
|
|
ctx,
|
|
input,
|
|
weight,
|
|
bias,
|
|
):
|
|
ctx.save_for_backward(weight)
|
|
output = torch.matmul(input, weight.t())
|
|
if bias is not None:
|
|
output = output + bias
|
|
return output
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx, grad_output):
|
|
(weight,) = ctx.saved_tensors
|
|
grad_input = grad_output.matmul(weight)
|
|
return grad_input, None, None
|
|
|
|
|
|
def linear_with_frozen_weight(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: Optional[torch.Tensor],
|
|
gradient_accumulation_fusion: bool,
|
|
async_grad_allreduce: bool,
|
|
sequence_parallel: bool,
|
|
) -> torch.Tensor:
|
|
"""Linear layer execution with weight.requires_grad == False.
|
|
|
|
This function handles linear layers with weight frozen (untrainable).
|
|
In the forward, it only saves weight and does not save input activations.
|
|
In the backward, it does not perform weight gradient calculation, or
|
|
weight gradient allreduce.
|
|
|
|
Arguments:
|
|
|
|
input (torch.Tensor required): input like torch.nn.functional.linear
|
|
|
|
weight (torch.Tensor required): weight like torch.nn.functional.linear
|
|
|
|
bias (torch.Tensor optional): bias like torch.nn.functional.linear
|
|
|
|
gradient_accumulation_fusion (bool required): dummy argument, used to
|
|
keep the API unified between all forward implementation functions.
|
|
|
|
async_grad_allreduce (bool required): dummy argument, used to
|
|
keep the API unified between all forward implementation functions.
|
|
|
|
sequence_parallel (bool required): Indicates that sequence
|
|
parallelism is used and thus in the forward pass the input is
|
|
all gathered, and the backward pass the input gradients are
|
|
reduce scattered.
|
|
"""
|
|
|
|
if sequence_parallel:
|
|
input = gather_from_sequence_parallel_region(
|
|
input, model_parallel_output_grad=True
|
|
)
|
|
else:
|
|
input = input
|
|
|
|
args = [
|
|
input,
|
|
weight,
|
|
bias,
|
|
]
|
|
|
|
return LinearWithFrozenWeight.apply(*args)
|
|
|
|
|
|
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
|
"""See linear_with_grad_accumulation_and_async_allreduce."""
|
|
|
|
@staticmethod
|
|
@custom_fwd
|
|
def forward(
|
|
ctx,
|
|
input,
|
|
weight,
|
|
bias,
|
|
gradient_accumulation_fusion,
|
|
async_grad_allreduce,
|
|
sequence_parallel,
|
|
):
|
|
# disable sequence parallel for now for it requires a global buffer
|
|
ctx.save_for_backward(input, weight)
|
|
ctx.use_bias = bias is not None
|
|
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
|
|
ctx.async_grad_allreduce = async_grad_allreduce
|
|
ctx.sequence_parallel = sequence_parallel
|
|
|
|
if sequence_parallel:
|
|
assert (
|
|
not ctx.async_grad_allreduce
|
|
), "async_grad_allreduce and sequence_parallel can not be both True"
|
|
world_size = constants.tensor_parallel_world_size()
|
|
dim_size = list(input.size())
|
|
dim_size[0] = dim_size[0] * world_size
|
|
|
|
all_gather_buffer = constants.get_global_memory_buffer().get_tensor(
|
|
dim_size, input.dtype, "mpu"
|
|
)
|
|
torch.distributed._all_gather_base(
|
|
all_gather_buffer, input, group=constants.tensor_parallel_group()
|
|
)
|
|
total_input = all_gather_buffer
|
|
else:
|
|
total_input = input
|
|
|
|
output = torch.matmul(total_input, weight.t())
|
|
if bias is not None:
|
|
output = output + bias
|
|
return output
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx, grad_output):
|
|
input, weight = ctx.saved_tensors
|
|
use_bias = ctx.use_bias
|
|
|
|
if ctx.sequence_parallel:
|
|
world_size = constants.tensor_parallel_world_size()
|
|
dim_size = list(input.size())
|
|
dim_size[0] = dim_size[0] * world_size
|
|
|
|
all_gather_buffer = constants.get_global_memory_buffer().get_tensor(
|
|
dim_size, input.dtype, "mpu"
|
|
)
|
|
handle = torch.distributed._all_gather_base(
|
|
all_gather_buffer,
|
|
input,
|
|
group=constants.tensor_parallel_group(),
|
|
async_op=True,
|
|
)
|
|
|
|
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
# gather is scheduled before the input gradient computation
|
|
total_input = all_gather_buffer
|
|
else:
|
|
total_input = input
|
|
grad_input = grad_output.matmul(weight)
|
|
|
|
if ctx.sequence_parallel:
|
|
handle.wait()
|
|
|
|
# Doing gather + slicing during the NeMo forward pass can make this tensor
|
|
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
|
|
# clones it if it's not contiguous:
|
|
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
|
|
grad_output = grad_output.contiguous()
|
|
# Convert the tensor shapes to 2D for execution compatibility
|
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
|
total_input = total_input.view(-1, total_input.shape[-1])
|
|
|
|
if ctx.async_grad_allreduce:
|
|
# Asynchronous all-reduce
|
|
handle = torch.distributed.all_reduce(
|
|
grad_input,
|
|
group=constants.tensor_parallel_group(),
|
|
async_op=True,
|
|
)
|
|
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
# all-reduce is scheduled before the weight gradient computation
|
|
|
|
if ctx.sequence_parallel:
|
|
assert not ctx.async_grad_allreduce
|
|
dim_size = list(input.size())
|
|
sub_grad_input = torch.empty(
|
|
dim_size,
|
|
dtype=input.dtype,
|
|
device=constants.current_device(),
|
|
requires_grad=False,
|
|
)
|
|
# reduce_scatter
|
|
handle = torch.distributed._reduce_scatter_base(
|
|
sub_grad_input,
|
|
grad_input,
|
|
group=constants.tensor_parallel_group(),
|
|
async_op=True,
|
|
)
|
|
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
# reduce scatter is scheduled before the weight gradient computation
|
|
|
|
if ctx.gradient_accumulation_fusion:
|
|
if weight.main_grad.dtype == torch.float32:
|
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
|
|
total_input, grad_output, weight.main_grad
|
|
)
|
|
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
|
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
|
|
total_input, grad_output, weight.main_grad
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
"Unsupported gradient type for gradient accumulation fusion"
|
|
)
|
|
|
|
if hasattr(weight, "grad_added_to_main_grad"):
|
|
# When overlap_grad_reduce is True, need to ensure that backward hooks
|
|
# are all run on the main backprop thread to prevent deadlocks. Setup
|
|
# dummy grad_weight tensor to prevent backward hooks from being run
|
|
# in a background thread.
|
|
if getattr(weight, "zero_out_wgrad", False):
|
|
grad_weight = torch.zeros(
|
|
weight.main_grad.shape,
|
|
dtype=input.dtype,
|
|
device=constants.current_device(),
|
|
requires_grad=False,
|
|
)
|
|
else:
|
|
grad_weight = torch.empty(
|
|
weight.main_grad.shape,
|
|
dtype=input.dtype,
|
|
device=constants.current_device(),
|
|
requires_grad=False,
|
|
)
|
|
weight.grad_added_to_main_grad = True
|
|
else:
|
|
grad_weight = None
|
|
else:
|
|
grad_weight = grad_output.t().matmul(total_input)
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
|
|
if ctx.sequence_parallel:
|
|
handle.wait()
|
|
return sub_grad_input, grad_weight, grad_bias, None, None, None
|
|
|
|
if ctx.async_grad_allreduce:
|
|
handle.wait()
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None
|
|
|
|
|
|
def linear_with_grad_accumulation_and_async_allreduce(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: Optional[torch.Tensor],
|
|
gradient_accumulation_fusion: bool,
|
|
async_grad_allreduce: bool,
|
|
sequence_parallel: bool,
|
|
) -> torch.Tensor:
|
|
"""Linear layer execution with asynchronous communication and gradient
|
|
accumulation fusion in backprop.
|
|
|
|
This has the option to accumulate the result of backprop
|
|
calculation into an existing gradient buffer, preventing the need
|
|
to do an additional addition kernel after the gradient
|
|
calculation.
|
|
|
|
Additionally, the tensor parallel all reduce of the input
|
|
gradients can be done asynchronously with the calculation of
|
|
the weight gradients.
|
|
|
|
In the case of sequence parallelism, the reduce scatter of the
|
|
input gradients is done asynchronously with the calcluation of the
|
|
weight gradients.
|
|
|
|
Use of this module requires that the environment variable
|
|
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
|
|
operations, noted in the code, that should be scheduled before
|
|
compute kernels to overlap the communication with the computation,
|
|
which is necessary for a speedup but not for correctness so that
|
|
ordering isn't imposed by the scheduler. Setting
|
|
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
|
|
in the order they are called.
|
|
|
|
Arguments:
|
|
|
|
input (torch.Tensor required): input like torch.nn.functional.linear
|
|
|
|
weight (torch.Tensor required): weight like torch.nn.functional.linear
|
|
|
|
bias (torch.Tensor optional): bias like torch.nn.functional.linear
|
|
|
|
gradient_accumulation_fusion (bool required): Perform the gradient
|
|
accumulation fusion, requires the custom CUDA extension
|
|
fused_weight_gradient_mlp_cuda module. To use
|
|
gradient_accumulation_fusion you must install APEX with
|
|
--cpp_ext and --cuda_ext. For example: "pip install
|
|
--global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
|
|
" Note that the extension requires CUDA>=11. Otherwise, you
|
|
must turn off gradient accumulation fusion."
|
|
|
|
async_grad_allreduce (bool required): Do the allreduce of input
|
|
gradients asyncronously with the computation of weight
|
|
gradients. If sequence_parallel_enabled is True, this must be
|
|
False, as no all reduce is performed.
|
|
|
|
sequence_parallel (bool required): Indicates that sequence
|
|
parallelism is used and thus in the forward pass the input is
|
|
all gathered, and the backward pass the input gradients are
|
|
reduce scattered.
|
|
"""
|
|
args = [
|
|
input,
|
|
weight,
|
|
bias,
|
|
gradient_accumulation_fusion,
|
|
async_grad_allreduce,
|
|
sequence_parallel,
|
|
]
|
|
|
|
if not linear_with_grad_accumulation_and_async_allreduce.warned:
|
|
if os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS") != "1":
|
|
if sequence_parallel:
|
|
warnings.warn(
|
|
"When using sequence parallelism it is recommended to set the "
|
|
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
|
|
"maximum speedup"
|
|
)
|
|
linear_with_grad_accumulation_and_async_allreduce.warned = True
|
|
|
|
if async_grad_allreduce:
|
|
warnings.warn(
|
|
"When using async grad allreduce it is recommended to set the "
|
|
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
|
|
"maximum speedup"
|
|
)
|
|
linear_with_grad_accumulation_and_async_allreduce.warned = True
|
|
|
|
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
|
|
|
|
|
|
linear_with_grad_accumulation_and_async_allreduce.warned = False
|
|
|
|
|
|
class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
@custom_fwd
|
|
def forward(
|
|
ctx,
|
|
input,
|
|
gradient_accumulation_fusion,
|
|
async_grad_allreduce,
|
|
sequence_parallel,
|
|
is_w_parallel,
|
|
*wbs,
|
|
):
|
|
# disable sequence parallel for now for it requires a global buffer
|
|
assert len(wbs) % 2 == 0
|
|
weights = wbs[::2]
|
|
biases = wbs[1::2]
|
|
assert len(is_w_parallel) == len(weights)
|
|
ctx.save_for_backward(input, *weights)
|
|
ctx.use_bias = tuple(b is not None for b in biases)
|
|
ctx.is_w_parallel = is_w_parallel
|
|
|
|
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
|
|
ctx.async_grad_allreduce = async_grad_allreduce
|
|
ctx.sequence_parallel = sequence_parallel
|
|
|
|
if sequence_parallel:
|
|
assert (
|
|
not ctx.async_grad_allreduce
|
|
), "async_grad_allreduce and sequence_parallel can not be both True"
|
|
world_size = constants.tensor_parallel_world_size()
|
|
dim_size = list(input.size())
|
|
dim_size[0] = dim_size[0] * world_size
|
|
|
|
all_gather_buffer = constants.get_global_memory_buffer().get_tensor(
|
|
dim_size, input.dtype, "mpu"
|
|
)
|
|
torch.distributed._all_gather_base(
|
|
all_gather_buffer, input, group=constants.tensor_parallel_group()
|
|
)
|
|
total_input = all_gather_buffer
|
|
else:
|
|
total_input = input
|
|
|
|
xs = []
|
|
for w, b in zip(weights, biases):
|
|
x = torch.matmul(total_input, w.t())
|
|
if b is not None:
|
|
x = x + b
|
|
xs.append(x)
|
|
return tuple(xs)
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx, *grads):
|
|
grads = list(grads)
|
|
input, *weights = ctx.saved_tensors
|
|
assert len(weights) == len(grads)
|
|
use_bias = ctx.use_bias
|
|
is_w_parallel = ctx.is_w_parallel
|
|
|
|
if ctx.sequence_parallel:
|
|
world_size = constants.tensor_parallel_world_size()
|
|
dim_size = list(input.size())
|
|
dim_size[0] = dim_size[0] * world_size
|
|
|
|
all_gather_buffer = constants.get_global_memory_buffer().get_tensor(
|
|
dim_size, input.dtype, "mpu"
|
|
)
|
|
handle = torch.distributed._all_gather_base(
|
|
all_gather_buffer,
|
|
input,
|
|
group=constants.tensor_parallel_group(),
|
|
async_op=True,
|
|
)
|
|
|
|
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
# gather is scheduled before the input gradient computation
|
|
total_input = all_gather_buffer
|
|
else:
|
|
total_input = input
|
|
grad_input = 0
|
|
for w, is_parallel, grad in zip(weights, is_w_parallel, grads):
|
|
if is_parallel or constants.tensor_parallel_rank() == 0:
|
|
grad_input = grad_input + grad.matmul(w)
|
|
|
|
if ctx.sequence_parallel:
|
|
handle.wait()
|
|
|
|
# Doing gather + slicing during the NeMo forward pass can make this tensor
|
|
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
|
|
# clones it if it's not contiguous:
|
|
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
|
|
# Convert the tensor shapes to 2D for execution compatibility
|
|
for i in range(len(grads)):
|
|
grads[i] = grads[i].contiguous().view(-1, grads[i].shape[-1])
|
|
total_input = total_input.view(-1, total_input.shape[-1])
|
|
|
|
if ctx.async_grad_allreduce:
|
|
# Asynchronous all-reduce
|
|
handle = torch.distributed.all_reduce(
|
|
grad_input,
|
|
group=constants.tensor_parallel_group(),
|
|
async_op=True,
|
|
)
|
|
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
# all-reduce is scheduled before the weight gradient computation
|
|
|
|
if ctx.sequence_parallel:
|
|
assert not ctx.async_grad_allreduce
|
|
dim_size = list(input.size())
|
|
sub_grad_input = torch.empty(
|
|
dim_size,
|
|
dtype=input.dtype,
|
|
device=constants.current_device(),
|
|
requires_grad=False,
|
|
)
|
|
# reduce_scatter
|
|
handle = torch.distributed._reduce_scatter_base(
|
|
sub_grad_input,
|
|
grad_input,
|
|
group=constants.tensor_parallel_group(),
|
|
async_op=True,
|
|
)
|
|
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
# reduce scatter is scheduled before the weight gradient computation
|
|
|
|
if ctx.gradient_accumulation_fusion:
|
|
gws = []
|
|
for weight, grad_output in zip(weights, grads):
|
|
if weight.main_grad.dtype == torch.float32:
|
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
|
|
total_input, grad_output, weight.main_grad
|
|
)
|
|
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
|
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
|
|
total_input, grad_output, weight.main_grad
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
"Unsupported gradient type for gradient accumulation fusion"
|
|
)
|
|
if hasattr(weight, "grad_added_to_main_grad"):
|
|
# When overlap_grad_reduce is True, need to ensure that backward hooks
|
|
# are all run on the main backprop thread to prevent deadlocks. Setup
|
|
# dummy grad_weight tensor to prevent backward hooks from being run
|
|
# in a background thread.
|
|
if getattr(weight, "zero_out_wgrad", False):
|
|
grad_weight = torch.zeros(
|
|
weight.main_grad.shape,
|
|
dtype=input.dtype,
|
|
device=constants.current_device(),
|
|
requires_grad=False,
|
|
)
|
|
else:
|
|
grad_weight = torch.empty(
|
|
weight.main_grad.shape,
|
|
dtype=input.dtype,
|
|
device=constants.current_device(),
|
|
requires_grad=False,
|
|
)
|
|
weight.grad_added_to_main_grad = True
|
|
else:
|
|
grad_weight = None
|
|
gws.append(grad_weight)
|
|
else:
|
|
gws = []
|
|
for w, g in zip(weights, grads):
|
|
gws.append(g.t().matmul(total_input))
|
|
gbs = [g.sum(dim=0) if use_bias[i] else None for i, g in enumerate(grads)]
|
|
|
|
if ctx.sequence_parallel:
|
|
handle.wait()
|
|
return (
|
|
sub_grad_input,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
*list(itertools.chain.from_iterable(zip(gws, gbs))),
|
|
)
|
|
|
|
if ctx.async_grad_allreduce:
|
|
handle.wait()
|
|
|
|
return (
|
|
grad_input,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
*list(itertools.chain.from_iterable(zip(gws, gbs))),
|
|
)
|
|
|
|
|
|
def merged_linear_with_grad_accumulation_and_async_allreduce(
|
|
input: torch.Tensor,
|
|
gradient_accumulation_fusion: bool,
|
|
async_grad_allreduce: bool,
|
|
sequence_parallel: bool,
|
|
is_w_parallel: List[bool],
|
|
*wbs: List[torch.Tensor | None],
|
|
) -> torch.Tensor:
|
|
"""Similar to linear_with_grad_accumulation_and_async_allreduce but does
|
|
multiple linear-layer forward/backward calls with a single all gather
|
|
operation."""
|
|
args = [
|
|
input,
|
|
gradient_accumulation_fusion,
|
|
async_grad_allreduce,
|
|
sequence_parallel,
|
|
is_w_parallel,
|
|
*wbs,
|
|
]
|
|
|
|
if not merged_linear_with_grad_accumulation_and_async_allreduce.warned:
|
|
if os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS") != "1":
|
|
if sequence_parallel:
|
|
warnings.warn(
|
|
"When using sequence parallelism it is recommended to set the "
|
|
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
|
|
"maximum speedup"
|
|
)
|
|
merged_linear_with_grad_accumulation_and_async_allreduce.warned = True
|
|
|
|
if async_grad_allreduce:
|
|
warnings.warn(
|
|
"When using async grad allreduce it is recommended to set the "
|
|
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
|
|
"maximum speedup"
|
|
)
|
|
merged_linear_with_grad_accumulation_and_async_allreduce.warned = True
|
|
|
|
return MergedLinearWithGradAccumulationAndAsyncCommunication.apply(*args)
|
|
|
|
|
|
merged_linear_with_grad_accumulation_and_async_allreduce.warned = False
|
|
|
|
|
|
class ColumnParallelLinear(torch.nn.Module):
|
|
"""Linear layer with column parallelism.
|
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
its second dimension as A = [A_1, ..., A_p].
|
|
|
|
Arguments:
|
|
input_size: first dimension of matrix A.
|
|
output_size: second dimension of matrix A.
|
|
|
|
Keyword Arguments
|
|
bias: If true, add bias
|
|
gather_output: If true, call all-gather on output and make Y available
|
|
to all GPUs, otherwise, every GPU will have its output
|
|
which is Y_i = XA_i
|
|
init_method: method to initialize weights. Note that bias is always set
|
|
to zero.
|
|
stride: For the strided linear layers.
|
|
keep_master_weight_for_test: This was added for testing and should be
|
|
set to False. It returns the master weights
|
|
used for initialization.
|
|
skip_bias_add: This was added to enable performance optimations where bias
|
|
can be fused with other elementwise operations. we skip
|
|
adding bias but instead return it.
|
|
sequence_parallel: Whether to all_gather input before doing linear
|
|
perform_initialization: Whether to perform initialization
|
|
gradient_accumulation_fusion: Whether to enable gradient accumulation fusion
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
output_size,
|
|
bias=True,
|
|
gather_output=False,
|
|
init_method=init.xavier_normal_,
|
|
stride=1,
|
|
skip_bias_add=False,
|
|
is_expert=False,
|
|
perform_initialization=True,
|
|
gradient_accumulation_fusion=False,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
super(ColumnParallelLinear, self).__init__()
|
|
|
|
# Keep input parameters
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
self.gather_output = gather_output
|
|
# Divide the weight matrix along the last dimension.
|
|
world_size = constants.tensor_parallel_world_size()
|
|
self.output_size_per_partition = divide(output_size, world_size)
|
|
self.skip_bias_add = skip_bias_add
|
|
self.is_expert = is_expert
|
|
assert skip_bias_add is False
|
|
|
|
# Parameters.
|
|
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
|
# we allocate the transpose.
|
|
# Initialize weight.
|
|
# logger.info(
|
|
# f"ColumnLinear: input_size={input_size}, output_size={output_size}, output_size_per_partition={self.output_size_per_partition}"
|
|
# )
|
|
self.weight = Parameter(
|
|
torch.empty(
|
|
self.output_size_per_partition,
|
|
self.input_size,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
)
|
|
if perform_initialization:
|
|
_initialize_affine_weight_gpu(
|
|
self.weight, init_method, partition_dim=0, stride=stride
|
|
)
|
|
|
|
if bias:
|
|
self.bias = Parameter(
|
|
torch.empty(self.output_size_per_partition, device=device, dtype=dtype)
|
|
)
|
|
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
|
# Always initialize bias to zero.
|
|
with torch.no_grad():
|
|
self.bias.zero_()
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
if gradient_accumulation_fusion:
|
|
if not _grad_accum_fusion_available:
|
|
raise RuntimeError(
|
|
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
|
|
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
|
|
"module is not found. To use gradient_accumulation_fusion you must "
|
|
"install APEX with --cpp_ext and --cuda_ext. For example: "
|
|
'pip install --global-option="--cpp_ext" --global-option="--cuda_ext ." '
|
|
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
|
|
"gradient accumulation fusion."
|
|
)
|
|
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
|
|
|
def forward(self, input_) -> torch.Tensor:
|
|
"""Forward of ColumnParallelLinear.
|
|
|
|
Args:
|
|
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
|
|
|
Returns:
|
|
- output
|
|
- bias
|
|
"""
|
|
bias = self.bias if not self.skip_bias_add else None
|
|
# NOTE: When sequence_parallel is enabled in MoE models, the gather and scatter of
|
|
# sequence parallel are done in MoE token dispatcher before and after permutation.
|
|
# Therefore, when used as experts, ColumnParallelLinear and RowParallelLinear
|
|
# in expert MLPs always behave as sequence parallel is not enabled.
|
|
sequence_parallel = constants.sequence_parallel() and not self.is_expert
|
|
async_tensor_model_parallel_allreduce = (
|
|
constants.tensor_parallel_world_size() > 1 and not sequence_parallel
|
|
)
|
|
|
|
if sequence_parallel:
|
|
input_parallel = input_
|
|
else:
|
|
input_parallel = copy_to_tensor_model_parallel_region(input_)
|
|
# Matrix multiply.
|
|
if not self.weight.requires_grad:
|
|
forward_impl = linear_with_frozen_weight
|
|
else:
|
|
forward_impl = linear_with_grad_accumulation_and_async_allreduce
|
|
output_parallel = forward_impl(
|
|
input=input_parallel,
|
|
weight=self.weight,
|
|
bias=bias,
|
|
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
|
|
async_grad_allreduce=async_tensor_model_parallel_allreduce,
|
|
sequence_parallel=sequence_parallel,
|
|
)
|
|
if self.gather_output:
|
|
# All-gather across the partitions.
|
|
assert not sequence_parallel
|
|
output = gather_from_tensor_model_parallel_region(output_parallel)
|
|
else:
|
|
output = output_parallel
|
|
# output_bias = self.bias if self.skip_bias_add else None
|
|
return output
|
|
|
|
|
|
class RowParallelLinear(torch.nn.Module):
|
|
"""Linear layer with row parallelism.
|
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
its first dimension and X along its second dimension as:
|
|
- -
|
|
| A_1 |
|
|
| . |
|
|
A = | . | X = [X_1, ..., X_p]
|
|
| . |
|
|
| A_p |
|
|
- -
|
|
Arguments:
|
|
input_size: first dimension of matrix A.
|
|
output_size: second dimension of matrix A.
|
|
|
|
Keyword Arguments:
|
|
bias: If true, add bias. Note that bias is not parallelized.
|
|
input_is_parallel: If true, we assume that the input is already
|
|
split across the GPUs and we do not split
|
|
again.
|
|
sequence_parallel: Whether sequence parallel is enabled.
|
|
init_method: method to initialize weights. Note that bias is always set
|
|
to zero.
|
|
stride: For the strided linear layers.
|
|
keep_master_weight_for_test: This was added for testing and should be
|
|
set to False. It returns the master weights
|
|
used for initialization.
|
|
skip_bias_add: This was added to enable performance optimization where bias
|
|
can be fused with other elementwise operations. We skip
|
|
adding bias but instead return it.
|
|
params_dtype:
|
|
use_cpu_initialization:
|
|
perform_initialization:
|
|
gradient_accumulation_fusion:
|
|
sequence_parallel_enabled:
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
output_size,
|
|
bias=True,
|
|
input_is_parallel=True,
|
|
init_method=init.xavier_normal_,
|
|
stride=1,
|
|
skip_bias_add=False,
|
|
is_expert=False,
|
|
perform_initialization=True,
|
|
gradient_accumulation_fusion=False,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
):
|
|
super(RowParallelLinear, self).__init__()
|
|
|
|
# Keep input parameters
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
self.input_is_parallel = input_is_parallel
|
|
# Divide the weight matrix along the last dimension.
|
|
world_size = constants.tensor_parallel_world_size()
|
|
self.input_size_per_partition = divide(input_size, world_size)
|
|
self.skip_bias_add = skip_bias_add
|
|
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
|
self.is_expert = is_expert
|
|
|
|
# Parameters.
|
|
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
|
# we allocate the transpose.
|
|
# Initialize weight.
|
|
self.weight = Parameter(
|
|
torch.empty(
|
|
self.output_size,
|
|
self.input_size_per_partition,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
)
|
|
if perform_initialization:
|
|
_initialize_affine_weight_gpu(
|
|
self.weight, init_method, partition_dim=1, stride=stride
|
|
)
|
|
if bias:
|
|
self.bias = Parameter(
|
|
torch.empty(self.output_size, device=device, dtype=dtype)
|
|
)
|
|
setattr(self.bias, "sequence_parallel", False)
|
|
|
|
# Always initialize bias to zero.
|
|
with torch.no_grad():
|
|
self.bias.zero_()
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
def forward(self, input_) -> torch.Tensor:
|
|
"""Forward of RowParallelLinear.
|
|
|
|
Args:
|
|
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
|
|
|
Returns:
|
|
- output
|
|
- bias
|
|
"""
|
|
# NOTE: ColumnParallelLinear and RowParallelLinear in expert MLPs always behave
|
|
# as sequence parallel is not enabled. See ColumnParallelLinear for more details.
|
|
sequence_parallel = constants.sequence_parallel() and not self.is_expert
|
|
if sequence_parallel and not self.input_is_parallel:
|
|
raise RuntimeError(
|
|
"To enable `sequence_parallel`, `input_is_parallel` must be `True`"
|
|
)
|
|
|
|
# Set up backprop all-reduce.
|
|
if self.input_is_parallel:
|
|
input_parallel = input_
|
|
else:
|
|
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
|
# Matrix multiply.
|
|
if not self.weight.requires_grad:
|
|
_forward_impl = linear_with_frozen_weight
|
|
else:
|
|
_forward_impl = linear_with_grad_accumulation_and_async_allreduce
|
|
output_parallel = _forward_impl(
|
|
input=input_parallel,
|
|
weight=self.weight,
|
|
bias=None,
|
|
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
|
|
async_grad_allreduce=False,
|
|
sequence_parallel=False, # Here false because we do not need allreduce grad in backward here
|
|
)
|
|
|
|
# All-reduce across all the partitions.
|
|
if sequence_parallel:
|
|
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
|
|
else:
|
|
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
|
output = output_ + self.bias if self.bias is not None else output_
|
|
return output
|
|
|
|
|
|
def parallel_lm_logits(
|
|
input_: torch.HalfTensor,
|
|
word_embeddings_weight: torch.HalfTensor,
|
|
parallel_output: bool = False,
|
|
gradient_accumulation_fusion: bool = False,
|
|
bias=None,
|
|
):
|
|
"""LM logits using word embedding weights."""
|
|
tensor_parallel = constants.tensor_parallel_world_size() > 1
|
|
sequence_parallel = constants.sequence_parallel()
|
|
async_grad_allreduce = not sequence_parallel and tensor_parallel
|
|
# Parallel logits.
|
|
if sequence_parallel:
|
|
input_parallel = input_
|
|
else:
|
|
input_parallel = copy_to_tensor_model_parallel_region(input_)
|
|
async_grad_allreduce = False
|
|
|
|
# Matrix multiply.
|
|
logits_parallel = linear_with_grad_accumulation_and_async_allreduce(
|
|
input=input_parallel,
|
|
weight=word_embeddings_weight,
|
|
bias=bias,
|
|
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
|
async_grad_allreduce=async_grad_allreduce,
|
|
sequence_parallel=sequence_parallel,
|
|
)
|
|
# Gather if needed.
|
|
|
|
if parallel_output:
|
|
return logits_parallel
|
|
|
|
return gather_from_tensor_model_parallel_region(logits_parallel)
|
|
|
|
|
|
class _VocabParallelCrossEntropy(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
|
|
# Maximum value along vocab dimension across all GPUs.
|
|
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
|
|
torch.distributed.all_reduce(
|
|
logits_max,
|
|
op=torch.distributed.ReduceOp.MAX,
|
|
group=constants.tensor_parallel_group(),
|
|
)
|
|
# Subtract the maximum value.
|
|
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
|
|
|
|
# Get the partition's vocab indecies
|
|
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
|
|
partition_vocab_size = vocab_parallel_logits.size()[-1]
|
|
rank = constants.tensor_parallel_rank()
|
|
world_size = constants.tensor_parallel_world_size()
|
|
vocab_start_index, vocab_end_index = get_vocab_range(
|
|
partition_vocab_size, rank, world_size
|
|
)
|
|
|
|
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
|
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
|
|
masked_target = target.clone() - vocab_start_index
|
|
masked_target[target_mask] = 0
|
|
|
|
# Get predicted-logits = logits[target].
|
|
# For Simplicity, we convert logits to a 2-D tensor with size
|
|
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
|
|
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
|
|
masked_target_1d = masked_target.view(-1)
|
|
arange_1d = torch.arange(
|
|
start=0, end=logits_2d.size()[0], device=logits_2d.device
|
|
)
|
|
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
|
|
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
|
predicted_logits = predicted_logits_1d.view_as(target)
|
|
predicted_logits[target_mask] = 0.0
|
|
# All reduce is needed to get the chunks from other GPUs.
|
|
torch.distributed.all_reduce(
|
|
predicted_logits,
|
|
op=torch.distributed.ReduceOp.SUM,
|
|
group=constants.tensor_parallel_group(),
|
|
)
|
|
|
|
# Sum of exponential of logits along vocab dimension across all GPUs.
|
|
exp_logits = vocab_parallel_logits
|
|
torch.exp(vocab_parallel_logits, out=exp_logits)
|
|
sum_exp_logits = exp_logits.sum(dim=-1)
|
|
torch.distributed.all_reduce(
|
|
sum_exp_logits,
|
|
op=torch.distributed.ReduceOp.SUM,
|
|
group=constants.tensor_parallel_group(),
|
|
)
|
|
|
|
# Loss = log(sum(exp(logits))) - predicted-logit.
|
|
loss = torch.log(sum_exp_logits) - predicted_logits
|
|
|
|
# Normalize and optionally smooth logits
|
|
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
|
|
|
vocab_size = exp_logits.size(-1)
|
|
if label_smoothing > 0:
|
|
"""
|
|
We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
|
|
= (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
|
|
= (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
|
|
= ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
|
|
= (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i
|
|
= (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
|
|
From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
|
|
"""
|
|
assert 1.0 > label_smoothing > 0.0
|
|
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
|
|
|
|
# Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
|
|
log_probs = torch.log(exp_logits)
|
|
mean_log_probs = log_probs.mean(dim=-1)
|
|
loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
|
|
|
|
ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
|
|
|
|
# Store softmax, target-mask and masked-target for backward pass.
|
|
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
|
|
|
|
return loss
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
# Retreive tensors from the forward path.
|
|
softmax, target_mask, masked_target_1d = ctx.saved_tensors
|
|
label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
|
|
|
|
# All the inputs have softmax as thier gradient.
|
|
grad_input = softmax
|
|
# For simplicity, work with the 2D gradient.
|
|
partition_vocab_size = softmax.size()[-1]
|
|
grad_2d = grad_input.view(-1, partition_vocab_size)
|
|
|
|
# Add the gradient from matching classes.
|
|
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
|
|
|
|
softmax_update = 1.0 - target_mask.view(-1).float()
|
|
|
|
if label_smoothing > 0:
|
|
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
|
|
grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update
|
|
average_grad = 1 / vocab_size
|
|
grad_2d[arange_1d, :] -= smoothing * average_grad
|
|
else:
|
|
grad_2d[arange_1d, masked_target_1d] -= softmax_update
|
|
|
|
# Finally elementwise multiplication with the output gradients.
|
|
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
|
|
|
return grad_input, None, None
|
|
|
|
|
|
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
|
|
"""Performs cross entropy loss when logits are split across tensor parallel
|
|
ranks.
|
|
|
|
Arguments:
|
|
vocab_parallel_logits: logits split across tensor parallel ranks
|
|
dimension is [sequence_length, batch_size, hidden_size]
|
|
|
|
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
|
|
|
|
lobal_smoothing: smoothing factor, must be in range [0.0, 1.0)
|
|
default is no smoothing (=0.0)
|
|
"""
|
|
return _VocabParallelCrossEntropy.apply(
|
|
vocab_parallel_logits, target, label_smoothing
|
|
)
|