mirror of https://github.com/inclusionAI/AReaL
947 lines
35 KiB
Python
947 lines
35 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import os
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn as nn
|
|
import torch.utils.checkpoint
|
|
import transformers
|
|
|
|
from realhf.api.core import model_api
|
|
from realhf.api.core.config import ModelName
|
|
from realhf.base import constants, logging, topology
|
|
from realhf.base.monitor import CUDATimeMarkType, cuda_tmark, cuda_tmarked
|
|
from realhf.impl.model.comm.global_comm import NCCLProcessGroupInfo
|
|
from realhf.impl.model.comm.param_realloc import (
|
|
ReparallelizeReceiverStep,
|
|
ReparallelizeSenderStep,
|
|
ReparallelizeTraget,
|
|
_derive_reparallelize_comm_plan,
|
|
is_trainable,
|
|
)
|
|
from realhf.impl.model.nn.flatten_param import set_intervals, slice_intervals
|
|
from realhf.impl.model.utils.padding import pad_input, unpad_input
|
|
|
|
from .flatten_param import build_param_spec, map_param_to_contigous_memory
|
|
from .real_llm_base import (
|
|
OutputHead,
|
|
ParallelActorHead,
|
|
PipeCacheData,
|
|
PipeTransferData,
|
|
ReaLModelBlock,
|
|
SequenceParallelCriticHead,
|
|
VocabPositionEmbedding,
|
|
)
|
|
from .real_llm_generate import generate
|
|
from .real_llm_parallel import partition_pipeline_layers
|
|
|
|
logger = logging.getLogger("ReaLModel Interface")
|
|
|
|
|
|
def chunked_bcast(x, src, group, chunk_size_bytes=1024 * 1024**2):
|
|
assert len(x.shape) == 1
|
|
n_chunks = (x.numel() * x.dtype.itemsize + chunk_size_bytes - 1) // chunk_size_bytes
|
|
chunk_size = chunk_size_bytes // x.dtype.itemsize
|
|
for i in range(n_chunks):
|
|
if isinstance(x, torch.Tensor):
|
|
torch.distributed.broadcast(
|
|
x[i * chunk_size : (i + 1) * chunk_size], src=src, group=group
|
|
)
|
|
else:
|
|
assert isinstance(x, list) and len(x) == n_chunks
|
|
torch.distributed.broadcast(x[i], src=src, group=group)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DuckModelOutput:
|
|
logits: Optional[Union[List[torch.Tensor], torch.Tensor]] = None
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DuckGenerationOutput:
|
|
sequences: torch.Tensor
|
|
scores: Optional[torch.Tensor] = None
|
|
logits_mask: Optional[torch.Tensor] = None
|
|
|
|
|
|
def _sync_embedding_and_output_weights(layers: nn.ModuleList):
|
|
pp_size = constants.pipe_parallel_world_size()
|
|
pp_rank = constants.pipe_parallel_rank()
|
|
if pp_size == 1:
|
|
old_head_w = layers[-1].weight.data
|
|
layers[-1].weight = layers[0].wte.weight
|
|
del old_head_w
|
|
layers[0].wte.weight.zero_out_wgrad = True
|
|
return
|
|
|
|
if pp_rank != 0 and pp_rank != pp_size - 1:
|
|
return
|
|
|
|
if pp_rank == 0:
|
|
weight = layers[0].wte.weight
|
|
weight.shared_embedding = True
|
|
else:
|
|
weight = layers[-1].weight
|
|
weight.data.fill_(0.0)
|
|
# To make Megatron happy
|
|
weight.shared = True
|
|
weight.shared_embedding = True
|
|
|
|
group = constants.grid().embedding_proc_group
|
|
torch.distributed.all_reduce(weight.data, group=group)
|
|
|
|
|
|
class ReaLModel(nn.Module):
|
|
"""The transformer model used in ReaL.
|
|
|
|
This model supports 3D parallelism, offloaded inference,
|
|
and parameter reallocation. It is usually more efficient
|
|
than HuggingFace implementations.
|
|
|
|
During construction, model parameters are not instantiated
|
|
immediately because the model may be redistributed.
|
|
The method ``instantiate`` should be called before using
|
|
model parameters, e.g., forward or state dict.
|
|
|
|
:param config: The model configuration.
|
|
:type config: model_api.ReaLModelConfig
|
|
:param dtype: The data type of the model.
|
|
:type dtype: Optional[torch.dtype], optional
|
|
:param device: The device of the model.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: model_api.ReaLModelConfig,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
hf_model_family: Optional[str] = None,
|
|
):
|
|
super().__init__()
|
|
if dtype is None:
|
|
dtype = torch.float16
|
|
self.config = config
|
|
self.dtype = dtype
|
|
self.device = device
|
|
|
|
# The main attribute of the model: layers,
|
|
# including the embedding layer, decoder layers, and the output head.
|
|
self.layer_mapping = partition_pipeline_layers(
|
|
config,
|
|
constants.pipe_parallel_world_size(),
|
|
)
|
|
self.layer_idx_start = self.layer_mapping[constants.pipe_parallel_rank()][0]
|
|
self.layer_idx_end = self.layer_mapping[constants.pipe_parallel_rank()][1]
|
|
self.num_stages = constants.pipe_parallel_world_size()
|
|
|
|
self.layers = nn.ModuleList()
|
|
|
|
# The model is lazily instantiated due to parameter reallocation.
|
|
# For models that will be redistributed, we instantiate replica 0
|
|
# and do not instantiate other replicas.
|
|
self._instantiated = False
|
|
self._instantiation_hooks = []
|
|
|
|
# Attributes used for parameter reallocation.
|
|
self._reparallelize_targets: Dict[
|
|
Tuple[ModelName, ModelName], ReparallelizeTraget
|
|
] = {}
|
|
|
|
# Attributes used for offload.
|
|
self._offload_buffer = None
|
|
self._offloaded = False
|
|
|
|
# Attributes used for flattening parameters.
|
|
self.head_param_point_to_embedding = (
|
|
self.config.tied_embedding
|
|
and not self.config.is_critic
|
|
and constants.pipe_parallel_world_size() == 1
|
|
)
|
|
self._param_spec, self._param_size = build_param_spec(
|
|
list(range(self.layer_idx_start, self.layer_idx_end)),
|
|
self.config,
|
|
tp_size=constants.tensor_parallel_world_size(),
|
|
pp_size=constants.pipe_parallel_world_size(),
|
|
dp_size=constants.data_parallel_world_size(),
|
|
head_param_point_to_embedding=self.head_param_point_to_embedding,
|
|
)
|
|
self.contiguous_param = None
|
|
|
|
self.hf_model_family = hf_model_family
|
|
|
|
def save_to_hf(self, tokenizer, save_dir):
|
|
return getattr(self, f"to_{self.hf_model_family}")(tokenizer, save_dir)
|
|
|
|
def load_from_hf(self, load_dir, init_critic_from_actor):
|
|
return getattr(self, f"from_{self.hf_model_family}")(
|
|
load_dir, init_critic_from_actor
|
|
)
|
|
|
|
@property
|
|
def pre_process(self):
|
|
# A workaround to make Megatron-LM backend happy.
|
|
if constants.pipe_parallel_rank() == 0:
|
|
return self.layers[0]
|
|
elif constants.pipe_parallel_rank() == constants.pipe_parallel_world_size() - 1:
|
|
return self.layers[-1]
|
|
return None
|
|
|
|
@property
|
|
def post_process(self):
|
|
# A workaround to make Megatron-LM backend happy.
|
|
if constants.pipe_parallel_rank() == constants.pipe_parallel_world_size() - 1:
|
|
return self.layers[-1]
|
|
return None
|
|
|
|
def shared_embedding_or_output_weight(self) -> None | torch.Tensor:
|
|
# NOTE: Use this name in consistent with Megatron-LM.
|
|
if not self.config.tied_embedding or self.config.is_critic:
|
|
return None
|
|
if constants.is_first_pipe_stage():
|
|
return self.layers[0].wte.weight
|
|
elif constants.is_last_pipe_stage():
|
|
return self.layers[-1].weight
|
|
return None
|
|
|
|
def instantiate(self):
|
|
"""Instantiate the model parameters.
|
|
|
|
Note that users can append hooks to this method to do more
|
|
processing, such as loading from HuggingFace models.
|
|
"""
|
|
assert not self._instantiated or self.contiguous_param.numel() == 0
|
|
layers = []
|
|
for idx in range(self.layer_idx_start, self.layer_idx_end):
|
|
layers.append(self._build_layer(idx, self.config))
|
|
self.layers = nn.ModuleList(layers)
|
|
|
|
if self.config.tied_embedding and not self.config.is_critic:
|
|
_sync_embedding_and_output_weights(self.layers)
|
|
|
|
self.contiguous_param = torch.empty(
|
|
self._param_size, dtype=self.dtype, device=self.device
|
|
)
|
|
map_param_to_contigous_memory(
|
|
self.layers,
|
|
self.config,
|
|
self.head_param_point_to_embedding,
|
|
self._param_spec,
|
|
self.contiguous_param,
|
|
self.layer_idx_start,
|
|
allocate_only=False,
|
|
)
|
|
|
|
for h in self._instantiation_hooks:
|
|
h()
|
|
|
|
self._instantiated = True
|
|
self._instantiation_hooks = []
|
|
|
|
@property
|
|
def num_layers(self):
|
|
"""Return the number of embedding or transformer layers in this
|
|
pipeline stage."""
|
|
return self.layer_idx_end - self.layer_idx_start
|
|
|
|
@property
|
|
def is_critic(self):
|
|
return self.config.is_critic
|
|
|
|
def _build_layer(self, idx: int, config: model_api.ReaLModelConfig) -> nn.Module:
|
|
dtype = self.dtype
|
|
device = self.device
|
|
if idx == 0:
|
|
l = VocabPositionEmbedding(config, dtype=dtype, device=device)
|
|
elif idx == config.n_layers + 1:
|
|
l = self._build_output_head(config)
|
|
else:
|
|
l = ReaLModelBlock(
|
|
config=config,
|
|
layer_index=idx - 1,
|
|
output_layernorm=(idx == config.n_layers),
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
return l
|
|
|
|
def _build_output_head(self, config: model_api.ReaLModelConfig) -> nn.Module:
|
|
dtype = self.dtype
|
|
device = self.device
|
|
if config.is_critic and constants.sequence_parallel():
|
|
l = SequenceParallelCriticHead(
|
|
config.hidden_dim,
|
|
1,
|
|
bias=False,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
elif not config.is_critic and constants.tensor_parallel_world_size() > 1:
|
|
l = ParallelActorHead(
|
|
config.hidden_dim,
|
|
config.vocab_size,
|
|
norm_head=self.config.norm_head,
|
|
norm_softmax=self.config.norm_softmax,
|
|
bias=False,
|
|
gradient_accumulation_fusion=constants.gradient_accumulation_fusion(),
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
else:
|
|
l = OutputHead(
|
|
config.hidden_dim,
|
|
1 if config.is_critic else config.vocab_size,
|
|
bias=False,
|
|
norm_head=self.config.norm_head,
|
|
norm_softmax=self.config.norm_softmax,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
return l
|
|
|
|
def async_offload(self):
|
|
"""Trigger offload asynchronously."""
|
|
if not constants.use_cuda():
|
|
return
|
|
assert not self._offloaded
|
|
assert self._instantiated
|
|
if self._offload_buffer is None:
|
|
self._offload_buffer = torch.empty_like(
|
|
self.contiguous_param,
|
|
dtype=self.dtype,
|
|
device="cpu",
|
|
pin_memory=True,
|
|
)
|
|
else:
|
|
assert self._offload_buffer.shape == self.contiguous_param.shape
|
|
dummy_tensor = torch.tensor((), device=self.device, dtype=self.dtype)
|
|
self._offload_stream = torch.cuda.Stream()
|
|
self._offload_event = torch.cuda.Event()
|
|
self.contiguous_param = None
|
|
for i, l in enumerate(self.layers):
|
|
layer_idx = self.layer_idx_start + i
|
|
with torch.cuda.stream(self._offload_stream):
|
|
for k, p in l.named_parameters():
|
|
spec = self._param_spec[f"{layer_idx}.{k}"]
|
|
if (
|
|
self.head_param_point_to_embedding
|
|
and layer_idx == self.config.n_layers + 1
|
|
):
|
|
continue
|
|
self._offload_buffer[spec.start_idx : spec.end_idx].copy_(
|
|
p.data.view(-1), non_blocking=True
|
|
)
|
|
p.data = dummy_tensor
|
|
self._offload_event.record(self._offload_stream)
|
|
self._offloaded = True
|
|
|
|
def wait_for_offload(self):
|
|
"""Wait for offload to finish."""
|
|
if not constants.use_cuda():
|
|
return
|
|
assert self._offloaded
|
|
torch.cuda.current_stream().wait_event(self._offload_event)
|
|
|
|
def __overlapped_load_forward(
|
|
self, x: PipeTransferData, ys: List[PipeCacheData]
|
|
) -> Tuple[PipeTransferData, List[PipeCacheData]]:
|
|
assert len(ys) == self.num_layers
|
|
raw_pp_input = x.pp_input
|
|
self.contiguous_param = torch.empty(
|
|
self._param_size, dtype=self.dtype, device=self.device
|
|
)
|
|
map_param_to_contigous_memory(
|
|
self.layers,
|
|
self.config,
|
|
self.head_param_point_to_embedding,
|
|
self._param_spec,
|
|
self.contiguous_param,
|
|
self.layer_idx_start,
|
|
allocate_only=True,
|
|
)
|
|
self.wait_for_offload()
|
|
|
|
stream = torch.cuda.Stream()
|
|
events: List[torch.cuda.Event] = [
|
|
torch.cuda.Event() for _ in range(self.num_layers)
|
|
]
|
|
with torch.cuda.stream(stream):
|
|
for layer_idx, y, l, e in zip(
|
|
range(self.layer_idx_start, self.layer_idx_end),
|
|
ys,
|
|
self.layers,
|
|
events,
|
|
):
|
|
# NOTE: although we can do more fine-grained overlapping, the overhead that can be
|
|
# reduced is very small (~50ms), which is unnecessary for now.
|
|
for k, v in l.named_parameters():
|
|
spec = self._param_spec[f"{layer_idx}.{k}"]
|
|
v.data.copy_(
|
|
self._offload_buffer[spec.start_idx : spec.end_idx].view(
|
|
spec.shape
|
|
),
|
|
non_blocking=True,
|
|
)
|
|
e: torch.cuda.Event
|
|
e.record(stream)
|
|
|
|
for layer_idx, y, l, e in zip(
|
|
range(self.layer_idx_start, self.layer_idx_end),
|
|
ys,
|
|
self.layers,
|
|
events,
|
|
):
|
|
torch.cuda.default_stream().wait_event(e)
|
|
x = l(x, y)
|
|
x.pp_input = x.pp_output
|
|
self._offloaded = False
|
|
x.pp_input = raw_pp_input
|
|
return x, ys
|
|
|
|
def __forward(
|
|
self, x: PipeTransferData, ys: List[PipeCacheData]
|
|
) -> Tuple[PipeTransferData, List[PipeCacheData]]:
|
|
layers = self.layers
|
|
assert len(ys) == len(layers), (len(ys), len(layers))
|
|
raw_pp_input = x.pp_input
|
|
for i, (layer, y) in enumerate(zip(layers, ys)):
|
|
x = layer(x, y)
|
|
x.pp_input = x.pp_output
|
|
# Finally, pp_input is the input of this pipeline stage (maybe across several layers),
|
|
# pp_output is the output of this pipeline stage.
|
|
# In the first stage, pp_input is None.
|
|
x.pp_input = raw_pp_input
|
|
return x, ys
|
|
|
|
def forward(
|
|
self, x: PipeTransferData, ys: List[PipeCacheData]
|
|
) -> Tuple[PipeTransferData, List[PipeCacheData]]:
|
|
if x.max_seqlen is not None and not isinstance(x.max_seqlen, int):
|
|
x.max_seqlen = int(x.max_seqlen)
|
|
if x.cu_seqlens is not None and not isinstance(x.cu_seqlens, torch.IntTensor):
|
|
x.cu_seqlens = x.cu_seqlens.int()
|
|
|
|
# Copy input tensor to a pinned buffer.
|
|
tp_size = constants.tensor_parallel_world_size()
|
|
batch_length = None
|
|
if ys[0].packed_input_ids is not None:
|
|
batch_length = ys[0].packed_input_ids.shape[0]
|
|
if x.pp_input is not None:
|
|
batch_length = x.pp_input.shape[0]
|
|
assert batch_length is not None
|
|
padded_batch_length = (batch_length + tp_size - 1) // tp_size * tp_size
|
|
pad_size = padded_batch_length - batch_length
|
|
|
|
if (
|
|
constants.sequence_parallel()
|
|
and pad_size > 0
|
|
and ys[0].packed_input_ids is not None
|
|
):
|
|
_cu_seqlens = x.cu_seqlens
|
|
_max_seqlen = x.max_seqlen
|
|
_input_ids = ys[0].packed_input_ids
|
|
_pp_input = x.pp_input
|
|
|
|
x.cu_seqlens = torch.nn.functional.pad(
|
|
x.cu_seqlens, (0, 1), value=padded_batch_length
|
|
)
|
|
x.max_seqlen = max(x.max_seqlen, padded_batch_length - batch_length)
|
|
if ys[0].packed_input_ids is not None:
|
|
input_ids_buf = torch.zeros(
|
|
(padded_batch_length,),
|
|
dtype=torch.long,
|
|
device=self.device,
|
|
)
|
|
input_ids_buf[:batch_length] = ys[0].packed_input_ids
|
|
ys[0].packed_input_ids = input_ids_buf
|
|
|
|
if x.pp_input is not None:
|
|
pp_input_buf = torch.zeros(
|
|
(padded_batch_length, *x.pp_input.shape[1:]),
|
|
dtype=x.pp_input.dtype,
|
|
device=self.device,
|
|
)
|
|
pp_input_buf[:batch_length] = x.pp_input
|
|
x.pp_input = pp_input_buf
|
|
|
|
tmark_type = CUDATimeMarkType.forward
|
|
with cuda_tmarked("fwd", tmark_type):
|
|
# Main forward calls.
|
|
if not self._offloaded:
|
|
x, ys = self.__forward(x, ys)
|
|
else:
|
|
x, ys = self.__overlapped_load_forward(x, ys)
|
|
|
|
# Resume from padding.
|
|
if (
|
|
constants.sequence_parallel()
|
|
and pad_size > 0
|
|
and ys[0].packed_input_ids is not None
|
|
):
|
|
x.pp_output = x.pp_output[:-pad_size]
|
|
|
|
x.pp_input = _pp_input
|
|
ys[0].packed_input_ids = _input_ids
|
|
x.cu_seqlens = _cu_seqlens
|
|
x.max_seqlen = _max_seqlen
|
|
|
|
if x.store_kv_cache:
|
|
for y in ys:
|
|
if y.k_cache is not None:
|
|
y.k_cache = y.k_cache[:-pad_size]
|
|
if y.v_cache is not None:
|
|
y.v_cache = y.v_cache[:-pad_size]
|
|
|
|
# Release the memory used for TP gathering.
|
|
constants.clear_global_memory_buffer()
|
|
return x, ys
|
|
|
|
def _forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
cu_seqlens: torch.IntTensor,
|
|
position_ids: torch.LongTensor,
|
|
hidden_states: Optional[torch.Tensor],
|
|
k_caches: Optional[List[torch.Tensor]],
|
|
v_caches: Optional[List[torch.Tensor]],
|
|
cache_seqlens: Optional[torch.IntTensor],
|
|
max_seqlen: Optional[int],
|
|
):
|
|
if k_caches is None:
|
|
assert v_caches is None
|
|
assert cache_seqlens is None
|
|
k_caches = [None] * self.num_layers
|
|
v_caches = [None] * self.num_layers
|
|
|
|
h = hidden_states
|
|
for idx, l in enumerate(self.layers):
|
|
if isinstance(l, VocabPositionEmbedding):
|
|
h = l._forward(input_ids, position_ids)
|
|
elif isinstance(l, ReaLModelBlock):
|
|
h, _, _ = l._forward(
|
|
h,
|
|
cu_seqlens=cu_seqlens,
|
|
k_cache=k_caches[idx],
|
|
v_cache=v_caches[idx],
|
|
cache_seqlens=cache_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
)
|
|
elif isinstance(
|
|
l,
|
|
(
|
|
OutputHead,
|
|
SequenceParallelCriticHead,
|
|
ParallelActorHead,
|
|
),
|
|
):
|
|
h = l._forward(h)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported layer type {type(l)}")
|
|
|
|
return h
|
|
|
|
def state_dict(self, *args, **kwargs):
|
|
"""Map layer indices to global layer indices."""
|
|
state_dict = self.layers.state_dict(*args, **kwargs)
|
|
new_state_dict = {}
|
|
for k, v in state_dict.items():
|
|
k = k.lstrip("module.").lstrip("layers.")
|
|
local_idx = int(k.split(".")[0])
|
|
name = k.split(".", 1)[1]
|
|
new_state_dict[f"{local_idx + self.layer_idx_start}.{name}"] = v
|
|
return new_state_dict
|
|
|
|
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
|
|
new_state_dict = {}
|
|
for k, v in state_dict.items():
|
|
name = k.split(".", 1)[1]
|
|
global_idx = int(k.split(".")[0])
|
|
new_state_dict[f"layers.{global_idx - self.layer_idx_start}.{name}"] = v
|
|
return super().load_state_dict(
|
|
new_state_dict,
|
|
strict=strict,
|
|
assign=assign,
|
|
)
|
|
|
|
def build_reparallelization_plan(
|
|
self,
|
|
from_model_name: ModelName,
|
|
to_model_name: ModelName,
|
|
from_topo: topology.ProcessTopology,
|
|
to_topo: topology.ProcessTopology,
|
|
to_model_config: model_api.ReaLModelConfig,
|
|
pg_info: NCCLProcessGroupInfo,
|
|
from_model_config: None | model_api.ReaLModelConfig = None,
|
|
):
|
|
if from_model_config is None:
|
|
from_model_config = self.config
|
|
to_layer_mapping = partition_pipeline_layers(
|
|
to_model_config,
|
|
to_topo.get_dim("pipe"),
|
|
)
|
|
to_layers_handle_dict = {}
|
|
to_layer_indices = []
|
|
if constants.has_model_name(to_model_name):
|
|
with constants.model_scope(to_model_name):
|
|
to_pp_rank = constants.pipe_parallel_rank()
|
|
to_layer_indices = list(
|
|
range(
|
|
to_layer_mapping[to_pp_rank][0],
|
|
to_layer_mapping[to_pp_rank][1],
|
|
)
|
|
)
|
|
for _to_layer_idx in to_layer_indices:
|
|
l = self._build_layer(_to_layer_idx, to_model_config)
|
|
for v in l.parameters():
|
|
v.data = torch.tensor((), dtype=self.dtype, device=self.device)
|
|
to_layers_handle_dict[_to_layer_idx] = l
|
|
to_model_head_param_point_to_embedding = (
|
|
to_model_config.tied_embedding
|
|
and not to_model_config.is_critic
|
|
and to_topo.get_dim("pipe") == 1
|
|
)
|
|
to_param_spec, to_param_size = build_param_spec(
|
|
to_layer_indices,
|
|
to_model_config,
|
|
tp_size=to_topo.get_dim("tensor"),
|
|
dp_size=to_topo.get_dim("data"),
|
|
pp_size=to_topo.get_dim("pipe"),
|
|
head_param_point_to_embedding=to_model_head_param_point_to_embedding,
|
|
)
|
|
if len(to_layer_indices) > 0:
|
|
to_layer_idx_start = min(to_layer_indices)
|
|
to_layer_idx_end = max(to_layer_indices) + 1
|
|
else:
|
|
to_layer_idx_start = to_layer_idx_end = -1
|
|
to_layers_handle = nn.ModuleList(
|
|
[to_layers_handle_dict[i] for i in to_layer_indices]
|
|
)
|
|
|
|
comm_plan = _derive_reparallelize_comm_plan(
|
|
from_model_name=from_model_name,
|
|
to_model_name=to_model_name,
|
|
from_topo=from_topo,
|
|
to_topo=to_topo,
|
|
from_model_config=from_model_config,
|
|
to_model_config=to_model_config,
|
|
pg_info=pg_info,
|
|
dtype=self.dtype,
|
|
)
|
|
rtgt = ReparallelizeTraget(
|
|
comm_plan=comm_plan,
|
|
to_param_spec=to_param_spec,
|
|
to_param_size=to_param_size,
|
|
to_layers_handle=to_layers_handle,
|
|
to_layer_start_idx=to_layer_idx_start,
|
|
to_layer_end_idx=to_layer_idx_end,
|
|
)
|
|
self._reparallelize_targets[(from_model_name, to_model_name)] = rtgt
|
|
|
|
# FIXME: we can get topo given model name from constants
|
|
@cuda_tmark("param_realloc", CUDATimeMarkType.mem_layout)
|
|
def build_reparallelized_layers_async(
|
|
self,
|
|
from_model_name: ModelName,
|
|
to_model_name: ModelName,
|
|
from_topo: topology.ProcessTopology,
|
|
to_topo: topology.ProcessTopology,
|
|
to_model_config: model_api.ReaLModelConfig,
|
|
pg_info: NCCLProcessGroupInfo,
|
|
) -> Tuple[nn.ModuleList, torch.Tensor, torch.Tensor]:
|
|
"""Trigger the parameter realloaction from the source model to the
|
|
target model."""
|
|
|
|
assert not (is_trainable(from_model_name) and is_trainable(to_model_name))
|
|
assert is_trainable(from_model_name) or is_trainable(to_model_name)
|
|
|
|
if (from_model_name, to_model_name) not in self._reparallelize_targets:
|
|
self.build_reparallelization_plan(
|
|
from_model_name,
|
|
to_model_name,
|
|
from_topo,
|
|
to_topo,
|
|
to_model_config,
|
|
pg_info,
|
|
)
|
|
rtgt = self._reparallelize_targets[(from_model_name, to_model_name)]
|
|
|
|
# Since the default implementation of PyTorch optimizers holds
|
|
# the reference of trainable parameters, we cannot deallocate
|
|
# them even after parameter reallocation. Therefore, there is no
|
|
# need to release and re-allocate the trainable parameters back-and-forth.
|
|
# We simply store the layer handles and fetch them when converting back.
|
|
with constants.model_scope(from_model_name):
|
|
from_model_ranks = constants.parallelism_group_ranks()
|
|
if not is_trainable(from_model_name):
|
|
if torch.distributed.get_rank() in from_model_ranks:
|
|
dummy_tensor = torch.tensor((), dtype=self.dtype, device=self.device)
|
|
for p in self.layers.parameters():
|
|
p.data = dummy_tensor
|
|
self.contiguous_param = dummy_tensor
|
|
return None, None, 0.0
|
|
|
|
# The following tensor holds the contiguous memory of incoming parameters
|
|
# If this process is not a receiver, to_param_size is 0 and it's an empty tensor.
|
|
to_contiguous_param = torch.zeros(
|
|
rtgt.to_param_size,
|
|
dtype=self.dtype,
|
|
device=constants.current_device(),
|
|
)
|
|
to_model_head_param_point_to_embedding = (
|
|
to_model_config.tied_embedding
|
|
and not to_model_config.is_critic
|
|
and to_topo.get_dim("pipe") == 1
|
|
)
|
|
map_param_to_contigous_memory(
|
|
rtgt.to_layers_handle,
|
|
to_model_config,
|
|
to_model_head_param_point_to_embedding,
|
|
rtgt.to_param_spec,
|
|
to_contiguous_param,
|
|
rtgt.to_layer_start_idx,
|
|
allocate_only=True,
|
|
)
|
|
|
|
# Allocate tensors in advance to reduce overhead.
|
|
recv_buf_specs = []
|
|
send_buf_specs = []
|
|
comm_volume = torch.zeros(
|
|
(), dtype=torch.long, device=constants.current_device()
|
|
)
|
|
for step in rtgt.comm_plan:
|
|
if (
|
|
isinstance(step, ReparallelizeReceiverStep)
|
|
and step.rank == torch.distributed.get_rank()
|
|
):
|
|
if step.rank == step.src:
|
|
# TODO: we can develop a kernel to directly move the
|
|
# memory without creating an intermediate buffer.
|
|
buf = slice_intervals(
|
|
self.contiguous_param,
|
|
step.sender_param_intervals_cpu,
|
|
intervals_cuda=step.sender_param_intervals_cuda,
|
|
max_interval_size=step.sender_max_interval_size,
|
|
output_size=step.param_size,
|
|
)
|
|
else:
|
|
buf = torch.zeros(
|
|
step.param_size,
|
|
dtype=step.param_dtype,
|
|
device=constants.current_device(),
|
|
)
|
|
comm_volume += buf.numel()
|
|
|
|
recv_buf_specs.append(
|
|
dict(
|
|
src=buf,
|
|
dst=to_contiguous_param,
|
|
intervals_cpu=step.receiver_param_intervals_cpu,
|
|
intervals_cuda=step.receiver_param_intervals_cuda,
|
|
max_interval_size=step.receiver_max_interval_size,
|
|
)
|
|
)
|
|
|
|
if (
|
|
isinstance(step, ReparallelizeSenderStep)
|
|
and step.rank == torch.distributed.get_rank()
|
|
):
|
|
if step.group is not None:
|
|
buf = slice_intervals(
|
|
self.contiguous_param,
|
|
step.param_intervals_cpu,
|
|
intervals_cuda=step.param_intervals_cuda,
|
|
max_interval_size=step.max_interval_size,
|
|
output_size=step.param_size,
|
|
)
|
|
send_buf_specs.append(buf)
|
|
|
|
# Run boradcast!
|
|
recv_buf_cnt = 0
|
|
recv_events = []
|
|
for step in rtgt.comm_plan:
|
|
if constants.use_cuda():
|
|
s = torch.cuda.Stream()
|
|
ctx = torch.cuda.stream(s)
|
|
else:
|
|
ctx = contextlib.nullcontext()
|
|
with ctx:
|
|
if (
|
|
isinstance(step, ReparallelizeReceiverStep)
|
|
and step.rank == torch.distributed.get_rank()
|
|
):
|
|
if constants.use_cuda():
|
|
e = torch.cuda.Event()
|
|
else:
|
|
e = None
|
|
if step.rank != step.src:
|
|
buf = recv_buf_specs[recv_buf_cnt]["src"]
|
|
chunked_bcast(buf, src=step.src, group=step.group)
|
|
if constants.use_cuda():
|
|
e.record(s)
|
|
recv_events.append(e)
|
|
recv_buf_cnt += 1
|
|
|
|
if (
|
|
isinstance(step, ReparallelizeSenderStep)
|
|
and step.rank == torch.distributed.get_rank()
|
|
):
|
|
if step.group is not None:
|
|
buf = send_buf_specs.pop(0)
|
|
chunked_bcast(buf, src=step.rank, group=step.group)
|
|
|
|
# Post-processing.
|
|
assert len(send_buf_specs) == 0, len(send_buf_specs)
|
|
assert recv_buf_cnt == len(recv_buf_specs), (
|
|
len(recv_buf_specs),
|
|
recv_buf_cnt,
|
|
)
|
|
# assert len(state_dict) == 0
|
|
assert len(recv_events) == len(recv_buf_specs)
|
|
for e, x in zip(recv_events, recv_buf_specs):
|
|
if constants.use_cuda():
|
|
torch.cuda.current_stream().wait_event(e)
|
|
set_intervals(**x)
|
|
|
|
return rtgt.to_layers_handle, to_contiguous_param, comm_volume
|
|
|
|
def patch_reparallelization(self, x, eta):
|
|
if eta == 1.0:
|
|
self.layers, self.contiguous_param = x
|
|
else:
|
|
new_layers, new_param = x
|
|
self.contiguous_param = eta * new_param + (1 - eta) * self.contiguous_param
|
|
map_param_to_contigous_memory(
|
|
self.layers,
|
|
self.config,
|
|
self.head_param_point_to_embedding,
|
|
param_spec=self._param_spec,
|
|
contiguous_param=self.contiguous_param,
|
|
layer_idx_offset=self.layer_idx_start,
|
|
allocate_only=False,
|
|
)
|
|
dummy_tensor = torch.tensor((), dtype=self.dtype, device=self.device)
|
|
for p in new_layers.parameters():
|
|
p.data = dummy_tensor
|
|
assert self.layers is not None
|
|
assert self.contiguous_param is not None
|
|
assert self.contiguous_param.shape[0] > 0
|
|
for l in self.layers:
|
|
for p in l.parameters():
|
|
p.requires_grad_()
|
|
|
|
|
|
# a helper function to make real_model look like huggingface model
|
|
def generate_helper(
|
|
self: ReaLModel,
|
|
tokenizer: transformers.PreTrainedTokenizerFast,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
packed_input_ids: Optional[torch.Tensor] = None,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
gconfig: model_api.GenerationHyperparameters = dataclasses.field(
|
|
default_factory=model_api.GenerationHyperparameters
|
|
),
|
|
) -> DuckGenerationOutput:
|
|
assert (packed_input_ids is None) == (cu_seqlens is None) == (max_seqlen is None)
|
|
if attention_mask is None and input_ids is not None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
if packed_input_ids is None and attention_mask is not None:
|
|
packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input(
|
|
input_ids, attention_mask
|
|
)
|
|
current_forward = self.forward
|
|
self.forward = functools.partial(ReaLModel.forward, self)
|
|
seq, scores, mask, _, _ = generate(
|
|
model=self,
|
|
tokenizer=tokenizer,
|
|
packed_input_ids=packed_input_ids,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
gconfig=gconfig,
|
|
)
|
|
self.forward = current_forward
|
|
return DuckGenerationOutput(seq, scores, mask)
|
|
|
|
|
|
# a helper function to make real_model look like huggingface model
|
|
def forward_helper(
|
|
self: ReaLModel,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
packed_input_ids: Optional[torch.Tensor] = None,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
) -> DuckModelOutput:
|
|
assert (packed_input_ids is None) == (cu_seqlens is None) == (max_seqlen is None)
|
|
build_packed = False
|
|
if attention_mask is None and input_ids is not None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
if packed_input_ids is None and attention_mask is not None:
|
|
build_packed = True
|
|
packed_input_ids, indices, cu_seqlens, max_seqlen = unpad_input(
|
|
input_ids, attention_mask
|
|
)
|
|
batch_size, seqlen = input_ids.shape[:2]
|
|
x = PipeTransferData(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
|
ys = [PipeCacheData(packed_input_ids=packed_input_ids)] + [
|
|
PipeCacheData() for _ in range(self.config.n_layers + 1)
|
|
]
|
|
scores = ReaLModel.forward(self, x, ys)[0].pp_output
|
|
if build_packed:
|
|
scores = pad_input(scores, indices, batch_size, seqlen)
|
|
return DuckModelOutput(logits=scores)
|
|
|
|
|
|
def add_helper_functions(m: ReaLModel):
|
|
m.forward = functools.partial(forward_helper, m)
|
|
m.generate = functools.partial(generate_helper, m)
|
|
return m
|
|
|
|
|
|
def make_real_model(
|
|
name: ModelName,
|
|
device: torch.device,
|
|
model_path: str,
|
|
is_critic: bool,
|
|
init_from_scratch: bool,
|
|
init_critic_from_actor: bool,
|
|
dtype: Optional[str] = None,
|
|
hf_model_family: Optional[str] = None,
|
|
) -> model_api.Model:
|
|
if dtype == "fp16" or dtype == None:
|
|
dtype = torch.float16
|
|
elif dtype == "bf16":
|
|
dtype = torch.bfloat16
|
|
elif dtype == "fp32":
|
|
dtype = torch.float32
|
|
else:
|
|
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
|
|
|
tokenizer = model_api.load_hf_tokenizer(model_path)
|
|
mconfig = getattr(ReaLModel, f"config_from_{hf_model_family}")(
|
|
model_path=model_path,
|
|
is_critic=is_critic or init_critic_from_actor,
|
|
)
|
|
m = ReaLModel(mconfig, dtype=dtype, device=device, hf_model_family=hf_model_family)
|
|
|
|
if not init_from_scratch:
|
|
m._instantiation_hooks.append(
|
|
lambda: getattr(m, f"from_{hf_model_family}")(
|
|
load_dir=model_path, init_critic_from_actor=init_critic_from_actor
|
|
)
|
|
)
|
|
|
|
if constants.pipe_parallel_world_size() == 1:
|
|
m = add_helper_functions(m)
|
|
return model_api.Model(name, m, tokenizer, device, dtype=dtype)
|
|
|
|
|
|
model_api.register_model("real_model", make_real_model)
|