AReaL/realhf/impl/model/conversion/hf_registry.py

383 lines
15 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import *
import torch
import torch.distributed as dist
import transformers
from realhf.api.core import model_api
from realhf.base import constants, logging
from realhf.base.saveload_utils import (
copy_hf_configs,
load_safetensor,
split_state_dict_into_shards,
)
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.nn.real_llm_parallel import (
tp_merge_key,
tp_partition_real_model_state_dict,
)
logger = logging.getLogger("HF Registry")
@dataclasses.dataclass
class HFModelRegistry:
name: str
hf_cls_name: str
config_from_hf_converter: Callable[
[transformers.PretrainedConfig], model_api.ReaLModelConfig
]
config_to_hf_converter: Callable[
[model_api.ReaLModelConfig], transformers.PretrainedConfig
]
sd_from_hf_converter: Callable[[Dict, model_api.ReaLModelConfig], Dict]
sd_to_hf_converter: Callable[[Dict, model_api.ReaLModelConfig], Dict]
embedding_param_names: Callable[[model_api.ReaLModelConfig], List[str]]
tblock_param_names: Callable[[model_api.ReaLModelConfig, int], List[str]]
head_param_names: Callable[[model_api.ReaLModelConfig], List[str]]
real_config_maker: Optional[Callable] = None
def config_from_hf(
self,
hf_config: Optional[transformers.PretrainedConfig] = None,
model_path: Optional[str] = None,
is_critic: bool = False,
) -> model_api.ReaLModelConfig:
if hf_config is None:
hf_config = transformers.AutoConfig.from_pretrained(
model_path,
trust_remote_code=True,
force_download=True,
)
config = self.config_from_hf_converter(hf_config)
config.base_model_path = model_path
config.is_critic = is_critic
if config.is_critic:
config.tied_embedding = False
return config
def config_to_hf(
self, real_config: model_api.ReaLModelConfig
) -> transformers.PretrainedConfig:
return self.config_to_hf_converter(real_config)
def load(
self,
model: ReaLModel,
load_dir: str,
init_critic_from_actor: bool = False,
):
tik = time.perf_counter()
with open(os.path.join(load_dir, "config.json"), "r") as f:
hf_config = json.load(f)
if "architectures" in hf_config:
assert (
self.hf_cls_name == hf_config["architectures"][0]
), f"{self.hf_cls_name} != {hf_config['architectures'][0]}"
layer_indices = range(model.layer_idx_start, model.layer_idx_end)
required_hf_sd_names = []
for lidx in layer_indices:
if lidx == 0:
required_hf_sd_names += self.embedding_param_names(model.config)
elif lidx == model.config.n_layers + 1:
required_hf_sd_names += self.head_param_names(model.config)
else:
required_hf_sd_names += self.tblock_param_names(model.config, lidx - 1)
# Load embedding weights as well if tied_embedding is True.
required_hf_sd_names = set(required_hf_sd_names)
if (
model.config.tied_embedding
and not model.config.is_critic
and constants.is_last_pipe_stage()
):
required_hf_sd_names.union(self.embedding_param_names(model.config))
if os.path.exists(os.path.join(load_dir, "pytorch_model.bin.index.json")):
with open(os.path.join(load_dir, "pytorch_model.bin.index.json"), "r") as f:
hf_sd_mapping = json.load(f)["weight_map"]
files_to_load = set()
for name in required_hf_sd_names:
if name in hf_sd_mapping:
files_to_load.add(hf_sd_mapping[name])
elif os.path.exists(os.path.join(load_dir, "model.safetensors.index.json")):
with open(os.path.join(load_dir, "model.safetensors.index.json"), "r") as f:
hf_sd_mapping = json.load(f)["weight_map"]
files_to_load = set()
for name in required_hf_sd_names:
if name in hf_sd_mapping:
files_to_load.add(hf_sd_mapping[name])
elif os.path.exists(os.path.join(load_dir, "pytorch_model.bin")):
files_to_load = ["pytorch_model.bin"]
elif os.path.exists(os.path.join(load_dir, "model.safetensors")):
files_to_load = ["model.safetensors"]
else:
raise ValueError(
f"Could not find model file in {load_dir}. "
"Make sure you have downloaded the model correctly."
)
setup_time = time.perf_counter() - tik
def _load_ckpt(fn):
load_tik = time.perf_counter()
if fn.endswith(".safetensors"):
sd = load_safetensor(os.path.join(load_dir, fn))
else:
# set map_location to be CPU is a little bit faster
sd = torch.load(
os.path.join(load_dir, fn), map_location="cpu", weights_only=True
)
partition_tik = time.perf_counter()
sd = {k: v for k, v in sd.items() if k in required_hf_sd_names}
sd = self.sd_from_hf_converter(sd, model.config)
psd = tp_partition_real_model_state_dict(
sd,
model.config,
constants.tensor_parallel_world_size(),
constants.tensor_parallel_rank(),
)
return psd, partition_tik - load_tik, time.perf_counter() - partition_tik
load_times, partition_times = [], []
state_dict = {}
with ThreadPoolExecutor(
max_workers=min(4, max(1, os.cpu_count() // 8))
) as executor:
future_to_checkpoint = {
executor.submit(_load_ckpt, path): path for path in files_to_load
}
for future in as_completed(future_to_checkpoint):
path = future_to_checkpoint[future]
try:
psd, loat_t, part_t = future.result()
state_dict.update(psd)
load_times.append(loat_t)
partition_times.append(part_t)
except Exception as e:
raise RuntimeError(f"Error loading checkpoint from {path}: {e}")
# Remap embedding weights to the last layer if tied_embedding is True.
if (
model.config.tied_embedding
and not model.config.is_critic
and constants.is_last_pipe_stage()
):
state_dict[f"{model.config.n_layers + 1}.weight"] = state_dict[
"0.wte.weight"
]
if not constants.is_first_pipe_stage() and "0.wte.weight" in state_dict:
state_dict.pop("0.wte.weight")
copy_tik = time.perf_counter()
if init_critic_from_actor and constants.is_last_pipe_stage():
if f"{model.config.n_layers + 1}.weight" in state_dict:
state_dict.pop(f"{model.config.n_layers + 1}.weight")
assert len(state_dict) == len(model.state_dict()) - 1, (
len(state_dict),
len(model.state_dict()),
)
model.load_state_dict(state_dict, strict=False)
else:
try:
model.load_state_dict(state_dict, strict=True)
except Exception as e:
logger.error(
f"Loading state dict with strict=True failed. "
f"Have you set init_critic_from_actor=True "
f"in the model config if you are initializing "
f"a critic model from a regular LLM? Err: {e}"
)
raise e
# Some logging info
copy_time = time.perf_counter() - copy_tik
load_times = "[" + ", ".join(f"{t:.2f}" for t in load_times) + "]"
partition_times = "[" + ", ".join(f"{t:.2f}" for t in partition_times) + "]"
logger.debug(
f"Loading from HuggingFace Model setup time cost={setup_time:.2f}s, load time cost={load_times}, "
f"partition time cost={partition_times}, copy time cost={copy_time:.2f}s"
)
return model
def save(
self,
model: ReaLModel,
tokenizer: Optional[transformers.PreTrainedTokenizer],
save_dir: str,
):
tik = time.perf_counter()
os.makedirs(save_dir, exist_ok=True)
dp_rank = constants.data_parallel_rank()
pp_rank = constants.pipe_parallel_rank()
tp_rank = constants.tensor_parallel_rank()
tp_size = constants.tensor_parallel_world_size()
pp_size = constants.pipe_parallel_world_size()
dp_size = constants.data_parallel_world_size()
# We will gather parameters across the model parallel group,
# and save parameters to separate shards across the pipeline parallel group.
# To decrease the size of each saved file, we split the file
# of each pipeline stage into smaller shards.
approx_param_size = (
sum(v.numel() * v.element_size() for v in model.state_dict().values())
* tp_size
)
# By default a shard is at most 1GB. A small size enables parallel saving during training.
max_shard_size_byte = int(os.getenv("REAL_SAVE_MAX_SHARD_SIZE_BYTE", int(1e10)))
n_shards_this_stage = (
approx_param_size + max_shard_size_byte - 1
) // max_shard_size_byte
if approx_param_size <= 0 or n_shards_this_stage <= 0:
raise ValueError(
f"Invalid param_size={approx_param_size}, n_shards_this_stage={n_shards_this_stage}. "
"Have you instantiated the model?"
)
n_shards_this_stage = torch.tensor(
n_shards_this_stage, dtype=torch.int32, device=model.device
)
pp_stage_n_shards = [
torch.zeros_like(n_shards_this_stage) for _ in range(pp_size)
]
dist.all_gather(
pp_stage_n_shards,
n_shards_this_stage,
group=constants.pipe_parallel_group(),
)
pp_stage_n_shards = [int(n.item()) for n in pp_stage_n_shards]
assert all(x >= 1 for x in pp_stage_n_shards)
t1 = time.perf_counter()
# Gather parameters across the model parallel group.
sd = model.state_dict()
cpu_sd = {}
for k, v in sd.items():
if (
model.config.tied_embedding
and not model.config.is_critic
and k == f"{model.config.n_layers + 1}.weight"
):
continue
gather_list = [torch.zeros_like(v) for _ in range(tp_size)]
dist.all_gather(gather_list, v, group=constants.tensor_parallel_group())
gathered = tp_merge_key(k, gather_list, model.config)
cpu_sd[k] = gathered.cpu()
t2 = time.perf_counter()
hf_sd = self.sd_to_hf_converter(cpu_sd, model.config)
hf_config = self.config_to_hf_converter(model.config)
hf_config.architectures = [self.hf_cls_name]
hf_config.name_or_path = str(save_dir)
hf_config.torch_dtype = str(model.dtype).strip("torch.")
param_size = sum(
[value.numel() * value.element_size() for value in hf_sd.values()]
)
param_size = torch.tensor(param_size, dtype=torch.int64, device=model.device)
dist.all_reduce(
param_size,
op=dist.ReduceOp.SUM,
group=constants.pipe_parallel_group(),
)
param_size = param_size.item()
# Save tokenizer and huggingface model config.
if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
hf_config.save_pretrained(save_dir)
if tokenizer is not None:
tokenizer.save_pretrained(save_dir)
# Dump parameters to disk.
if len(pp_stage_n_shards) == 1 and pp_stage_n_shards[0] == 1:
fn = "pytorch_model.bin"
if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
torch.save(hf_sd, os.path.join(save_dir, fn))
else:
output_fn = (
"pytorch_model"
+ "-{shard:05d}"
+ f"-of-{sum(pp_stage_n_shards):05d}.bin"
)
n_shards = pp_stage_n_shards[pp_rank]
shard_offset = sum(pp_stage_n_shards[:pp_rank])
shards = split_state_dict_into_shards(hf_sd, n_shards)
bin_index = {}
bin_index["metadata"] = dict(total_size=param_size)
bin_index["weight_map"] = {}
weight_map = {}
mesh_size = dp_size * tp_size
mesh_idx = dp_rank * tp_size + tp_rank
n_shards_per_gpu = (n_shards + mesh_size - 1) // mesh_size
if mesh_idx < len(range(0, n_shards, n_shards_per_gpu)):
s = list(range(0, n_shards, n_shards_per_gpu))[mesh_idx]
else:
s = n_shards
# Since torch.save requires pickling, which is CPU-bound,
# parallelizing the saving process is not beneficial.
for i, shard in enumerate(shards[s : s + n_shards_per_gpu]):
shard_idx = shard_offset + i + s
torch.save(
shard,
os.path.join(save_dir, output_fn.format(shard=shard_idx + 1)),
)
for i, shard in enumerate(shards):
shard_idx = shard_offset + i
for k in shard:
weight_map[k] = output_fn.format(shard=shard_idx + 1)
weight_map_list = [None for _ in range(pp_size)]
dist.all_gather_object(
weight_map_list,
weight_map,
group=constants.pipe_parallel_group(),
)
for wm in weight_map_list:
bin_index["weight_map"].update(wm)
if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
with open(
os.path.join(save_dir, "pytorch_model.bin.index.json"), "w"
) as f:
json.dump(bin_index, f, indent=4)
# Copy other configs and remote codes.
if (
constants.parallelism_rank() == 0
and model.config.base_model_path is not None
):
copy_hf_configs(model.config.base_model_path, save_dir)
t3 = time.perf_counter()
metadata_t = t1 - tik
gather_cpu_t = t2 - t1
dump_t = t3 - t2
logger.debug(
f"Saving to HuggingFace Model metadata cost={metadata_t:.2f}s, "
f"gather/cpu copy cost={gather_cpu_t:.2f}s, "
f"dump cost={dump_t:.2f}s"
)