mirror of https://github.com/inclusionAI/AReaL
224 lines
8.1 KiB
Python
224 lines
8.1 KiB
Python
import os
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Callable, Dict, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from tensordict import TensorDict
|
|
from torch.distributed.checkpoint.state_dict import (
|
|
StateDictOptions,
|
|
get_model_state_dict,
|
|
)
|
|
from transformers import PreTrainedTokenizerFast
|
|
|
|
from arealite.api.cli_args import TrainEngineConfig
|
|
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
|
|
from arealite.engine.base_hf_engine import BaseHFEngine
|
|
from arealite.utils.fsdp import (
|
|
CPUOffloadPolicy,
|
|
MixedPrecisionPolicy,
|
|
apply_fsdp2,
|
|
create_fsdp_device_mesh,
|
|
fsdp2_clip_grad_norm_,
|
|
fsdp2_load_full_state_dict,
|
|
)
|
|
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
|
|
from realhf.base import logging, name_resolve, names, pkg_version
|
|
|
|
logger = logging.getLogger("FSDPEngine")
|
|
|
|
|
|
class FSDPEngine(BaseHFEngine):
|
|
def __init__(self, config: TrainEngineConfig):
|
|
super().__init__(config)
|
|
# FSDP options
|
|
self.mixed_precision_policy = None
|
|
self.device_mesh = None
|
|
self.cpu_offload = None
|
|
|
|
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
|
# Initialize distributed enviroments and load model.
|
|
assert addr is None, "FSDPEngine does not support remote initialization."
|
|
assert pkg_version.is_version_greater_or_equal(
|
|
"torch", "2.4.0"
|
|
), f"arealite only supports FSDP2, which requires torch>=2.4.0"
|
|
|
|
self.create_process_group()
|
|
self.create_device_model()
|
|
|
|
# Wrap with FSDP2
|
|
# Simple auto wrap policy
|
|
self.mixed_precision_policy = MixedPrecisionPolicy(
|
|
param_dtype=getattr(torch, self.config.dtype),
|
|
reduce_dtype=torch.float32,
|
|
cast_forward_inputs=True,
|
|
)
|
|
self.device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size)
|
|
# sharding_strategy = ShardingStrategy.FULL_SHARD
|
|
self.cpu_offload = (
|
|
CPUOffloadPolicy() if self.config.fsdp.offload_params else None
|
|
)
|
|
fsdp_kwargs = {
|
|
"mesh": self.device_mesh,
|
|
"mp_policy": self.mixed_precision_policy,
|
|
"offload_policy": self.cpu_offload,
|
|
"reshard_after_forward": True,
|
|
}
|
|
tik = time.perf_counter()
|
|
apply_fsdp2(self.model, fsdp_kwargs, self.config.fsdp.wrap_policy)
|
|
logger.info(f"Applying FSDP2 time: {time.perf_counter() - tik}")
|
|
|
|
self.create_optimizer(ft_spec)
|
|
self.initialized = True
|
|
|
|
def save(self, meta: SaveLoadMeta):
|
|
if meta.weight_format == "hf":
|
|
self._save_model_to_hf(meta.path, meta.tokenizer)
|
|
elif meta.weight_format == "dcp":
|
|
# TODO: implement DCP save/load for FSDP
|
|
raise NotImplementedError("DCP format saving is not implemented yet. ")
|
|
else:
|
|
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
|
|
|
if meta.with_optim:
|
|
self.save_optimizer_state(meta.path)
|
|
|
|
def load(self, meta: SaveLoadMeta):
|
|
if meta.weight_format == "hf":
|
|
self._load_model_from_hf(meta.path)
|
|
elif meta.weight_format == "dcp":
|
|
# TODO: implement DCP save/load for FSDP
|
|
raise NotImplementedError("DCP format loading is not implemented yet. ")
|
|
else:
|
|
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
|
|
|
if meta.with_optim:
|
|
self.load_optimizer_state(meta.path)
|
|
|
|
def _save_model_to_hf(
|
|
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast]
|
|
):
|
|
"""Save model in HuggingFace format."""
|
|
if self.model is None:
|
|
raise RuntimeError("Model not initialized")
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
# FSDP2 checkpoint saving
|
|
# Get full state dict with FSDP2
|
|
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
|
state_dict = get_model_state_dict(self.model, options=options)
|
|
|
|
# save huggingface model on rank 0
|
|
if dist.get_rank() == 0:
|
|
os.makedirs(path, exist_ok=True)
|
|
self.model.save_pretrained(path, state_dict=state_dict)
|
|
self.model_config.save_pretrained(path)
|
|
if tokenizer is not None:
|
|
tokenizer.save_pretrained(path)
|
|
|
|
dist.barrier()
|
|
|
|
def _load_model_from_hf(self, path: str):
|
|
"""Load model from HuggingFace format."""
|
|
if dist.get_rank() == 0:
|
|
full_state = get_state_dict_from_repo_id_or_path(path)
|
|
else:
|
|
full_state = {}
|
|
|
|
fsdp2_load_full_state_dict(
|
|
self.model,
|
|
full_state,
|
|
self.cpu_offload,
|
|
tie_word_embeddings=self.model_config.tie_word_embeddings,
|
|
)
|
|
|
|
def upload_weights(self, meta: WeightUpdateMeta):
|
|
if meta.type == "nccl":
|
|
if not self.weight_update_group_initialized:
|
|
self._init_distributed_weight_update(meta)
|
|
self._update_weights_from_distributed()
|
|
elif meta.type == "disk":
|
|
self._save_model_to_hf(meta.path, self.tokenizer)
|
|
# dist.barrier() are called when _save_model_to_hf finished
|
|
if dist.get_rank() == 0:
|
|
update_name = names.update_weights_from_disk(
|
|
self.config.experiment_name,
|
|
self.config.trial_name,
|
|
meta.model_version,
|
|
)
|
|
name_resolve.add(
|
|
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown weight update type {meta.type}")
|
|
|
|
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
|
raise NotImplementedError(
|
|
"Distributed weight update is not implemented for FSDPEngine yet. "
|
|
)
|
|
|
|
def _update_weights_from_distributed(self):
|
|
raise NotImplementedError(
|
|
"Distributed weight update is not implemented for FSDPEngine yet. "
|
|
)
|
|
|
|
def train_batch(
|
|
self,
|
|
input_: TensorDict,
|
|
loss_fn: Callable[[torch.Tensor, TensorDict], torch.Tensor],
|
|
loss_weight_fn: Callable[[TensorDict], float],
|
|
) -> Dict[str, float]:
|
|
"""Train on a batch using gradient accumulation."""
|
|
input_ = input_.to(self.device)
|
|
assert self.optimizer is not None
|
|
assert self.optimizer_config is not None
|
|
assert self.lr_scheduler is not None
|
|
|
|
self.optimizer.zero_grad()
|
|
mb_list = self.prepare_mb_list(input_)
|
|
|
|
total_loss_weight = torch.tensor(
|
|
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
|
)
|
|
assert total_loss_weight != 0
|
|
dist.all_reduce(total_loss_weight)
|
|
|
|
# Process microbatches with gradient accumulation
|
|
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
|
|
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
|
|
):
|
|
outputs = self.model(**padded_mb_input)
|
|
|
|
logits = outputs.logits.squeeze(0)
|
|
logits = logits[:-pad_length] if pad_length > 0 else logits
|
|
loss = loss_fn(logits, mb_input)
|
|
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
|
|
|
# Scale loss for accumulation
|
|
# Revert gradient averaging across dp ranks
|
|
loss_scale *= self.world_size
|
|
|
|
loss *= loss_scale
|
|
loss.backward()
|
|
|
|
# NOTE: grad norm clip function is different
|
|
grad_norm = fsdp2_clip_grad_norm_(
|
|
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
|
|
)
|
|
if not torch.isfinite(grad_norm):
|
|
self.optimizer.zero_grad()
|
|
update_successful = False
|
|
else:
|
|
self.optimizer.step()
|
|
update_successful = True
|
|
|
|
current_lr = self.lr_scheduler.get_last_lr()[0]
|
|
# Optimizer step
|
|
self.optimizer.step()
|
|
return dict(
|
|
update_successful=float(update_successful),
|
|
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
|
lr=current_lr,
|
|
)
|