mirror of https://github.com/inclusionAI/AReaL
[Feat][Refactor]Support DeepSpeed AutoTP; Refactor hf_engine.py and unit test. (#161)
* refactor hf engine * format file * revert file format * Squashed commit of the following: commit8d4b8dc90f
Author: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Thu Jul 10 13:14:10 2025 +0800 [Doc] Add an instruction about how to run the SFT example. (#164) commit3bf9c85e40
Author: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Thu Jul 10 12:56:24 2025 +0800 [Fix] Merge previous contributions from fw/refactor to lite (#163) * initial proposal * add arealite * . * change api * . * remove LOG_ROOT * remove MODEL_SAVE_PATH * remove PARAM_REALLOC_PATH, DATASET_CACHE * prepare for testing * prepare for testing * ready for run * local run * tests mainly pass * format * . * amend cluster.py * . * . * client test pass * pass rollout test * remove unused imports * add arealite readme * change api * . * . * . * . * . * . * . * . * format * . * implement iteraptable generation (#112) Co-authored-by: zhaochenyang <zhaochenyang20@gmail.com> * . * fix * . * . * . * pass controller generate batch test * . * refactor rollout controller into worker and controller * . * . * . * change to async rollout * pass rollout controller test * pass test * . * update readme * . * sft debug * . * add lisence * remove unused files * remove unsed args in ppo * add hf engine wrapper (#116) * add hf engine * fix issues * fix ppo bugs and add test * add hf client interface and modify cli args * fix bugs * fix issues * Merge fw/refactor * Finish hf wrapper test * add test --------- Co-authored-by: Wei Fu <36355462+garrett4wade@users.noreply.github.com> * format * format * . * refine hf engine * . * fix * add fsdp engine and sft tests * . * . * . * pass ppo unittest * pass ppo and rollout controller tests * clear unused imports * rename ppo to grpo * change reward function organization * reorganize code * add dataset api * . * . * . * format * chmod fix * . * rename workflow to collector * refactor llm_client location * . * . * fix llm server api * refactor config structure * . * fix tests * . * . * . * Fix unresolved issue in SFTTrainer PR (#139) * . * . * efficient loading * format * . * . * . * . * . * . * Add CI for testing AReaLite (#150) * ci: add test-arealite * ci: add checkout before running test-arealite * ci: add USERNAME * ci: add test script * ci: add GitHub mirror * ci: fix typo * ci: clone one commit * ci: fix condition * ci: set command timeout to 60m * ci: enable pip cache * ci: optimize container lifecycle * ci: split into many stages * ci(test-arealite): fix typo * ci: fix wrong env * ci: fix pytest * ci: uninstall transformer-engine * ci: uninstall transformer-engine * ci: fix model paths * ci: show stdout/stderr * ci: fix not clean up * ci: backup sglang * ci: remove tmp repo dir when run * ci: fix docker run exit 1 condition * ci(test-arealite): limit the concurrency and extend command timeout * . * merge fw/refactor * revert some changes * fix --------- Co-authored-by: meizhiyu.mzy <meizhiyu.mzy@antgroup.com> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: zhaochenyang <zhaochenyang20@gmail.com> Co-authored-by: Jayon02 <qiujiangc@outlook.com> Co-authored-by: root <meizhiyu.mzy> Co-authored-by: Zijian Zhang <futrime@outlook.com> commitd48bf007cf
Merge:42c717b
b9dbd4a
Author: 博惟 <bowei.fw@antgroup.com> Date: Thu Jul 10 12:53:30 2025 +0800 Merge branch 'main' of https://github.com/inclusionAI/AReaL into lite commit42c717b6e4
Merge:c38cffc
a203c7c
Author: 博惟 <bowei.fw@antgroup.com> Date: Thu Jul 10 11:15:01 2025 +0800 Merge branch 'lite' of https://github.com/inclusionAI/AReaL into lite commitc38cffc023
Author: 博惟 <bowei.fw@antgroup.com> Date: Thu Jul 10 11:10:10 2025 +0800 PullRequest: 340 [lite] Refactor trainer API into utilities and remove mb_spec in engine methods Merge branch fw/lite-dev of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/340 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * fix commitb9dbd4a2c1
Author: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Wed Jul 9 10:50:19 2025 +0800 Update to persistent wechat QR code. (#159) commit17ea7fe94d
Author: xssstory <33601810+xssstory@users.noreply.github.com> Date: Mon Jul 7 15:49:13 2025 +0800 fix math reward verifier (#156) * PullRequest: 293 fix get_param_realloc_path Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/293 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * fix get_param_realloc_path * PullRequest: 297 bugfix: reward is always -5 Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/297 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * bugfix: reward is always -5 * PullRequest: 321 fix checkpoint save dir Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/321 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * fix checkpoint save dir * PullRequest: 328 [Doc] update installation Merge branch sxj/doc of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/328 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * [Doc] update installation * PullRequest: 329 bugfix: math verifier blocks the async training Merge branch xss/debug of git@code.alipay.com:inclusionAI/AReaL.git into gh https://code.alipay.com/inclusionAI/AReaL/pull_requests/329 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * bugfix: math verifier block the async training * format --------- Co-authored-by: 冰临 <shenxujie.sxj@antgroup.com> Co-authored-by: garrett4wade <fuwth17@gmail.com> * add autotp for hf * refactor test * fix bugs * fix issues * format files * Squashed commit of the following: commit9ed043f6ab
Author: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Tue Jul 15 10:24:48 2025 +0800 format (#174) commit8cc9b1feb5
Author: Night <32424487+PrinsYin@users.noreply.github.com> Date: Mon Jul 14 19:22:00 2025 -0700 added LocalSGlangEngine and test (#170) * added LocalSGLangEngine * upload test file * add build args * fix sgl_local generate * improved sgl local robustness * test * test updated * added fallback when sgl engine isn't initialized * finish test local engine * added LocalSGlangEngine and test * format and fix format and fix, raise when generate missing field format * change cli_args.py * add comment header format --------- Co-authored-by: ChangyiYang <changyiyang2023@gmail.com> --------- Co-authored-by: Jayon02 <12012211@mail..sustech.edu.cn> Co-authored-by: Wei Fu <36355462+garrett4wade@users.noreply.github.com>
This commit is contained in:
parent
9ed043f6ab
commit
ef4215d6f1
|
@ -104,6 +104,14 @@ class FSDPEngineConfig:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HFEngineConfig:
|
||||||
|
autotp_size: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "DeepSpeed AutoTP size"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainEngineConfig:
|
class TrainEngineConfig:
|
||||||
experiment_name: str = MISSING
|
experiment_name: str = MISSING
|
||||||
|
@ -136,6 +144,7 @@ class TrainEngineConfig:
|
||||||
)
|
)
|
||||||
backend: str = ""
|
backend: str = ""
|
||||||
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
||||||
|
hf: HFEngineConfig = field(default_factory=HFEngineConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -175,6 +175,7 @@ class SaveLoadMeta:
|
||||||
with_optim: bool
|
with_optim: bool
|
||||||
tokenizer: PreTrainedTokenizerFast | None
|
tokenizer: PreTrainedTokenizerFast | None
|
||||||
base_model_path: str | None
|
base_model_path: str | None
|
||||||
|
naive_distributed: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -1,123 +1,131 @@
|
||||||
import asyncio
|
import gc
|
||||||
import functools
|
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
from safetensors.torch import save_file
|
||||||
|
from tensordict import TensorDict
|
||||||
from arealite.api.cli_args import EngineConfig, ParallelismConfig, TrainingArgs
|
from transformers import (
|
||||||
from arealite.api.engine_api import TrainEngine
|
AutoConfig,
|
||||||
from arealite.api.io_struct import FinetuneSpec
|
AutoModelForCausalLM,
|
||||||
from arealite.api.llm_client_api import LLMClient
|
get_constant_schedule_with_warmup,
|
||||||
from arealite.utils import (
|
get_linear_schedule_with_warmup,
|
||||||
get_state_dict_from_repo_id_or_path,
|
|
||||||
recorder_list,
|
|
||||||
split_dict_tensor_with_cu_seqlens,
|
|
||||||
unpack_sequence,
|
|
||||||
)
|
)
|
||||||
from realhf.base import constants
|
|
||||||
|
|
||||||
|
from arealite.api.cli_args import MicroBatchSpec, TrainEngineConfig
|
||||||
|
from arealite.api.engine_api import (
|
||||||
|
FinetuneSpec,
|
||||||
|
SaveLoadMeta,
|
||||||
|
TrainEngine,
|
||||||
|
WeightUpdateMeta,
|
||||||
|
)
|
||||||
|
from arealite.utils.data import (
|
||||||
|
MicroBatchList,
|
||||||
|
amend_position_ids,
|
||||||
|
pack_tensor_dict,
|
||||||
|
pad_and_stack_tensors_along_first_dim,
|
||||||
|
pad_mb_list,
|
||||||
|
reorder_list,
|
||||||
|
split_packed_tensor_dict_into_mb_list,
|
||||||
|
unpack_sequence,
|
||||||
|
unsqueeze_mb_list,
|
||||||
|
)
|
||||||
|
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
|
||||||
|
from arealite.utils.save_load import (
|
||||||
|
get_state_dict_from_repo_id_or_path,
|
||||||
|
is_existing_local_path,
|
||||||
|
)
|
||||||
|
from realhf.api.core.data_api import load_hf_tokenizer
|
||||||
|
from realhf.base import logging, name_resolve, names
|
||||||
|
|
||||||
def get_cosine_schedule_with_warmup(
|
logger = logging.getLogger("HFEngine")
|
||||||
optimizer: torch.optim.Optimizer,
|
|
||||||
num_warmup_steps: int,
|
|
||||||
num_training_steps: int,
|
|
||||||
min_lr_ratio: float = 0.0,
|
|
||||||
num_cycles: float = 0.5,
|
|
||||||
last_epoch: int = -1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
|
||||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
|
||||||
initial lr set in the optimizer.
|
|
||||||
Args:
|
|
||||||
optimizer (:class:`~torch.optim.Optimizer`):
|
|
||||||
The optimizer for which to schedule the learning rate.
|
|
||||||
num_warmup_steps (:obj:`int`):
|
|
||||||
The number of steps for the warmup phase.
|
|
||||||
num_training_steps (:obj:`int`):
|
|
||||||
The total number of training steps.
|
|
||||||
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
|
|
||||||
The minimum lr ratio w.r.t the maximum.
|
|
||||||
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
|
|
||||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
|
||||||
following a half-cosine).
|
|
||||||
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
|
||||||
The index of the last epoch when resuming training.
|
|
||||||
Return:
|
|
||||||
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
|
||||||
"""
|
|
||||||
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
|
|
||||||
coef = (1 - min_lr_ratio) * 0.5
|
|
||||||
intercept = (1 + min_lr_ratio) * 0.5
|
|
||||||
|
|
||||||
def lr_lambda(current_step):
|
|
||||||
if current_step < num_warmup_steps:
|
|
||||||
return float(current_step) / float(max(1, num_warmup_steps))
|
|
||||||
progress = float(current_step - num_warmup_steps) / float(
|
|
||||||
max(1, num_training_steps - num_warmup_steps)
|
|
||||||
)
|
|
||||||
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
|
|
||||||
return max(0.0, x * coef + intercept)
|
|
||||||
|
|
||||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
|
|
||||||
|
|
||||||
|
|
||||||
class HFEngine(TrainEngine):
|
class HFEngine(TrainEngine):
|
||||||
"""Simplified HF engine for transformer models."""
|
def __init__(self, config: TrainEngineConfig):
|
||||||
|
self.config = config
|
||||||
def __init__(self, args: TrainingArgs, engine_config: EngineConfig):
|
self.optimizer_config = config.optimizer
|
||||||
super().__init__(args, engine_config)
|
|
||||||
|
|
||||||
self.model = None
|
self.model = None
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
|
self.tokenizer = None
|
||||||
|
# huggingface model config
|
||||||
self.model_config = None
|
self.model_config = None
|
||||||
|
# initialization
|
||||||
|
self.initialized = False
|
||||||
self.weight_update_group_initialized = False
|
self.weight_update_group_initialized = False
|
||||||
|
|
||||||
def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec):
|
def train(self, mode: bool = True):
|
||||||
"""Initialize model in single node."""
|
assert self.model is not None
|
||||||
if not dist.is_initialized():
|
self.model.train(mode=mode)
|
||||||
dist.init_process_group(backend="nccl")
|
return self
|
||||||
if dist.get_world_size() > 1:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Distributed training is not supported in this engine. "
|
|
||||||
"Please use FSDP for distributed training."
|
|
||||||
)
|
|
||||||
torch.cuda.set_device("cuda:0")
|
|
||||||
|
|
||||||
dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16
|
def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
|
||||||
|
"""Initialize distributed communication and model."""
|
||||||
|
assert addr is None, "HFEngine does not support remote initialization."
|
||||||
|
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 0))
|
||||||
|
if not dist.is_initialized() and world_size > 1:
|
||||||
|
try:
|
||||||
|
import deepspeed
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"Warning: deepspeed is not installed. Some functionality may be disabled."
|
||||||
|
)
|
||||||
|
deepspeed.init_distributed(dist_backend="nccl", world_size=world_size)
|
||||||
|
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
self.device = torch.device(f"cuda:{local_rank}")
|
||||||
|
|
||||||
|
dtype = torch.bfloat16 if self.config.bf16 else torch.float16
|
||||||
self.model_config = AutoConfig.from_pretrained(
|
self.model_config = AutoConfig.from_pretrained(
|
||||||
pretrained_model_name_or_path=self.engine_config.path,
|
pretrained_model_name_or_path=self.config.path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
with torch.device("cuda"):
|
self.tokenizer = load_hf_tokenizer(self.config.path)
|
||||||
# initialize scratch model from config
|
|
||||||
model = AutoModelForCausalLM.from_config(
|
self.model = AutoModelForCausalLM.from_config(
|
||||||
self.model_config,
|
self.model_config,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
attn_implementation="flash_attention_2",
|
attn_implementation=self.config.attn_impl,
|
||||||
|
).to(f"cuda:{local_rank}")
|
||||||
|
|
||||||
|
if not self.config.init_from_scratch:
|
||||||
|
# Load model from a initial checkpoint path,
|
||||||
|
# which should only be a huggingface checkpoint.
|
||||||
|
load_meta = SaveLoadMeta(
|
||||||
|
path=self.config.path,
|
||||||
|
weight_format="hf",
|
||||||
|
with_optim=False,
|
||||||
|
tokenizer=None,
|
||||||
|
base_model_path=self.config.path,
|
||||||
|
naive_distributed=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model.cuda()
|
self.load(load_meta)
|
||||||
|
|
||||||
self.model = model
|
if world_size > 1:
|
||||||
|
if self._check_autotp():
|
||||||
|
self.model = deepspeed.tp_model_init(
|
||||||
|
self.model, tp_size=self.config.hf.autotp_size, dtype=dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("DeepSpeed AutoTP configuration error in HFEngine. ")
|
||||||
|
|
||||||
# Set up optimizer
|
# Set up optimizer
|
||||||
optimizer_config = self.engine_config.optimizer
|
if self.optimizer_config is not None:
|
||||||
if optimizer_config is not None:
|
|
||||||
assert (
|
assert (
|
||||||
optimizer_config.type == "adam"
|
self.optimizer_config.type == "adam"
|
||||||
), "Only AdamW optimizer is supported in this engine."
|
), "Only AdamW optimizer is supported in this engine."
|
||||||
lr = optimizer_config.lr
|
lr = self.optimizer_config.lr
|
||||||
weight_decay = optimizer_config.weight_decay
|
weight_decay = self.optimizer_config.weight_decay
|
||||||
beta1 = optimizer_config.beta1
|
beta1 = self.optimizer_config.beta1
|
||||||
beta2 = optimizer_config.beta2
|
beta2 = self.optimizer_config.beta2
|
||||||
eps = optimizer_config.eps
|
eps = self.optimizer_config.eps
|
||||||
|
|
||||||
self.optimizer = torch.optim.AdamW(
|
self.optimizer = torch.optim.AdamW(
|
||||||
self.model.parameters(),
|
self.model.parameters(),
|
||||||
|
@ -128,80 +136,293 @@ class HFEngine(TrainEngine):
|
||||||
)
|
)
|
||||||
total_train_steps = ft_spec.total_train_steps
|
total_train_steps = ft_spec.total_train_steps
|
||||||
num_warmup_steps = int(
|
num_warmup_steps = int(
|
||||||
optimizer_config.warmup_steps_proportion * total_train_steps
|
self.optimizer_config.warmup_steps_proportion * total_train_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
if self.optimizer_config.lr_scheduler_type == "cosine":
|
||||||
self.optimizer,
|
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
||||||
num_warmup_steps,
|
self.optimizer,
|
||||||
total_train_steps,
|
num_warmup_steps,
|
||||||
min_lr_ratio=optimizer_config.min_lr_ratio,
|
total_train_steps,
|
||||||
|
min_lr_ratio=self.optimizer_config.min_lr_ratio,
|
||||||
|
)
|
||||||
|
elif self.optimizer_config.lr_scheduler_type == "linear":
|
||||||
|
self.lr_scheduler = get_linear_schedule_with_warmup(
|
||||||
|
self.optimizer,
|
||||||
|
num_warmup_steps,
|
||||||
|
total_train_steps,
|
||||||
|
)
|
||||||
|
elif self.optimizer_config.lr_scheduler_type == "constant":
|
||||||
|
self.lr_scheduler = get_constant_schedule_with_warmup(
|
||||||
|
self.optimizer,
|
||||||
|
num_warmup_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def _check_autotp(self):
|
||||||
|
tp_size = self.config.hf.autotp_size
|
||||||
|
config = self.model_config
|
||||||
|
num_attention_heads = config.num_attention_heads
|
||||||
|
num_key_value_heads = config.num_key_value_heads
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
|
||||||
|
return (
|
||||||
|
num_attention_heads % tp_size == 0
|
||||||
|
and num_key_value_heads % tp_size == 0
|
||||||
|
and hidden_size % tp_size == 0
|
||||||
|
and intermediate_size % tp_size == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
"""Destroy the engine and release GPU memory."""
|
||||||
|
self.model = None
|
||||||
|
self.optimizer = None
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def save(self, meta: SaveLoadMeta):
|
||||||
|
if meta.weight_format == "hf":
|
||||||
|
self._save_model_to_hf(meta.path, meta.tokenizer, meta.naive_distributed)
|
||||||
|
elif meta.weight_format == "dcp":
|
||||||
|
# TODO: implement DCP save/load for HF
|
||||||
|
raise NotImplementedError("DCP format saving is not implemented yet. ")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||||
|
|
||||||
|
if meta.with_optim:
|
||||||
|
self._save_optimizer_state(meta.path)
|
||||||
|
|
||||||
|
def load(self, meta: SaveLoadMeta):
|
||||||
|
if meta.weight_format == "hf":
|
||||||
|
self._load_model_from_hf(meta.path, meta.naive_distributed)
|
||||||
|
elif meta.weight_format == "dcp":
|
||||||
|
# TODO: implement DCP save/load for HF
|
||||||
|
raise NotImplementedError("DCP format loading is not implemented yet. ")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
|
||||||
|
|
||||||
|
if meta.with_optim:
|
||||||
|
self._load_optimizer_state(meta.path)
|
||||||
|
|
||||||
|
def _save_optimizer_state(self, path: str):
|
||||||
|
assert self.optimizer is not None
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt"))
|
||||||
|
|
||||||
|
def _load_optimizer_state(self, path: str):
|
||||||
|
assert self.optimizer is not None
|
||||||
|
path = os.path.join(path, "optim.pt")
|
||||||
|
optimizer_state_dict = torch.load(path, weights_only=False)
|
||||||
|
self.optimizer.load_state_dict(optimizer_state_dict)
|
||||||
|
|
||||||
|
def _save_model_to_hf(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
tokenizer: Optional[transformers.PreTrainedTokenizerFast],
|
||||||
|
naive_distributed: bool,
|
||||||
|
):
|
||||||
|
"""Save model in HuggingFace format."""
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError("Model not initialized")
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
if rank == 0:
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
self.model_config.save_pretrained(path)
|
||||||
|
if tokenizer is not None:
|
||||||
|
tokenizer.save_pretrained(path)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
|
||||||
|
if hasattr(self.model, "module"):
|
||||||
|
state_dict = {
|
||||||
|
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
state_dict = {k: v.cpu() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
if world_size > 1 and naive_distributed:
|
||||||
|
# Only support store parameters from model partitions respectively
|
||||||
|
gathered_state_dicts = None
|
||||||
|
if rank == 0:
|
||||||
|
gathered_state_dicts = [None for _ in range(world_size)]
|
||||||
|
|
||||||
|
dist.gather_object(
|
||||||
|
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
if rank == 0:
|
||||||
"""Set the module in training mode."""
|
for i, state_dict in enumerate(gathered_state_dicts):
|
||||||
return self.model.train(mode)
|
save_file(state_dict, f"{path}/rank_{i:02d}_model.safetensors")
|
||||||
|
else:
|
||||||
|
self.model.save_pretrained(path, state_dict=state_dict)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
def _load_model_from_hf(self, path: str, naive_distributed: bool):
|
||||||
|
"""Load model from HuggingFace format."""
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
# Only support load full model parameters from huggingface
|
||||||
|
# and load model partition locally
|
||||||
|
if rank == 0 or is_existing_local_path(path):
|
||||||
|
if naive_distributed:
|
||||||
|
path = f"{path}/rank_{rank:02d}_model.safetensors"
|
||||||
|
full_state = get_state_dict_from_repo_id_or_path(path)
|
||||||
|
|
||||||
|
if hasattr(self.model, "module") and not hasattr(full_state):
|
||||||
|
full_state = {
|
||||||
|
f"module.{k}" if not k.startswith("module.") else k: v
|
||||||
|
for k, v in full_state.items()
|
||||||
|
}
|
||||||
|
self.model.load_state_dict(
|
||||||
|
full_state, strict=not self.model_config.tie_word_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_config.tie_word_embeddings:
|
||||||
|
self.model.tie_weights()
|
||||||
|
|
||||||
|
def upload_weights(self, meta: WeightUpdateMeta):
|
||||||
|
if meta.type == "nccl":
|
||||||
|
if not self.weight_update_group_initialized:
|
||||||
|
self._init_distributed_weight_update(meta)
|
||||||
|
self._update_weights_from_distributed()
|
||||||
|
elif meta.type == "disk":
|
||||||
|
self._save_model_to_hf(meta.path, self.tokenizer, meta.naive_distributed)
|
||||||
|
update_name = names.update_weights_from_disk(
|
||||||
|
self.config.experiment_name,
|
||||||
|
self.config.trial_name,
|
||||||
|
meta.model_version,
|
||||||
|
)
|
||||||
|
name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown weight update type {meta.type}")
|
||||||
|
|
||||||
|
def _init_distributed_weight_update(self, meta: WeightUpdateMeta):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Distributed weight update is not implemented for HFEngine yet. "
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_weights_from_distributed(self):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Distributed weight update is not implemented for HFEngine yet. "
|
||||||
|
)
|
||||||
|
|
||||||
|
def step_lr_scheduler(self):
|
||||||
|
assert self.lr_scheduler is not None
|
||||||
|
return self.lr_scheduler.step()
|
||||||
|
|
||||||
|
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||||
|
assert "attention_mask" in input_ and "input_ids" in input_
|
||||||
|
if isinstance(input_, dict):
|
||||||
|
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||||
|
input_ = amend_position_ids(input_)
|
||||||
|
packed_input = pack_tensor_dict(input_)
|
||||||
|
mb_list = split_packed_tensor_dict_into_mb_list(
|
||||||
|
packed_input,
|
||||||
|
self.config.mb_spec,
|
||||||
|
)
|
||||||
|
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||||
|
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||||
|
# packed input to be of shape [1, total_seqlen].
|
||||||
|
mb_list = unsqueeze_mb_list(mb_list)
|
||||||
|
return mb_list
|
||||||
|
|
||||||
def train_batch(
|
def train_batch(
|
||||||
self,
|
self,
|
||||||
input_: Dict,
|
input_: TensorDict,
|
||||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||||
loss_weight_fn: Callable[[Dict], float],
|
loss_weight_fn: Callable[[Dict], float],
|
||||||
) -> Dict:
|
) -> Dict[str, float]:
|
||||||
"""Train on a batch using gradient accumulation."""
|
"""Train on a batch using gradient accumulation."""
|
||||||
|
input_ = input_.to(self.device)
|
||||||
assert self.optimizer is not None
|
assert self.optimizer is not None
|
||||||
|
assert self.optimizer_config is not None
|
||||||
assert self.lr_scheduler is not None
|
assert self.lr_scheduler is not None
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
mb_list = self._prepare_mb_list(input_)
|
||||||
|
|
||||||
total_loss_weight = torch.tensor(
|
total_loss_weight = torch.tensor(
|
||||||
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
|
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||||
)
|
)
|
||||||
assert total_loss_weight != 0
|
assert total_loss_weight != 0
|
||||||
|
|
||||||
for mb_input in mb_splits.mbs:
|
# Process microbatches with gradient accumulation
|
||||||
outputs = self.model(**mb_input)
|
for pad_length, padded_mb_input, mb_input in zip(
|
||||||
loss = loss_fn(outputs.logits, mb_input)
|
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||||
|
):
|
||||||
|
outputs = self.model(**padded_mb_input)
|
||||||
|
|
||||||
|
logits = outputs.logits.squeeze(0)
|
||||||
|
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||||
|
loss = loss_fn(logits, mb_input)
|
||||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||||
|
|
||||||
loss *= loss_scale
|
loss *= loss_scale
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
self.model.parameters(),
|
self.model.parameters(),
|
||||||
self.engine_config.optimizer.gradient_clipping,
|
self.optimizer_config.gradient_clipping,
|
||||||
norm_type=2.0,
|
norm_type=2.0,
|
||||||
error_if_nonfinite=False,
|
error_if_nonfinite=False,
|
||||||
foreach=None,
|
foreach=None,
|
||||||
)
|
)
|
||||||
|
if not torch.isfinite(grad_norm):
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
update_successful = False
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
|
update_successful = True
|
||||||
|
|
||||||
current_lr = self.lr_scheduler.get_last_lr()[0]
|
current_lr = self.lr_scheduler.get_last_lr()[0]
|
||||||
# Optimizer step
|
# Optimizer step
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
return dict(
|
||||||
return {
|
update_successful=float(update_successful),
|
||||||
"grad_norm": grad_norm,
|
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
|
||||||
"lr": current_lr,
|
lr=current_lr,
|
||||||
}
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_batch(
|
def eval_batch(
|
||||||
self,
|
self,
|
||||||
input_: Dict,
|
input_: TensorDict,
|
||||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||||
loss_weight_fn: Callable[[Dict], float],
|
loss_weight_fn: Callable[[Dict], float],
|
||||||
) -> torch.Tensor | None:
|
) -> torch.Tensor | None:
|
||||||
"""Evaluate on a batch."""
|
"""Evaluate on a batch."""
|
||||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
mb_list = self._prepare_mb_list(input_)
|
||||||
total_loss_weight = torch.tensor(
|
total_loss_weight = torch.tensor(
|
||||||
sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32
|
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||||
)
|
)
|
||||||
assert total_loss_weight != 0
|
assert total_loss_weight != 0
|
||||||
|
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
total_weight = 0.0
|
total_weight = 0.0
|
||||||
|
|
||||||
for mb_input in mb_splits.mbs:
|
for pad_length, padded_mb_input, mb_input in zip(
|
||||||
outputs = self.model(**mb_input)
|
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||||
loss = loss_fn(outputs.logits, mb_input)
|
):
|
||||||
|
outputs = self.model(**padded_mb_input)
|
||||||
|
logits = outputs.logits.squeeze(0)
|
||||||
|
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||||
|
loss = loss_fn(logits, mb_input)
|
||||||
|
|
||||||
# Simple weight calculation (could be improved)
|
# Simple weight calculation (could be improved)
|
||||||
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
loss_scale = loss_weight_fn(mb_input) / total_loss_weight
|
||||||
|
@ -213,95 +434,34 @@ class HFEngine(TrainEngine):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_: Dict,
|
input_: TensorDict,
|
||||||
output_seqlens: List[int] | None = None,
|
output_seqlens: List[int] | None = None,
|
||||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||||
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),
|
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||||
) -> Any | None:
|
) -> Any | None:
|
||||||
"""Forward pass with optional post-processing."""
|
"""Forward pass with optional post-processing."""
|
||||||
mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec)
|
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
|
||||||
|
mb_list = self._prepare_mb_list(input_)
|
||||||
|
|
||||||
if output_seqlens is None:
|
if output_seqlens is None:
|
||||||
cu_seqlens = input_["cu_seqlens"]
|
|
||||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for mb_input in mb_splits.mbs:
|
for pad_length, padded_mb_input, mb_input in zip(
|
||||||
outputs = self.model(**mb_input)
|
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
|
||||||
|
):
|
||||||
|
outputs = self.model(**padded_mb_input)
|
||||||
|
logits = outputs.logits.squeeze(0)
|
||||||
|
logits = logits[:-pad_length] if pad_length > 0 else logits
|
||||||
|
|
||||||
if post_hook:
|
if post_hook:
|
||||||
result = post_hook(outputs.logits, mb_input)
|
result = post_hook(logits, mb_input)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
else:
|
else:
|
||||||
results.append(outputs.logits)
|
results.append(logits)
|
||||||
|
|
||||||
res = aggregate_fn(results)
|
res = aggregate_fn(results)
|
||||||
output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices]
|
output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices]
|
||||||
unpacked = unpack_sequence(res, lens=output_seqlens, dim=1)
|
unpacked = unpack_sequence(res, lens=output_seqlens, dim=0)
|
||||||
return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices))
|
reordered = reorder_list(unpacked, mb_list.backward_indices)
|
||||||
|
return pad_and_stack_tensors_along_first_dim(reordered)
|
||||||
def step_lr_scheduler(self):
|
|
||||||
"""Step the learning rate scheduler."""
|
|
||||||
return self.lr_scheduler.step()
|
|
||||||
|
|
||||||
def save_model_to_hf(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
|
|
||||||
base_model_path: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""Save model in HuggingFace format."""
|
|
||||||
if self.model is None:
|
|
||||||
raise RuntimeError("Model not initialized")
|
|
||||||
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
|
|
||||||
state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()}
|
|
||||||
self.model.save_pretrained(path, state_dict=state_dict)
|
|
||||||
self.model_config.save_pretrained(path)
|
|
||||||
if tokenizer is not None:
|
|
||||||
tokenizer.save_pretrained(path)
|
|
||||||
|
|
||||||
def load_model_from_hf(self, path: str):
|
|
||||||
"""Load model from HuggingFace format."""
|
|
||||||
full_state = get_state_dict_from_repo_id_or_path(path)
|
|
||||||
self.model.load_state_dict(
|
|
||||||
full_state, strict=not self.model_config.tie_word_embeddings
|
|
||||||
)
|
|
||||||
if self.model_config.tie_word_embeddings:
|
|
||||||
self.model.tie_weights()
|
|
||||||
|
|
||||||
def save_optimizer_state(self, path: str):
|
|
||||||
"""Save optimizer state."""
|
|
||||||
if self.optimizer is None:
|
|
||||||
raise RuntimeError("Optimizer not initialized")
|
|
||||||
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt"))
|
|
||||||
|
|
||||||
def load_optimizer_state(self, path: str):
|
|
||||||
"""Load optimizer state."""
|
|
||||||
if self.optimizer is None:
|
|
||||||
raise RuntimeError("Optimizer not initialized")
|
|
||||||
|
|
||||||
optimizer_path = os.path.join(path, "optimizer.pt")
|
|
||||||
if os.path.exists(optimizer_path):
|
|
||||||
self.optimizer.load_state_dict(
|
|
||||||
torch.load(optimizer_path, map_location="cpu")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Optimizer state file not found: {optimizer_path}")
|
|
||||||
|
|
||||||
async def aupdate_weights_to(self, llm_client: LLMClient):
|
|
||||||
path = constants.get_param_realloc_path(self.args)
|
|
||||||
self.save_model_to_hf(path)
|
|
||||||
tasks = [
|
|
||||||
llm_client.aupdate_weights_from_disk(server_info=server_info, path=path)
|
|
||||||
for server_info in llm_client.get_healthy_servers()
|
|
||||||
]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
def update_weights_to(self, llm_client: LLMClient):
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(self.aupdate_weights_to(llm_client))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Copyright 2025 Ant Group Inc.
|
# Copyright 2025 Ant Group Inc.
|
||||||
# Licensed under the Apache License, Version 2.0
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
"""Test script for HF Engine implementation."""
|
"""Test script for Engine implementation."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
@ -52,29 +52,43 @@ def mock_input(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine(engine_type: str, model_path: str):
|
||||||
|
from arealite.engine.fsdp_engine import FSDPEngine
|
||||||
|
from arealite.engine.hf_engine import HFEngine
|
||||||
|
|
||||||
|
engine_cls = {"hf": HFEngine, "fsdp": FSDPEngine}[engine_type]
|
||||||
|
|
||||||
|
engine_config = TrainEngineConfig(
|
||||||
|
experiment_name=f"test-{engine_type}-engine",
|
||||||
|
trial_name="test0",
|
||||||
|
path=model_path,
|
||||||
|
optimizer=OptimizerConfig(),
|
||||||
|
)
|
||||||
|
engine = engine_cls(engine_config)
|
||||||
|
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
||||||
|
engine.initialize(None, ft_spec)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
|
def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor:
|
||||||
"""Mock loss function for testing."""
|
"""Mock loss function for testing."""
|
||||||
return torch.mean(logits)
|
return torch.mean(logits)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module", params=["fsdp", "hf"])
|
||||||
def engine():
|
def engine(request):
|
||||||
os.environ["WORLD_SIZE"] = "1"
|
os.environ.update(
|
||||||
os.environ["RANK"] = "0"
|
{
|
||||||
os.environ["LOCAL_RANK"] = "0"
|
"WORLD_SIZE": "1",
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
"RANK": "0",
|
||||||
os.environ["MASTER_PORT"] = "7777"
|
"LOCAL_RANK": "0",
|
||||||
|
"MASTER_ADDR": "localhost",
|
||||||
engine_config = TrainEngineConfig(
|
"MASTER_PORT": "7777",
|
||||||
experiment_name="test-fsdp-engine",
|
}
|
||||||
trial_name="test0",
|
|
||||||
path=MODEL_PATH,
|
|
||||||
optimizer=OptimizerConfig(),
|
|
||||||
)
|
)
|
||||||
engine = FSDPEngine(engine_config)
|
|
||||||
ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2)
|
engine = get_engine(request.param, MODEL_PATH)
|
||||||
engine.initialize(None, ft_spec)
|
print(f"✓ {request.param.upper()} Engine created successfully")
|
||||||
print("✓ Engine created successfully")
|
|
||||||
yield engine
|
yield engine
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -41,18 +42,21 @@ def get_state_dict_from_repo_id_or_path(repo_id_or_path: str) -> Dict:
|
||||||
else:
|
else:
|
||||||
# Assume it's a local path
|
# Assume it's a local path
|
||||||
local_path = repo_id_or_path
|
local_path = repo_id_or_path
|
||||||
if not os.path.isdir(local_path):
|
|
||||||
raise ValueError(
|
|
||||||
f"Local path {local_path} does not exist or is not a directory, "
|
|
||||||
f"or {local_path} is a huggingface repo id but huggingface_hub is not installed."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 3: Load all .safetensors and .bin files
|
# Step 3: Load all .safetensors and .bin files
|
||||||
file_paths_to_load = []
|
file_paths_to_load = []
|
||||||
for filename in os.listdir(local_path):
|
if os.path.isdir(local_path):
|
||||||
filepath = os.path.join(local_path, filename)
|
for filename in os.listdir(local_path):
|
||||||
if filename.endswith(".safetensors") or filename.endswith(".bin"):
|
filepath = os.path.join(local_path, filename)
|
||||||
file_paths_to_load.append(filepath)
|
if filename.endswith(".safetensors") or filename.endswith(".bin"):
|
||||||
|
file_paths_to_load.append(filepath)
|
||||||
|
elif os.path.isfile(local_path):
|
||||||
|
file_paths_to_load.append(local_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Local path {local_path} does not exist or is not a valid path, "
|
||||||
|
f"or {local_path} is a huggingface repo id but huggingface_hub is not installed."
|
||||||
|
)
|
||||||
|
|
||||||
def _load(filepath: str):
|
def _load(filepath: str):
|
||||||
if filepath.endswith(".safetensors"):
|
if filepath.endswith(".safetensors"):
|
||||||
|
@ -82,3 +86,11 @@ def get_state_dict_from_repo_id_or_path(repo_id_or_path: str) -> Dict:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error loading checkpoint from {path}: {e}")
|
raise RuntimeError(f"Error loading checkpoint from {path}: {e}")
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def is_existing_local_path(path: str) -> bool:
|
||||||
|
try:
|
||||||
|
path_obj = Path(path)
|
||||||
|
return path_obj.exists() and (path_obj.is_file() or path_obj.is_dir())
|
||||||
|
except (ValueError, OSError):
|
||||||
|
return False
|
||||||
|
|
|
@ -86,6 +86,7 @@ dependencies = [
|
||||||
# Distributed computing
|
# Distributed computing
|
||||||
"ray",
|
"ray",
|
||||||
"redis",
|
"redis",
|
||||||
|
"deepspeed>=0.17.2",
|
||||||
|
|
||||||
# Web frameworks
|
# Web frameworks
|
||||||
"fastapi>=0.115.12",
|
"fastapi>=0.115.12",
|
||||||
|
|
|
@ -74,3 +74,4 @@ swanlab[dashboard]
|
||||||
torchdata
|
torchdata
|
||||||
autoflake
|
autoflake
|
||||||
tensordict
|
tensordict
|
||||||
|
deepspeed>=0.17.2
|
||||||
|
|
Loading…
Reference in New Issue