mirror of https://github.com/inclusionAI/AReaL
129 lines
4.7 KiB
Python
129 lines
4.7 KiB
Python
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from safetensors.torch import save_file
|
|
|
|
from arealite.api.cli_args import TrainEngineConfig
|
|
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
|
from arealite.engine.base_hf_engine import BaseHFEngine
|
|
from arealite.utils.save_load import (
|
|
get_state_dict_from_repo_id_or_path,
|
|
is_existing_local_path,
|
|
)
|
|
from realhf.base import constants, logging
|
|
|
|
logger = logging.getLogger("DeepSpeedAutoTPEngine")
|
|
|
|
|
|
class DeepSpeedAutoTPEngine(BaseHFEngine):
|
|
def __init__(self, config: TrainEngineConfig):
|
|
super().__init__(config)
|
|
|
|
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
|
"""Initialize distributed communication and model."""
|
|
assert (
|
|
addr is None
|
|
), "DeepSpeedAutoTPEngine does not support remote initialization."
|
|
import deepspeed
|
|
|
|
self.create_process_group()
|
|
|
|
world_size = int(os.environ.get("WORLD_SIZE"))
|
|
deepspeed.init_distributed(
|
|
dist_backend="nccl",
|
|
world_size=world_size,
|
|
timeout=constants.NCCL_DEFAULT_TIMEOUT,
|
|
)
|
|
self.create_device_model()
|
|
# NOTE: the device context manager does not work here.
|
|
self.model = deepspeed.tp_model_init(
|
|
self.model,
|
|
tp_size=self.config.ds_auto_tp.autotp_size,
|
|
dtype=getattr(torch, self.config.dtype),
|
|
).to(self.device)
|
|
self.create_optimizer(ft_spec)
|
|
self.initialized = True
|
|
|
|
def _check_autotp(self):
|
|
tp_size = self.config.ds_auto_tp.autotp_size
|
|
config = self.model_config
|
|
num_attention_heads = config.num_attention_heads
|
|
num_key_value_heads = config.num_key_value_heads
|
|
hidden_size = config.hidden_size
|
|
intermediate_size = config.intermediate_size
|
|
|
|
return (
|
|
num_attention_heads % tp_size == 0
|
|
and num_key_value_heads % tp_size == 0
|
|
and hidden_size % tp_size == 0
|
|
and intermediate_size % tp_size == 0
|
|
)
|
|
|
|
def save(self, meta: SaveLoadMeta):
|
|
if meta.weight_format != "naive_distributed":
|
|
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
|
if self.model is None:
|
|
raise RuntimeError("Model not initialized")
|
|
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
if rank == 0:
|
|
os.makedirs(meta.path, exist_ok=True)
|
|
self.model_config.save_pretrained(
|
|
meta.path,
|
|
)
|
|
if meta.tokenizer is not None:
|
|
meta.tokenizer.save_pretrained(
|
|
meta.path,
|
|
)
|
|
|
|
state_dict = self.model.state_dict()
|
|
if hasattr(self.model, "module"):
|
|
state_dict = {
|
|
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
|
|
for k, v in state_dict.items()
|
|
}
|
|
else:
|
|
state_dict = {k: v.cpu() for k, v in state_dict.items()}
|
|
|
|
# Only support store parameters from model partitions respectively
|
|
gathered_state_dicts = None
|
|
if rank == 0:
|
|
gathered_state_dicts = [None for _ in range(world_size)]
|
|
dist.gather_object(
|
|
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
|
|
)
|
|
if rank == 0:
|
|
for i, state_dict in enumerate(gathered_state_dicts):
|
|
save_file(state_dict, f"{meta.path}/rank_{i:02d}_model.safetensors")
|
|
if meta.with_optim:
|
|
self.save_optimizer_state(meta.path)
|
|
|
|
def load(self, meta: SaveLoadMeta):
|
|
if meta.weight_format != "naive_distributed":
|
|
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
|
rank = dist.get_rank()
|
|
# Only support load full model parameters from huggingface
|
|
# and load model partition locally
|
|
if rank == 0 or is_existing_local_path(meta.path):
|
|
path = f"{meta.path}/rank_{rank:02d}_model.safetensors"
|
|
full_state = get_state_dict_from_repo_id_or_path(meta.path)
|
|
|
|
if hasattr(self.model, "module") and not hasattr(full_state):
|
|
full_state = {
|
|
f"module.{k}" if not k.startswith("module.") else k: v
|
|
for k, v in full_state.items()
|
|
}
|
|
self.model.load_state_dict(
|
|
full_state, strict=not self.model_config.tie_word_embeddings
|
|
)
|
|
if self.model_config.tie_word_embeddings:
|
|
self.model.tie_weights()
|
|
|
|
if meta.with_optim:
|
|
self.load_optimizer_state(meta.path)
|
|
|
|
def upload_weights(self, meta: WeightUpdateMeta):
|
|
raise ValueError(f"update weight not implemented {meta.type}")
|