[Feat][Refactor]Support DeepSpeed AutoTP; Refactor hf_engine.py and unit test. (#161)

* refactor hf engine

* format file

* revert file format

* Squashed commit of the following:

commit 8d4b8dc90f
Author: Wei Fu <36355462+garrett4wade@users.noreply.github.com>
Date:   Thu Jul 10 13:14:10 2025 +0800

    [Doc] Add an instruction about how to run the SFT example. (#164)

commit 3bf9c85e40
Author: Wei Fu <36355462+garrett4wade@users.noreply.github.com>
Date:   Thu Jul 10 12:56:24 2025 +0800

    [Fix] Merge previous contributions from fw/refactor to lite (#163)

    * initial proposal

    * add arealite

    * .

    * change api

    * .

    * remove LOG_ROOT

    * remove MODEL_SAVE_PATH

    * remove PARAM_REALLOC_PATH, DATASET_CACHE

    * prepare for testing

    * prepare for testing

    * ready for run

    * local run

    * tests mainly pass

    * format

    * .

    * amend cluster.py

    * .

    * .

    * client test pass

    * pass rollout test

    * remove unused imports

    * add arealite readme

    * change api

    * .

    * .

    * .

    * .

    * .

    * .

    * .

    * .

    * format

    * .

    * implement iteraptable generation (#112)

    Co-authored-by: zhaochenyang <zhaochenyang20@gmail.com>

    * .

    * fix

    * .

    * .

    * .

    * pass controller generate batch test

    * .

    * refactor rollout controller into worker and controller

    * .

    * .

    * .

    * change to async rollout

    * pass rollout controller test

    * pass test

    * .

    * update readme

    * .

    * sft debug

    * .

    * add lisence

    * remove unused files

    * remove unsed args in ppo

    * add hf engine wrapper  (#116)

    * add hf engine

    * fix issues

    * fix ppo bugs and add test

    * add hf client interface and modify cli args

    * fix bugs

    * fix issues

    * Merge fw/refactor

    * Finish hf wrapper test

    * add test

    ---------

    Co-authored-by: Wei Fu <36355462+garrett4wade@users.noreply.github.com>

    * format

    * format

    * .

    * refine hf engine

    * .

    * fix

    * add fsdp engine and sft tests

    * .

    * .

    * .

    * pass ppo unittest

    * pass ppo and rollout controller tests

    * clear unused imports

    * rename ppo to grpo

    * change reward function organization

    * reorganize code

    * add dataset api

    * .

    * .

    * .

    * format

    * chmod fix

    * .

    * rename workflow to collector

    * refactor llm_client location

    * .

    * .

    * fix llm server api

    * refactor config structure

    * .

    * fix tests

    * .

    * .

    * .

    * Fix unresolved issue in SFTTrainer PR (#139)

    * .

    * .

    * efficient loading

    * format

    * .

    * .

    * .

    * .

    * .

    * .

    * Add CI for testing AReaLite (#150)

    * ci: add test-arealite

    * ci: add checkout before running test-arealite

    * ci: add USERNAME

    * ci: add test script

    * ci: add GitHub mirror

    * ci: fix typo

    * ci: clone one commit

    * ci: fix condition

    * ci: set command timeout to 60m

    * ci: enable pip cache

    * ci: optimize container lifecycle

    * ci: split into many stages

    * ci(test-arealite): fix typo

    * ci: fix wrong env

    * ci: fix pytest

    * ci: uninstall transformer-engine

    * ci: uninstall transformer-engine

    * ci: fix model paths

    * ci: show stdout/stderr

    * ci: fix not clean up

    * ci: backup sglang

    * ci: remove tmp repo dir when run

    * ci: fix docker run exit 1 condition

    * ci(test-arealite): limit the concurrency and extend command timeout

    * .

    * merge fw/refactor

    * revert some changes

    * fix

    ---------

    Co-authored-by: meizhiyu.mzy <meizhiyu.mzy@antgroup.com>
    Co-authored-by: Chayenne <zhaochen20@outlook.com>
    Co-authored-by: zhaochenyang <zhaochenyang20@gmail.com>
    Co-authored-by: Jayon02 <qiujiangc@outlook.com>
    Co-authored-by: root <meizhiyu.mzy>
    Co-authored-by: Zijian Zhang <futrime@outlook.com>

commit d48bf007cf
Merge: 42c717b b9dbd4a
Author: 博惟 <bowei.fw@antgroup.com>
Date:   Thu Jul 10 12:53:30 2025 +0800

    Merge branch 'main' of https://github.com/inclusionAI/AReaL into lite

commit 42c717b6e4
Merge: c38cffc a203c7c
Author: 博惟 <bowei.fw@antgroup.com>
Date:   Thu Jul 10 11:15:01 2025 +0800

    Merge branch 'lite' of https://github.com/inclusionAI/AReaL into lite

commit c38cffc023
Author: 博惟 <bowei.fw@antgroup.com>
Date:   Thu Jul 10 11:10:10 2025 +0800

    PullRequest: 340 [lite] Refactor trainer API into utilities and remove mb_spec in engine methods

    Merge branch fw/lite-dev of git@code.alipay.com:inclusionAI/AReaL.git into lite
    https://code.alipay.com/inclusionAI/AReaL/pull_requests/340

    Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>

    * support fsdp engine and sglang remote engine
    * minor fix
    * .
    * refactor trainer
    * add close
    * rm mb_spec
    * fix

commit b9dbd4a2c1
Author: Wei Fu <36355462+garrett4wade@users.noreply.github.com>
Date:   Wed Jul 9 10:50:19 2025 +0800

    Update to persistent wechat QR code. (#159)

commit 17ea7fe94d
Author: xssstory <33601810+xssstory@users.noreply.github.com>
Date:   Mon Jul 7 15:49:13 2025 +0800

    fix math reward verifier (#156)

    * PullRequest: 293 fix get_param_realloc_path

    Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
    https://code.alipay.com/inclusionAI/AReaL/pull_requests/293

    Reviewed-by: 博惟 <bowei.fw@antgroup.com>

    * fix get_param_realloc_path

    * PullRequest: 297 bugfix: reward is always -5

    Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
    https://code.alipay.com/inclusionAI/AReaL/pull_requests/297

    Reviewed-by: 博惟 <bowei.fw@antgroup.com>

    * bugfix: reward is always -5

    * PullRequest: 321 fix checkpoint save dir

    Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
    https://code.alipay.com/inclusionAI/AReaL/pull_requests/321

    Reviewed-by: 博惟 <bowei.fw@antgroup.com>

    * fix checkpoint save dir

    * PullRequest: 328 [Doc] update installation

    Merge branch sxj/doc of git@code.alipay.com:inclusionAI/AReaL.git into gh
    https://code.alipay.com/inclusionAI/AReaL/pull_requests/328

    Reviewed-by: 博惟 <bowei.fw@antgroup.com>

    * [Doc] update installation

    * PullRequest: 329 bugfix: math verifier blocks the async training

    Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh
    https://code.alipay.com/inclusionAI/AReaL/pull_requests/329

    Reviewed-by: 博惟 <bowei.fw@antgroup.com>

    * bugfix: math verifier block the async training

    * format

    ---------

    Co-authored-by: 冰临 <shenxujie.sxj@antgroup.com>
    Co-authored-by: garrett4wade <fuwth17@gmail.com>

* add autotp for hf

* refactor test

* fix bugs

* fix issues

* format files

* Squashed commit of the following:

commit 9ed043f6ab
Author: Wei Fu <36355462+garrett4wade@users.noreply.github.com>
Date:   Tue Jul 15 10:24:48 2025 +0800

    format (#174)

commit 8cc9b1feb5
Author: Night <32424487+PrinsYin@users.noreply.github.com>
Date:   Mon Jul 14 19:22:00 2025 -0700

    added LocalSGlangEngine and test (#170)

    * added LocalSGLangEngine

    * upload test file

    * add build args

    * fix sgl_local generate

    * improved sgl local robustness

    * test

    * test updated

    * added fallback when sgl engine isn't initialized

    * finish test local engine

    * added LocalSGlangEngine and test

    * format and fix

    format and fix, raise when generate missing field

    format

    * change cli_args.py

    * add comment header

    format

    ---------

    Co-authored-by: ChangyiYang <changyiyang2023@gmail.com>

---------

Co-authored-by: Jayon02 <12012211@mail..sustech.edu.cn>
Co-authored-by: Wei Fu <36355462+garrett4wade@users.noreply.github.com>
This commit is contained in:
Jayon02 2025-07-16 12:44:10 +08:00 committed by GitHub
parent 9ed043f6ab
commit ef4215d6f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 423 additions and 225 deletions

View File

@ -104,6 +104,14 @@ class FSDPEngineConfig:
)
@dataclass
class HFEngineConfig:
autotp_size: Optional[int] = field(
default=1,
metadata={"help": "DeepSpeed AutoTP size"},
)
@dataclass
class TrainEngineConfig:
experiment_name: str = MISSING
@ -136,6 +144,7 @@ class TrainEngineConfig:
)
backend: str = ""
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
hf: HFEngineConfig = field(default_factory=HFEngineConfig)
@dataclass

View File

@ -175,6 +175,7 @@ class SaveLoadMeta:
with_optim: bool
tokenizer: PreTrainedTokenizerFast | None
base_model_path: str | None
naive_distributed: bool = False
@dataclass

View File

@ -1,123 +1,131 @@
import asyncio
import functools
import math
import gc
import os
import time
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
from safetensors.torch import save_file
from tensordict import TensorDict
from transformers import (
AutoConfig,
AutoModelForCausalLM,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from arealite.api.cli_args import EngineConfig, ParallelismConfig, TrainingArgs
from arealite.api.engine_api import TrainEngine
from arealite.api.io_struct import FinetuneSpec
from arealite.api.llm_client_api import LLMClient
from arealite.utils import (
get_state_dict_from_repo_id_or_path,
recorder_list,
split_dict_tensor_with_cu_seqlens,
from arealite.api.cli_args import MicroBatchSpec, TrainEngineConfig
from arealite.api.engine_api import (
FinetuneSpec,
SaveLoadMeta,
TrainEngine,
WeightUpdateMeta,
)
from arealite.utils.data import (
MicroBatchList,
amend_position_ids,
pack_tensor_dict,
pad_and_stack_tensors_along_first_dim,
pad_mb_list,
reorder_list,
split_packed_tensor_dict_into_mb_list,
unpack_sequence,
unsqueeze_mb_list,
)
from realhf.base import constants
def get_cosine_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
The minimum lr ratio w.r.t the maximum.
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
coef = (1 - min_lr_ratio) * 0.5
intercept = (1 + min_lr_ratio) * 0.5
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
from arealite.utils.save_load import (
get_state_dict_from_repo_id_or_path,
is_existing_local_path,
)
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
return max(0.0, x * coef + intercept)
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, name_resolve, names
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
logger = logging.getLogger("HFEngine")
class HFEngine(TrainEngine):
"""Simplified HF engine for transformer models."""
def __init__(self, args: TrainingArgs, engine_config: EngineConfig):
super().__init__(args, engine_config)
def __init__(self, config: TrainEngineConfig):
self.config = config
self.optimizer_config = config.optimizer
self.model = None
self.optimizer = None
self.tokenizer = None
# huggingface model config
self.model_config = None
# initialization
self.initialized = False
self.weight_update_group_initialized = False
def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec):
"""Initialize model in single node."""
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
if dist.get_world_size() > 1:
raise RuntimeError(
"Distributed training is not supported in this engine. "
"Please use FSDP for distributed training."
)
torch.cuda.set_device("cuda:0")
def train(self, mode: bool = True):
assert self.model is not None
self.model.train(mode=mode)
return self
dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
"""Initialize distributed communication and model."""
assert addr is None, "HFEngine does not support remote initialization."
world_size = int(os.environ.get("WORLD_SIZE", 0))
if not dist.is_initialized() and world_size > 1:
try:
import deepspeed
except ImportError:
print(
"Warning: deepspeed is not installed. Some functionality may be disabled."
)
deepspeed.init_distributed(dist_backend="nccl", world_size=world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
self.device = torch.device(f"cuda:{local_rank}")
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.engine_config.path,
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
with torch.device("cuda"):
# initialize scratch model from config
model = AutoModelForCausalLM.from_config(
self.tokenizer = load_hf_tokenizer(self.config.path)
self.model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation="flash_attention_2",
attn_implementation=self.config.attn_impl,
).to(f"cuda:{local_rank}")
if not self.config.init_from_scratch:
# Load model from a initial checkpoint path,
# which should only be a huggingface checkpoint.
load_meta = SaveLoadMeta(
path=self.config.path,
weight_format="hf",
with_optim=False,
tokenizer=None,
base_model_path=self.config.path,
naive_distributed=False,
)
model = model.cuda()
self.load(load_meta)
self.model = model
if world_size > 1:
if self._check_autotp():
self.model = deepspeed.tp_model_init(
self.model, tp_size=self.config.hf.autotp_size, dtype=dtype
)
else:
raise RuntimeError("DeepSpeed AutoTP configuration error in HFEngine. ")
# Set up optimizer
optimizer_config = self.engine_config.optimizer
if optimizer_config is not None:
if self.optimizer_config is not None:
assert (
optimizer_config.type == "adam"
self.optimizer_config.type == "adam"
), "Only AdamW optimizer is supported in this engine."
lr = optimizer_config.lr
weight_decay = optimizer_config.weight_decay
beta1 = optimizer_config.beta1
beta2 = optimizer_config.beta2
eps = optimizer_config.eps
lr = self.optimizer_config.lr
weight_decay = self.optimizer_config.weight_decay
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
@ -128,80 +136,293 @@ class HFEngine(TrainEngine):
)
total_train_steps = ft_spec.total_train_steps
num_warmup_steps = int(
optimizer_config.warmup_steps_proportion * total_train_steps
self.optimizer_config.warmup_steps_proportion * total_train_steps
)
if self.optimizer_config.lr_scheduler_type == "cosine":
self.lr_scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
min_lr_ratio=optimizer_config.min_lr_ratio,
min_lr_ratio=self.optimizer_config.min_lr_ratio,
)
elif self.optimizer_config.lr_scheduler_type == "linear":
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
total_train_steps,
)
elif self.optimizer_config.lr_scheduler_type == "constant":
self.lr_scheduler = get_constant_schedule_with_warmup(
self.optimizer,
num_warmup_steps,
)
else:
raise ValueError(
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
)
def train(self, mode: bool = True):
"""Set the module in training mode."""
return self.model.train(mode)
self.initialized = True
def _check_autotp(self):
tp_size = self.config.hf.autotp_size
config = self.model_config
num_attention_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
return (
num_attention_heads % tp_size == 0
and num_key_value_heads % tp_size == 0
and hidden_size % tp_size == 0
and intermediate_size % tp_size == 0
)
def destroy(self):
"""Destroy the engine and release GPU memory."""
self.model = None
self.optimizer = None
gc.collect()
torch.cuda.empty_cache()
gc.collect()
self.initialized = False
def save(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._save_model_to_hf(meta.path, meta.tokenizer, meta.naive_distributed)
elif meta.weight_format == "dcp":
# TODO: implement DCP save/load for HF
raise NotImplementedError("DCP format saving is not implemented yet. ")
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim:
self._save_optimizer_state(meta.path)
def load(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._load_model_from_hf(meta.path, meta.naive_distributed)
elif meta.weight_format == "dcp":
# TODO: implement DCP save/load for HF
raise NotImplementedError("DCP format loading is not implemented yet. ")
else:
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if meta.with_optim:
self._load_optimizer_state(meta.path)
def _save_optimizer_state(self, path: str):
assert self.optimizer is not None
os.makedirs(path, exist_ok=True)
torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt"))
def _load_optimizer_state(self, path: str):
assert self.optimizer is not None
path = os.path.join(path, "optim.pt")
optimizer_state_dict = torch.load(path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict)
def _save_model_to_hf(
self,
path: str,
tokenizer: Optional[transformers.PreTrainedTokenizerFast],
naive_distributed: bool,
):
"""Save model in HuggingFace format."""
if self.model is None:
raise RuntimeError("Model not initialized")
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
os.makedirs(path, exist_ok=True)
self.model_config.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
if world_size > 1:
dist.barrier()
state_dict = self.model.state_dict()
if hasattr(self.model, "module"):
state_dict = {
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
for k, v in state_dict.items()
}
else:
state_dict = {k: v.cpu() for k, v in state_dict.items()}
if world_size > 1 and naive_distributed:
# Only support store parameters from model partitions respectively
gathered_state_dicts = None
if rank == 0:
gathered_state_dicts = [None for _ in range(world_size)]
dist.gather_object(
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
)
if rank == 0:
for i, state_dict in enumerate(gathered_state_dicts):
save_file(state_dict, f"{path}/rank_{i:02d}_model.safetensors")
else:
self.model.save_pretrained(path, state_dict=state_dict)
if world_size > 1:
dist.barrier()
def _load_model_from_hf(self, path: str, naive_distributed: bool):
"""Load model from HuggingFace format."""
rank = dist.get_rank()
# Only support load full model parameters from huggingface
# and load model partition locally
if rank == 0 or is_existing_local_path(path):
if naive_distributed:
path = f"{path}/rank_{rank:02d}_model.safetensors"
full_state = get_state_dict_from_repo_id_or_path(path)
if hasattr(self.model, "module") and not hasattr(full_state):
full_state = {
f"module.{k}" if not k.startswith("module.") else k: v
for k, v in full_state.items()
}
self.model.load_state_dict(
full_state, strict=not self.model_config.tie_word_embeddings
)
if self.model_config.tie_word_embeddings:
self.model.tie_weights()
def upload_weights(self, meta: WeightUpdateMeta):
if meta.type == "nccl":
if not self.weight_update_group_initialized:
self._init_distributed_weight_update(meta)
self._update_weights_from_distributed()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer, meta.naive_distributed)
update_name = names.update_weights_from_disk(
self.config.experiment_name,
self.config.trial_name,
meta.model_version,
)
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
else:
raise ValueError(f"Unknown weight update type {meta.type}")
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
raise NotImplementedError(
"Distributed weight update is not implemented for HFEngine yet. "
)
def _update_weights_from_distributed(self):
raise NotImplementedError(
"Distributed weight update is not implemented for HFEngine yet. "
)
def step_lr_scheduler(self):
assert self.lr_scheduler is not None
return self.lr_scheduler.step()
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
packed_input = pack_tensor_dict(input_)
mb_list = split_packed_tensor_dict_into_mb_list(
packed_input,
self.config.mb_spec,
)
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
return mb_list
def train_batch(
self,
input_: Dict,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict:
) -> Dict[str, float]:
"""Train on a batch using gradient accumulation."""
input_ = input_.to(self.device)
assert self.optimizer is not None
assert self.optimizer_config is not None
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
loss = loss_fn(outputs.logits, mb_input)
# Process microbatches with gradient accumulation
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
loss *= loss_scale
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.engine_config.optimizer.gradient_clipping,
self.optimizer_config.gradient_clipping,
norm_type=2.0,
error_if_nonfinite=False,
foreach=None,
)
if not torch.isfinite(grad_norm):
self.optimizer.zero_grad()
update_successful = False
else:
self.optimizer.step()
update_successful = True
current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return {
"grad_norm": grad_norm,
"lr": current_lr,
}
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
lr=current_lr,
)
@torch.no_grad()
def eval_batch(
self,
input_: Dict,
input_: TensorDict,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
assert total_loss_weight != 0
total_loss = 0.0
total_weight = 0.0
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
loss = loss_fn(outputs.logits, mb_input)
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)
# Simple weight calculation (could be improved)
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
@ -213,95 +434,34 @@ class HFEngine(TrainEngine):
@torch.no_grad()
def forward(
self,
input_: Dict,
input_: TensorDict,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
"""Forward pass with optional post-processing."""
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
mb_list = self._prepare_mb_list(input_)
if output_seqlens is None:
cu_seqlens = input_["cu_seqlens"]
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
results = []
for mb_input in mb_splits.mbs:
outputs = self.model(**mb_input)
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):
outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
if post_hook:
result = post_hook(outputs.logits, mb_input)
result = post_hook(logits, mb_input)
results.append(result)
else:
results.append(outputs.logits)
results.append(logits)
res = aggregate_fn(results)
output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=1)
return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices))
def step_lr_scheduler(self):
"""Step the learning rate scheduler."""
return self.lr_scheduler.step()
def save_model_to_hf(
self,
path: str,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
base_model_path: Optional[str] = None,
):
"""Save model in HuggingFace format."""
if self.model is None:
raise RuntimeError("Model not initialized")
os.makedirs(path, exist_ok=True)
state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()}
self.model.save_pretrained(path, state_dict=state_dict)
self.model_config.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
def load_model_from_hf(self, path: str):
"""Load model from HuggingFace format."""
full_state = get_state_dict_from_repo_id_or_path(path)
self.model.load_state_dict(
full_state, strict=not self.model_config.tie_word_embeddings
)
if self.model_config.tie_word_embeddings:
self.model.tie_weights()
def save_optimizer_state(self, path: str):
"""Save optimizer state."""
if self.optimizer is None:
raise RuntimeError("Optimizer not initialized")
os.makedirs(path, exist_ok=True)
torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt"))
def load_optimizer_state(self, path: str):
"""Load optimizer state."""
if self.optimizer is None:
raise RuntimeError("Optimizer not initialized")
optimizer_path = os.path.join(path, "optimizer.pt")
if os.path.exists(optimizer_path):
self.optimizer.load_state_dict(
torch.load(optimizer_path, map_location="cpu")
)
else:
raise RuntimeError(f"Optimizer state file not found: {optimizer_path}")
async def aupdate_weights_to(self, llm_client: LLMClient):
path = constants.get_param_realloc_path(self.args)
self.save_model_to_hf(path)
tasks = [
llm_client.aupdate_weights_from_disk(server_info=server_info, path=path)
for server_info in llm_client.get_healthy_servers()
]
await asyncio.gather(*tasks)
def update_weights_to(self, llm_client: LLMClient):
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(self.aupdate_weights_to(llm_client))
finally:
loop.close()
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
reordered = reorder_list(unpacked, mb_list.backward_indices)
return pad_and_stack_tensors_along_first_dim(reordered)

View File

@ -1,7 +1,7 @@
# Copyright 2025 Ant Group Inc.
# Licensed under the Apache License, Version 2.0
"""Test script for HF Engine implementation."""
"""Test script for Engine implementation."""
import os
from typing import Dict
@ -52,29 +52,43 @@ def mock_input(
)
def get_engine(engine_type: str, model_path: str):
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.engine.hf_engine import HFEngine
engine_cls = {"hf": HFEngine, "fsdp": FSDPEngine}[engine_type]
engine_config = TrainEngineConfig(
experiment_name=f"test-{engine_type}-engine",
trial_name="test0",
path=model_path,
optimizer=OptimizerConfig(),
)
engine = engine_cls(engine_config)
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec)
return engine
def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
"""Mock loss function for testing."""
return torch.mean(logits)
@pytest.fixture(scope="module")
def engine():
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "7777"
engine_config = TrainEngineConfig(
experiment_name="test-fsdp-engine",
trial_name="test0",
path=MODEL_PATH,
optimizer=OptimizerConfig(),
@pytest.fixture(scope="module", params=["fsdp", "hf"])
def engine(request):
os.environ.update(
{
"WORLD_SIZE": "1",
"RANK": "0",
"LOCAL_RANK": "0",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "7777",
}
)
engine = FSDPEngine(engine_config)
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
engine.initialize(None, ft_spec)
print("✓ Engine created successfully")
engine = get_engine(request.param, MODEL_PATH)
print(f"{request.param.upper()} Engine created successfully")
yield engine

View File

@ -1,4 +1,5 @@
import os
from pathlib import Path
from typing import Dict
import torch
@ -41,18 +42,21 @@ def get_state_dict_from_repo_id_or_path(repo_id_or_path: str) -> Dict:
else:
# Assume it's a local path
local_path = repo_id_or_path
if not os.path.isdir(local_path):
raise ValueError(
f"Local path {local_path} does not exist or is not a directory, "
f"or {local_path} is a huggingface repo id but huggingface_hub is not installed."
)
# Step 3: Load all .safetensors and .bin files
file_paths_to_load = []
if os.path.isdir(local_path):
for filename in os.listdir(local_path):
filepath = os.path.join(local_path, filename)
if filename.endswith(".safetensors") or filename.endswith(".bin"):
file_paths_to_load.append(filepath)
elif os.path.isfile(local_path):
file_paths_to_load.append(local_path)
else:
raise ValueError(
f"Local path {local_path} does not exist or is not a valid path, "
f"or {local_path} is a huggingface repo id but huggingface_hub is not installed."
)
def _load(filepath: str):
if filepath.endswith(".safetensors"):
@ -82,3 +86,11 @@ def get_state_dict_from_repo_id_or_path(repo_id_or_path: str) -> Dict:
except Exception as e:
raise RuntimeError(f"Error loading checkpoint from {path}: {e}")
return state_dict
def is_existing_local_path(path: str) -> bool:
try:
path_obj = Path(path)
return path_obj.exists() and (path_obj.is_file() or path_obj.is_dir())
except (ValueError, OSError):
return False

View File

@ -86,6 +86,7 @@ dependencies = [
# Distributed computing
"ray",
"redis",
"deepspeed>=0.17.2",
# Web frameworks
"fastapi>=0.115.12",

View File

@ -74,3 +74,4 @@ swanlab[dashboard]
torchdata
autoflake
tensordict
deepspeed>=0.17.2