AReaL/realhf/impl/model/nn/flatten_param.py

403 lines
13 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import contextlib
import dataclasses
import math
import os
import subprocess
from typing import *
import numpy as np
import torch
import torch.utils.cpp_extension as torch_cpp_ext
from packaging.version import Version, parse
import realhf
from realhf.api.core import model_api
from realhf.api.core.config import ModelName
from realhf.base import constants, logging
from .real_llm_base import ReaLModelParamKeys
from .real_llm_parallel import (
get_real_model_param_shape,
intervals_partition_fn,
shape_partition_fn,
tp_partition_key,
)
try:
from realhf._C.interval_op import merge_intervals
except ImportError:
merge_intervals = None
logger = logging.getLogger("FlattenParam")
_FLAT_PARAM_INDICES_CACHE = {}
@dataclasses.dataclass
class ContiguousParamSpec:
start_idx: int
end_idx: int
shape: torch.Size
def _is_integer_list_contiguous(l: List[int]) -> bool:
return np.all(np.array(l) == np.arange(len(l)) + l[0])
def _are_intervals_contiguous(l: List[Tuple[int, int]]) -> bool:
l = sorted(l, key=lambda x: x[0])
res = True
for i in range(len(l) - 1):
res &= l[i][1] == l[i + 1][0]
return res
def recursive_getattr(obj, attr_string):
attrs = attr_string.split(".")
for attr in attrs:
obj = getattr(obj, attr)
return obj
def _slice_intervals_py(src: torch.Tensor, intervals: List[Tuple[int, int]]):
# Drop-in replacement for the C++ implementation.
assert len(src.shape) == 1
assert all([x[0] >= 0 for x in intervals])
assert all([x[1] <= src.shape[0] for x in intervals])
N = len(intervals)
slices = []
for i, j in intervals:
slices.append(src[i:j])
return torch.cat(slices, dim=0)
def _set_intervals_py(
src: torch.Tensor,
dst: torch.Tensor,
intervals: List[Tuple[int, int]],
):
# Drop-in replacement for the C++ implementation.
assert len(dst.shape) == len(src.shape) == 1
offset = 0
for i, j in intervals:
assert i >= 0
assert j <= dst.shape[0], (j, dst.shape[0])
dst[i:j] = src[offset : offset + j - i]
offset += j - i
assert offset == src.shape[0]
_SLICE_INTERVAL_EXT_WARNED = False
_SET_INTERVAL_EXT_WARNED = False
def slice_intervals(
src: torch.Tensor,
intervals_cpu: List[Tuple[int, int]] = None,
intervals_cuda: torch.Tensor = None,
output_size: int = None,
max_interval_size: Optional[int] = None,
):
if intervals_cpu is not None:
N = len(intervals_cpu)
else:
N = intervals_cuda.size(0)
if not constants.use_cuda() or N < 1024:
# NOTE: The CUDA implementation will launch a thread for each interval,
# which has a negative effect when the number of intervals is small.
return _slice_intervals_py(src, intervals_cpu)
try:
from realhf._C.interval_op_cuda import (
slice_intervals_bf16,
slice_intervals_fp16,
slice_intervals_fp32,
)
if src.dtype == torch.float32:
return slice_intervals_fp32(
src, intervals_cuda, output_size, max_interval_size
)
elif src.dtype == torch.float16:
return slice_intervals_fp16(
src, intervals_cuda, output_size, max_interval_size
)
elif src.dtype == torch.bfloat16:
return slice_intervals_bf16(
src, intervals_cuda, output_size, max_interval_size
)
else:
raise NotImplementedError(src.dtype)
except ImportError:
global _SLICE_INTERVAL_EXT_WARNED
if not _SLICE_INTERVAL_EXT_WARNED:
_SLICE_INTERVAL_EXT_WARNED = True
logger.warning(
f"The `slice_interval` extension not found. "
"Fallback to python, which can be very slow. "
"You should re-install the package with REAL_CUDA=1 or "
"set REAL_PARAM_REALLOC_OPT_LEVEL=1."
)
return _slice_intervals_py(src, intervals_cpu)
def set_intervals(
src: torch.Tensor,
dst: torch.Tensor,
intervals_cpu: List[Tuple[int, int]] = None,
intervals_cuda: torch.Tensor = None,
max_interval_size: Optional[int] = None,
):
if intervals_cpu is not None:
N = len(intervals_cpu)
else:
N = intervals_cuda.size(0)
if not constants.use_cuda() or N < 512 or not (src.is_cuda and dst.is_cuda):
# NOTE: The CUDA implementation will launch a thread for each interval,
# which has a negative effect when the number of intervals is small.
return _set_intervals_py(src, dst, intervals_cpu)
try:
from realhf._C.interval_op_cuda import (
set_intervals_bf16,
set_intervals_fp16,
set_intervals_fp32,
)
if src.dtype == torch.float32:
return set_intervals_fp32(src, dst, intervals_cuda, max_interval_size)
elif src.dtype == torch.float16:
return set_intervals_fp16(src, dst, intervals_cuda, max_interval_size)
elif src.dtype == torch.bfloat16:
return set_intervals_bf16(src, dst, intervals_cuda, max_interval_size)
else:
raise NotImplementedError(src.dtype)
except ImportError:
global _SET_INTERVAL_EXT_WARNED
if not _SET_INTERVAL_EXT_WARNED:
_SET_INTERVAL_EXT_WARNED = True
logger.warning(
f"The `set_interval` extension not found. "
"Fallback to python, which can be very slow. "
"You should re-install the package with REAL_CUDA=1 or "
"set REAL_PARAM_REALLOC_OPT_LEVEL=1."
)
return _set_intervals_py(src, dst, intervals_cpu)
def param_size_from_keys(
config: model_api.ReaLModelConfig,
src_tp_size: int,
sd_keys: List[str],
src2dst_tp_size: int,
src2dst_tp_rank: int,
head_param_point_to_embedding: bool,
) -> Tuple[List[int], int]:
param_size = 0
for k in sd_keys:
if (
head_param_point_to_embedding
and k == f"{config.n_layers + 1}.weight"
and "0.wte.weight" in sd_keys
):
continue
new_shape = tp_partition_key(
k,
get_real_model_param_shape(k, config, src_tp_size),
src2dst_tp_rank,
src2dst_tp_size,
config,
partition_fn=shape_partition_fn,
)
param_size += int(np.prod(new_shape))
return param_size
def build_param_spec(
layer_indices: List[int],
config: model_api.ReaLModelConfig,
dp_size: int,
tp_size: int,
pp_size: int,
head_param_point_to_embedding: bool,
bucket_size: int = 40000000,
) -> Tuple[Dict[str, ContiguousParamSpec], int]:
# TODO: omit parameters that do not require gradient?
# TODO: allow different dtypes for different buckets
if len(layer_indices) == 0:
return {}, 0
disable_bucketing = 0 not in layer_indices
sd_keys = []
for layer_idx in sorted(layer_indices):
if layer_idx == 0:
sd_keys += ReaLModelParamKeys.embed(config)
elif layer_idx == config.n_layers + 1:
sd_keys += ReaLModelParamKeys.head(config)
else:
sd_keys += ReaLModelParamKeys.tblock(config, layer_idx - 1)
# In the reverse order as backpropagation, consistent with Megatron.
sd_keys = list(reversed(sd_keys))
data_start_index = 0
bucket_data_start_index = data_start_index
bucket_params = set()
def _requires_new_allreduce_bucket(k):
if pp_size == 1:
return False
if config.is_critic:
return False
if not config.tied_embedding:
return False
return k == f"{config.n_layers + 1}.weight" or k == "0.wte.weight"
def _pad_to_multiple(x, m):
return int(math.ceil(x / m)) * m
def _create_fake_bucket(data_end_index) -> int:
nonlocal bucket_data_start_index, bucket_params
data_end_index = _pad_to_multiple(data_end_index, dp_size)
# Update bucket metadata.
bucket_data_start_index = data_end_index
# Re-set bucket_params and increment bucket_id for next bucket.
bucket_params = set()
# Return the potentially padded data_end_index.
return data_end_index
param_spec = {}
for k in sd_keys:
if head_param_point_to_embedding and k == f"{config.n_layers + 1}.weight":
continue
shape = get_real_model_param_shape(k, config, tp_size)
numel = int(np.prod(shape))
data_end_index = data_start_index + numel
if _requires_new_allreduce_bucket(k) and len(bucket_params) > 0:
_create_fake_bucket(data_start_index)
param_spec[k] = ContiguousParamSpec(
data_start_index,
data_end_index,
shape,
)
bucket_params.add(k)
if (
not disable_bucketing
and (data_end_index - bucket_data_start_index) >= bucket_size
) or _requires_new_allreduce_bucket(k):
data_end_index = _create_fake_bucket(data_end_index)
data_start_index = data_end_index
if len(bucket_params) > 0:
data_end_index = _create_fake_bucket(data_end_index)
if head_param_point_to_embedding and f"{config.n_layers + 1}.weight" in sd_keys:
param_spec[f"{config.n_layers + 1}.weight"] = param_spec["0.wte.weight"]
return param_spec, data_end_index
def param_intervals_from_keys(
model_name: ModelName,
config: model_api.ReaLModelConfig,
head_param_point_to_embedding: bool,
param_spec: Dict[str, ContiguousParamSpec],
tp_size: int,
sd_keys: List[str],
portion_size: int,
portion_rank: int,
) -> List[int]:
param_size = param_size_from_keys(
config=config,
src_tp_size=tp_size,
sd_keys=sd_keys,
src2dst_tp_size=portion_size,
src2dst_tp_rank=portion_rank,
head_param_point_to_embedding=head_param_point_to_embedding,
)
interval_size = 0
intervals = []
for k in sd_keys:
if (
head_param_point_to_embedding
and k == f"{config.n_layers + 1}.weight"
and "0.wte.weight" in sd_keys
):
continue
if (
model_name,
k.split(".", 1)[1],
tp_size,
portion_rank,
portion_size,
) not in _FLAT_PARAM_INDICES_CACHE:
zero_start_intervals = tp_partition_key(
k,
get_real_model_param_shape(k, config, tp_size),
portion_rank,
portion_size,
config,
partition_fn=intervals_partition_fn,
)
_FLAT_PARAM_INDICES_CACHE[
(
model_name,
k.split(".", 1)[1],
tp_size,
portion_rank,
portion_size,
)
] = zero_start_intervals
else:
zero_start_intervals = _FLAT_PARAM_INDICES_CACHE[
(
model_name,
k.split(".", 1)[1],
tp_size,
portion_rank,
portion_size,
)
]
intervals += (zero_start_intervals + param_spec[k].start_idx).tolist()
interval_size += sum(zero_start_intervals[:, 1] - zero_start_intervals[:, 0])
# assert len(set([x[0] for x in intervals])) == len(intervals)
assert interval_size == param_size, (interval_size, param_size)
if merge_intervals is not None:
intervals = merge_intervals(intervals)
return intervals
def map_param_to_contigous_memory(
layers: torch.nn.ModuleList,
config: model_api.ReaLModelConfig,
head_param_point_to_embedding: bool,
param_spec: Dict[str, ContiguousParamSpec],
contiguous_param: torch.Tensor,
layer_idx_offset: int,
allocate_only: bool,
):
for local_layer_idx, l in enumerate(layers):
layer_idx = local_layer_idx + layer_idx_offset
for k, v in l.named_parameters():
spec = param_spec[f"{layer_idx}.{k}"]
old_param_data = v.data
target = contiguous_param[spec.start_idx : spec.end_idx].view(spec.shape)
if not allocate_only:
target.copy_(old_param_data)
else:
if not (
head_param_point_to_embedding and layer_idx == config.n_layers + 1
):
assert old_param_data.shape == torch.Size([0]), (
old_param_data.shape,
spec.shape,
f"{layer_idx}.{k}",
)
recursive_getattr(l, k).data = target