mirror of https://github.com/inclusionAI/AReaL
289 lines
8.3 KiB
Python
289 lines
8.3 KiB
Python
# Modified from Megatron-LM.
|
|
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
import torch
|
|
import torch.distributed
|
|
|
|
from realhf.base import constants
|
|
|
|
from .utils import split_tensor_along_last_dim
|
|
|
|
|
|
def _reduce(input_):
|
|
"""All-reduce the input tensor across model parallel group."""
|
|
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if constants.tensor_parallel_world_size() == 1:
|
|
return input_
|
|
|
|
# All-reduce.
|
|
torch.distributed.all_reduce(input_, group=constants.tensor_parallel_group())
|
|
return input_
|
|
|
|
|
|
def _split_along_last_dim(input_):
|
|
"""Split the tensor along its last dimension and keep the corresponding
|
|
slice."""
|
|
|
|
world_size = constants.tensor_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
|
|
# Split along last dimension.
|
|
input_list = split_tensor_along_last_dim(input_, world_size)
|
|
|
|
# Note: torch.split does not create contiguous tensors by default.
|
|
rank = constants.tensor_parallel_rank()
|
|
output = input_list[rank].contiguous()
|
|
|
|
return output
|
|
|
|
|
|
def _split_along_first_dim(input_):
|
|
"""Split the tensor along its first dimension and keep the corresponding
|
|
slice."""
|
|
|
|
world_size = constants.tensor_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
|
|
# Split along first dimension.
|
|
dim_size = input_.size()[0]
|
|
assert (
|
|
dim_size % world_size == 0
|
|
), "First dimension of the tensor should be divisible by tensor parallel size"
|
|
local_dim_size = dim_size // world_size
|
|
rank = constants.tensor_parallel_rank()
|
|
dim_offset = rank * local_dim_size
|
|
|
|
output = input_[dim_offset : dim_offset + local_dim_size].contiguous()
|
|
|
|
return output
|
|
|
|
|
|
def _gather_along_last_dim(input_):
|
|
"""Gather tensors and concatinate along the last dimension."""
|
|
|
|
world_size = constants.tensor_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
|
|
# Size and dimension.
|
|
last_dim = input_.dim() - 1
|
|
rank = constants.tensor_parallel_rank()
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
tensor_list[rank] = input_
|
|
torch.distributed.all_gather(
|
|
tensor_list, input_, group=constants.tensor_parallel_group()
|
|
)
|
|
|
|
# Note: torch.cat already creates a contiguous tensor.
|
|
output = torch.cat(tensor_list, dim=last_dim).contiguous()
|
|
|
|
return output
|
|
|
|
|
|
def _gather_along_first_dim(input_):
|
|
"""Gather tensors and concatinate along the first dimension."""
|
|
|
|
world_size = constants.tensor_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
|
|
dim_size = list(input_.size())
|
|
dim_size[0] = dim_size[0] * world_size
|
|
|
|
output = torch.empty(
|
|
dim_size, dtype=input_.dtype, device=constants.current_device()
|
|
)
|
|
torch.distributed._all_gather_base(
|
|
output, input_.contiguous(), group=constants.tensor_parallel_group()
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
def _reduce_scatter_along_first_dim(input_):
|
|
"""Reduce-scatter the input tensor across model parallel group."""
|
|
world_size = constants.tensor_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
|
|
dim_size = list(input_.size())
|
|
assert (
|
|
dim_size[0] % world_size == 0
|
|
), "First dimension of the tensor should be divisible by tensor parallel size"
|
|
|
|
dim_size[0] = dim_size[0] // world_size
|
|
|
|
# NOTE: We don't use in-place reduce-scatter because we don't want to
|
|
# corrupt the activations which will be used during the backward pass.
|
|
output = torch.empty(
|
|
dim_size, dtype=input_.dtype, device=constants.current_device()
|
|
)
|
|
torch.distributed._reduce_scatter_base(
|
|
output, input_.contiguous(), group=constants.tensor_parallel_group()
|
|
)
|
|
return output
|
|
|
|
|
|
class _CopyToModelParallelRegion(torch.autograd.Function):
|
|
"""Pass the input to the model parallel region."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return input_
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return input_
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _reduce(grad_output)
|
|
|
|
|
|
class _ReduceFromModelParallelRegion(torch.autograd.Function):
|
|
"""All-reduce the input from the model parallel region."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return _reduce(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return _reduce(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
|
|
class _ScatterToModelParallelRegion(torch.autograd.Function):
|
|
"""Split the input and keep only the corresponding chuck to the rank."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return _split_along_last_dim(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return _split_along_last_dim(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _gather_along_last_dim(grad_output)
|
|
|
|
|
|
class _GatherFromModelParallelRegion(torch.autograd.Function):
|
|
"""Gather the input from model parallel region and concatinate."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return _gather_along_last_dim(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return _gather_along_last_dim(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _split_along_last_dim(grad_output)
|
|
|
|
|
|
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
|
|
"""Split the input and keep only the corresponding chuck to the rank."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return _split_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return _split_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _gather_along_first_dim(grad_output)
|
|
|
|
|
|
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
|
|
"""Gather the input from sequence parallel region and concatinate."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_, model_parallel_output_grad=True):
|
|
return _gather_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_, model_parallel_output_grad=True):
|
|
ctx.model_parallel_output_grad = model_parallel_output_grad
|
|
return _gather_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
model_parallel_output_grad = ctx.model_parallel_output_grad
|
|
|
|
# If the computation graph after the gather operation is
|
|
# in the tensor parallel mode, output gradients need to reduce
|
|
# scattered and whereas if the computation is duplicated,
|
|
# output gradients need to be scattered.
|
|
if model_parallel_output_grad:
|
|
return _reduce_scatter_along_first_dim(grad_output), None
|
|
else:
|
|
return _split_along_first_dim(grad_output), None
|
|
|
|
|
|
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
|
|
"""Reduce scatter the input from the model parallel region."""
|
|
|
|
@staticmethod
|
|
def symbolic(graph, input_):
|
|
return _reduce_scatter_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_):
|
|
return _reduce_scatter_along_first_dim(input_)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _gather_along_first_dim(grad_output)
|
|
|
|
|
|
# -----------------
|
|
# Helper functions.
|
|
# -----------------
|
|
|
|
|
|
def copy_to_tensor_model_parallel_region(input_):
|
|
return _CopyToModelParallelRegion.apply(input_)
|
|
|
|
|
|
def reduce_from_tensor_model_parallel_region(input_):
|
|
return _ReduceFromModelParallelRegion.apply(input_)
|
|
|
|
|
|
def scatter_to_tensor_model_parallel_region(input_):
|
|
return _ScatterToModelParallelRegion.apply(input_)
|
|
|
|
|
|
def gather_from_tensor_model_parallel_region(input_):
|
|
return _GatherFromModelParallelRegion.apply(input_)
|
|
|
|
|
|
def scatter_to_sequence_parallel_region(input_):
|
|
return _ScatterToSequenceParallelRegion.apply(input_)
|
|
|
|
|
|
def gather_from_sequence_parallel_region(input_, model_parallel_output_grad=True):
|
|
return _GatherFromSequenceParallelRegion.apply(input_, model_parallel_output_grad)
|
|
|
|
|
|
def reduce_scatter_to_sequence_parallel_region(input_):
|
|
return _ReduceScatterToSequenceParallelRegion.apply(input_)
|