mirror of https://github.com/inclusionAI/AReaL
PullRequest: 340 [lite] Refactor trainer API into utilities and remove mb_spec in engine methods
Merge branch fw/lite-dev of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/340 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * support fsdp engine and sglang remote engine * minor fix * . * refactor trainer * add close * rm mb_spec * fix
This commit is contained in:
parent
7be4ab0d18
commit
c38cffc023
|
@ -187,7 +187,6 @@ class TrainEngine(abc.ABC):
|
|||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
|
@ -197,7 +196,6 @@ class TrainEngine(abc.ABC):
|
|||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
|
@ -207,7 +205,6 @@ class TrainEngine(abc.ABC):
|
|||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
|
@ -323,7 +320,7 @@ Extended engines (such as Actor in PPO) provide convenient organization and call
|
|||
class Actor(Engine):
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logps(self, input_: Dict[str, Tensor], mb_spec: MicroBatchSpec) -> torch.Tensor:
|
||||
def compute_logps(self, input_: Dict[str, Tensor]) -> torch.Tensor:
|
||||
... # unpad
|
||||
logps = self.forward(xxx)
|
||||
... # pad back
|
||||
|
@ -332,8 +329,7 @@ class Actor(Engine):
|
|||
def compute_advantages_and_returns(self, input_: Dict) -> Dict:
|
||||
pass
|
||||
|
||||
def ppo_update(self, input_: Dict,
|
||||
mb_spec: MicroBatchSpec) -> List[Dict[str, float]]:
|
||||
def ppo_update(self, input_: Dict) -> List[Dict[str, float]]:
|
||||
...
|
||||
all_stats = []
|
||||
for _ in range(self.ppo_n_minibatches):
|
||||
|
@ -344,11 +340,10 @@ class Actor(Engine):
|
|||
class Critic(Engine):
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_values(self, input_: Dict, mb_spec: MicroBatchSpec) -> torch.Tensor:
|
||||
def compute_values(self, input_: Dict) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def ppo_update(self, input_: Dict,
|
||||
mb_spec: MicroBatchSpec) -> List[Dict[str, float]]:
|
||||
def ppo_update(self, input_: Dict) -> List[Dict[str, float]]:
|
||||
...
|
||||
all_stats = []
|
||||
for _ in range(self.ppo_n_minibatches):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import argparse
|
||||
import getpass
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
|
@ -9,6 +8,7 @@ from hydra import compose as hydra_compose
|
|||
from hydra import initialize as hydra_init
|
||||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
from arealite.utils.fs import get_user_tmp
|
||||
from realhf.api.cli_args import OptimizerConfig
|
||||
|
||||
|
||||
|
@ -103,8 +103,8 @@ class FSDPEngineConfig:
|
|||
|
||||
@dataclass
|
||||
class TrainEngineConfig:
|
||||
experiment_name: str
|
||||
trial_name: str
|
||||
experiment_name: str = MISSING
|
||||
trial_name: str = MISSING
|
||||
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
|
||||
attn_impl: str = field(
|
||||
default="flash_attention_2",
|
||||
|
@ -120,6 +120,8 @@ class TrainEngineConfig:
|
|||
default=False,
|
||||
metadata={"help": "Initialize critic/reward model from LM checkpoint"},
|
||||
)
|
||||
# Runtime microbatch limit
|
||||
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
||||
|
||||
# Training Backend Configuration
|
||||
gradient_checkpointing: bool = field(
|
||||
|
@ -307,92 +309,38 @@ class SGLangEngineConfig:
|
|||
|
||||
|
||||
@dataclass
|
||||
class ExperimentSaveEvalControl:
|
||||
"""Controls the frequency of model saving and evaluation during training.
|
||||
class _Timer:
|
||||
experiment_name: str = MISSING
|
||||
trial_name: str = MISSING
|
||||
fileroot: str = MISSING
|
||||
freq_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Trigger frequency in epochs. None disables epoch-based saving."
|
||||
},
|
||||
)
|
||||
freq_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Trigger frequency in steps. None disables step-based saving."
|
||||
},
|
||||
)
|
||||
freq_secs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Trigger frequency in seconds. None disables time-based saving."
|
||||
},
|
||||
)
|
||||
|
||||
Manages independent counters for epochs, steps, and seconds. The model will be saved
|
||||
or evaluated when any specified frequency condition is met.
|
||||
|
||||
Note:
|
||||
- Epoch: Number of full passes through the training dataset
|
||||
- Step: Number of individual training iterations
|
||||
- Seconds: Wall-clock time duration
|
||||
"""
|
||||
@dataclass
|
||||
class EvaluatorConfig(_Timer):
|
||||
pass
|
||||
|
||||
total_train_epochs: int = field(
|
||||
default=1, metadata={"help": "Total number of epochs to train the model."}
|
||||
)
|
||||
# Save control
|
||||
save_freq_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Save frequency in epochs. None disables epoch-based saving."
|
||||
},
|
||||
)
|
||||
save_freq_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Save frequency in steps. None disables step-based saving."},
|
||||
)
|
||||
save_freq_secs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Save frequency in seconds. None disables time-based saving."
|
||||
},
|
||||
)
|
||||
# Checkpointing control
|
||||
ckpt_freq_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Checkpoint frequency in epochs. None uses save_freq_epochs. "
|
||||
"Checkpointing is used for recover. Previous checkpoint is overwritten to save space."
|
||||
},
|
||||
)
|
||||
ckpt_freq_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Checkpoint frequency in steps. None disables step-based checkpointing."
|
||||
},
|
||||
)
|
||||
ckpt_freq_secs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Checkpoint frequency in seconds. None disables time-based checkpointing."
|
||||
},
|
||||
)
|
||||
# Evaluation control
|
||||
eval_freq_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Evaluation frequency in epochs. None disables epoch-based evaluation."
|
||||
},
|
||||
)
|
||||
eval_freq_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Evaluation frequency in steps. None disables step-based evaluation."
|
||||
},
|
||||
)
|
||||
eval_freq_secs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Evaluation frequency in seconds. None disables time-based evaluation."
|
||||
},
|
||||
)
|
||||
# Benchmark control
|
||||
benchmark_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Terminate training after this number of steps. "
|
||||
"For benchmarking purposes only. None indicates normal training."
|
||||
},
|
||||
)
|
||||
benchmark_n_seqs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Terminate training after consuming this number of samples. "
|
||||
"For benchmarking purposes only. None indicates normal training."
|
||||
},
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class SaverConfig(_Timer):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -423,11 +371,23 @@ class TensorBoardConfig:
|
|||
path: Optional[str] = None
|
||||
|
||||
|
||||
def get_user_tmp():
|
||||
user = getpass.getuser()
|
||||
user_tmp = os.path.join("/home", user, ".cache", "realhf")
|
||||
os.makedirs(user_tmp, exist_ok=True)
|
||||
return user_tmp
|
||||
@dataclass
|
||||
class StatsLoggerConfig:
|
||||
experiment_name: str = MISSING
|
||||
trial_name: str = MISSING
|
||||
fileroot: str = MISSING
|
||||
wandb: WandBConfig = field(
|
||||
default_factory=WandBConfig,
|
||||
metadata={"help": "Weights & Biases configuration."},
|
||||
)
|
||||
swanlab: SwanlabConfig = field(
|
||||
default_factory=SwanlabConfig,
|
||||
metadata={"help": "SwanLab configuration."},
|
||||
)
|
||||
tensorboard: TensorBoardConfig = field(
|
||||
default_factory=TensorBoardConfig,
|
||||
metadata={"help": "TensorBoard configuration. Only 'path' field required."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -498,7 +458,9 @@ class ClusterSpecConfig:
|
|||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
type: str = field(default="", metadata={"help": "Type of implemented dataset"})
|
||||
type: Optional[str] = field(
|
||||
default=None, metadata={"help": "Type of implemented dataset"}
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=1, metadata={"help": "Batch size of the dataloader"}
|
||||
)
|
||||
|
@ -517,51 +479,6 @@ class DatasetConfig:
|
|||
drop_last: bool = field(default=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerConfig:
|
||||
experiment_name: str = field(
|
||||
default=MISSING,
|
||||
metadata={"help": "Name of the experiment (no '_' or '/'). Required."},
|
||||
)
|
||||
trial_name: str = field(
|
||||
default=MISSING,
|
||||
metadata={"help": "Name of the trial (no '-' or '/'). Required."},
|
||||
)
|
||||
fileroot: str = field(
|
||||
default=get_user_tmp(),
|
||||
metadata={
|
||||
"help": "Root for logs and checkpoints. Should be available to all nodes."
|
||||
},
|
||||
)
|
||||
wandb: WandBConfig = field(
|
||||
default_factory=WandBConfig,
|
||||
metadata={"help": "Weights & Biases configuration."},
|
||||
)
|
||||
swanlab: SwanlabConfig = field(
|
||||
default_factory=SwanlabConfig,
|
||||
metadata={"help": "SwanLab configuration."},
|
||||
)
|
||||
tensorboard: TensorBoardConfig = field(
|
||||
default_factory=TensorBoardConfig,
|
||||
metadata={"help": "TensorBoard configuration. Only 'path' field required."},
|
||||
)
|
||||
allocation_mode: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "GPU parallel strategy allocation mode. "
|
||||
"Options: manual/heuristic or pattern-based."
|
||||
},
|
||||
)
|
||||
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
|
||||
exp_ctrl: ExperimentSaveEvalControl = field(
|
||||
default_factory=ExperimentSaveEvalControl,
|
||||
metadata={"help": "Experiment save/evaluation control configuration."},
|
||||
)
|
||||
|
||||
tokenizer_path: str = field(default="")
|
||||
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseExperimentConfig:
|
||||
# NOTE: we need this unified config class because different experiments
|
||||
|
@ -585,14 +502,45 @@ class BaseExperimentConfig:
|
|||
n_gpus_per_node: int = field(
|
||||
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
|
||||
)
|
||||
allocation_mode: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "GPU parallel strategy allocation mode. "
|
||||
"Options: manual/heuristic or pattern-based."
|
||||
},
|
||||
)
|
||||
seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."})
|
||||
total_train_epochs: int = field(
|
||||
default=1, metadata={"help": "Total number of epochs to train the model."}
|
||||
)
|
||||
total_train_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Terminate training after this number of steps. "
|
||||
"For benchmarking purposes only. None indicates normal training."
|
||||
},
|
||||
)
|
||||
total_train_n_seqs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Terminate training after consuming this number of samples. "
|
||||
"For benchmarking purposes only. None indicates normal training."
|
||||
},
|
||||
)
|
||||
tokenizer_path: str = field(default="")
|
||||
|
||||
train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
||||
valid_dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
||||
|
||||
saver: SaverConfig = field(default_factory=SaverConfig)
|
||||
checkpointer: SaverConfig = field(default_factory=SaverConfig)
|
||||
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
|
||||
stats_logger: StatsLoggerConfig = field(default_factory=StatsLoggerConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTConfig(BaseExperimentConfig):
|
||||
model: TrainEngineConfig = field(default_factory=TrainEngineConfig)
|
||||
trainer: TrainerConfig = field(default_factory=TrainerConfig)
|
||||
|
||||
|
||||
def load_expr_config(argv: List[str], config_cls) -> Tuple[BaseExperimentConfig, str]:
|
||||
|
|
|
@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
|||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec
|
||||
from arealite.api.io_struct import (
|
||||
FinetuneSpec,
|
||||
LLMRequest,
|
||||
|
@ -79,7 +78,6 @@ class TrainEngine(abc.ABC):
|
|||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
|
@ -90,7 +88,6 @@ class TrainEngine(abc.ABC):
|
|||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
|
@ -101,7 +98,6 @@ class TrainEngine(abc.ABC):
|
|||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
|
|
|
@ -1,130 +0,0 @@
|
|||
import getpass
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from tensorboardX import SummaryWriter
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from arealite.api.cli_args import TrainerConfig
|
||||
from arealite.api.engine_api import InferenceEngine, TrainEngine
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import logging, timeutil
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
config: TrainerConfig,
|
||||
train_dataloader: StatefulDataLoader,
|
||||
valid_dataloader: StatefulDataLoader,
|
||||
engine: TrainEngine,
|
||||
inf_engine: InferenceEngine | None = None,
|
||||
):
|
||||
self.config = config
|
||||
|
||||
self.train_dataloader = train_dataloader
|
||||
self.valid_dataloader = valid_dataloader
|
||||
|
||||
self.engine = engine
|
||||
self.inf_engine = inf_engine
|
||||
|
||||
self.tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
self.save_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.save_freq_epochs,
|
||||
freq_step=config.exp_ctrl.save_freq_steps,
|
||||
freq_sec=config.exp_ctrl.save_freq_secs,
|
||||
)
|
||||
self.eval_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.exp_ctrl.eval_freq_epochs,
|
||||
freq_step=config.exp_ctrl.eval_freq_steps,
|
||||
freq_sec=config.exp_ctrl.eval_freq_steps,
|
||||
)
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.init_stats_logging()
|
||||
|
||||
def init_stats_logging(self):
|
||||
"""
|
||||
Initialize wandb and/or tensorboard according to config.
|
||||
If torch.distributed is initialized
|
||||
|
||||
Return:
|
||||
tensorboard SummaryWriter if self.config.tensorboard.path is not None
|
||||
"""
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
|
||||
# wandb init, connect to remote wandb host
|
||||
if self.config.wandb.mode != "disabled":
|
||||
wandb.login()
|
||||
wandb.init(
|
||||
mode=self.config.wandb.mode,
|
||||
entity=self.config.wandb.entity,
|
||||
project=self.config.wandb.project or self.config.experiment_name,
|
||||
name=self.config.wandb.name or self.config.trial_name,
|
||||
job_type=self.config.wandb.job_type,
|
||||
group=self.config.wandb.group
|
||||
or f"{self.config.experiment_name}_{self.config.trial_name}",
|
||||
notes=self.config.wandb.notes,
|
||||
tags=self.config.wandb.tags,
|
||||
config=self.config.wandb.config,
|
||||
dir=Trainer.get_log_path(self.config),
|
||||
force=True,
|
||||
id=f"{self.config.experiment_name}_{self.config.trial_name}_train",
|
||||
resume="allow",
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
)
|
||||
# tensorboard logging
|
||||
self.summary_writer = None
|
||||
if self.config.tensorboard.path is not None:
|
||||
self.summary_writer = SummaryWriter(log_dir=self.config.tensorboard.path)
|
||||
|
||||
def log_wandb_tensorboard(self, step: int, data: Dict):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
|
||||
wandb.log(data, step=step)
|
||||
if self.summary_writer is not None:
|
||||
for key, val in data.items():
|
||||
self.summary_writer.add_scalar(f"{key}", val, step)
|
||||
|
||||
def close_wandb_tensorboard(self):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
|
||||
wandb.finish()
|
||||
if self.summary_writer is not None:
|
||||
self.summary_writer.close()
|
||||
|
||||
@staticmethod
|
||||
def get_save_checkpoint_path(
|
||||
config: TrainerConfig,
|
||||
epoch: int,
|
||||
step: int,
|
||||
globalstep: int,
|
||||
name: str = "default",
|
||||
):
|
||||
path = os.path.join(
|
||||
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
|
||||
name,
|
||||
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
|
||||
)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def get_log_path(config: TrainerConfig):
|
||||
path = f"{config.fileroot}/logs/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}"
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
def log(self, msg: str, level="info"):
|
||||
if dist.is_initialized() and dist.get_rank() > 0:
|
||||
return
|
||||
log_fn = getattr(self.logger, level, "info")
|
||||
return log_fn(msg)
|
||||
|
||||
def train(self):
|
||||
raise NotImplementedError()
|
|
@ -21,7 +21,6 @@ from transformers import (
|
|||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import (
|
||||
FinetuneSpec,
|
||||
MicroBatchSpec,
|
||||
SaveLoadMeta,
|
||||
TrainEngine,
|
||||
WeightUpdateMeta,
|
||||
|
@ -319,15 +318,15 @@ class FSDPEngine(TrainEngine):
|
|||
assert self.lr_scheduler is not None
|
||||
self.lr_scheduler.step()
|
||||
|
||||
def _prepare_mb_list(
|
||||
self, input_: TensorDict, mb_spec: MicroBatchSpec
|
||||
) -> MicroBatchList:
|
||||
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||
assert "attention_mask" in input_ and "input_ids" in input_
|
||||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
input_ = amend_position_ids(input_)
|
||||
packed_input = pack_tensor_dict(input_)
|
||||
mb_list = split_packed_tensor_dict_into_mb_list(
|
||||
packed_input,
|
||||
mb_spec,
|
||||
self.config.mb_spec,
|
||||
)
|
||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
|
@ -338,7 +337,6 @@ class FSDPEngine(TrainEngine):
|
|||
def train_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
|
@ -349,7 +347,7 @@ class FSDPEngine(TrainEngine):
|
|||
assert self.lr_scheduler is not None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
mb_list = self._prepare_mb_list(input_, mb_spec)
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
|
@ -398,12 +396,12 @@ class FSDPEngine(TrainEngine):
|
|||
def eval_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate on a batch."""
|
||||
mb_list = self._prepare_mb_list(input_, mb_spec)
|
||||
input_ = input_.to(self.device)
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
)
|
||||
|
@ -431,14 +429,14 @@ class FSDPEngine(TrainEngine):
|
|||
def forward(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
) -> Any | None:
|
||||
"""Forward pass with optional post-processing."""
|
||||
input_ = input_.to(self.device)
|
||||
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
|
||||
mb_list = self._prepare_mb_list(input_, mb_spec)
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
|
||||
if output_seqlens is None:
|
||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||
|
|
|
@ -9,12 +9,7 @@ import torch.distributed as dist
|
|||
import transformers
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
EngineConfig,
|
||||
MicroBatchSpec,
|
||||
ParallelismConfig,
|
||||
TrainingArgs,
|
||||
)
|
||||
from arealite.api.cli_args import EngineConfig, ParallelismConfig, TrainingArgs
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
|
@ -150,7 +145,6 @@ class HFEngine(TrainEngine):
|
|||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict:
|
||||
|
@ -192,7 +186,6 @@ class HFEngine(TrainEngine):
|
|||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
|
@ -221,7 +214,6 @@ class HFEngine(TrainEngine):
|
|||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from tensordict import TensorDict
|
||||
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.utils.functional import gather_logprobs
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
class LMEngine:
|
||||
def __init__(self, engine: TrainEngine):
|
||||
self.engine = engine
|
||||
|
||||
def train_lm(self, data: TensorDict):
|
||||
self.engine.train()
|
||||
return self.engine.train_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
|
||||
)
|
||||
|
||||
def evaluate_lm(self, data):
|
||||
self.engine.eval()
|
||||
self.engine.eval_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
|
||||
)
|
||||
|
||||
|
||||
class FSDPLMEngine(FSDPEngine):
|
||||
def __init__(self, config: TrainEngineConfig):
|
||||
super().__init__(config)
|
||||
self.lm_engine = LMEngine(self)
|
||||
|
||||
def train_lm(self, data):
|
||||
return self.lm_engine.train_lm(data)
|
||||
|
||||
def evaluate_lm(self, data):
|
||||
return self.lm_engine.evaluate_lm(data)
|
||||
|
||||
|
||||
def compute_packed_sft_loss(
|
||||
logits: torch.Tensor, input_: Dict[str, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
packed_input_ids: torch.Tensor = input_["input_ids"]
|
||||
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
||||
prompt_mask = input_["prompt_mask"].bool()
|
||||
|
||||
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
|
||||
prompt_mask = torch.roll(prompt_mask, shifts=-1, dims=-1)
|
||||
logprobs = torch.where(prompt_mask, 0, logprobs)
|
||||
|
||||
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
|
||||
|
||||
with torch.no_grad():
|
||||
seqlogp = torch.zeros(
|
||||
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
|
||||
)
|
||||
for i in range(cu_seqlens.shape[0] - 1):
|
||||
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
|
||||
cu_seqlens,
|
||||
logprobs.shape,
|
||||
)
|
||||
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
|
||||
m.numel() - m.count_nonzero()
|
||||
)
|
||||
|
||||
## Loggin stats
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
|
||||
),
|
||||
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
||||
n_valid_tokens=prompt_mask.logical_not(),
|
||||
prompt_tokens=prompt_mask,
|
||||
)
|
||||
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
|
||||
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")
|
||||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
vocab_max_logits=vocab_max_logits,
|
||||
denominator="n_tokens",
|
||||
)
|
||||
|
||||
return loss
|
|
@ -81,22 +81,10 @@ def engine():
|
|||
@torch.no_grad()
|
||||
def test_forward_microbatch(engine, mock_input):
|
||||
engine.eval()
|
||||
x2 = (
|
||||
engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
|
||||
)
|
||||
.squeeze(0)
|
||||
.mean(-1)
|
||||
)
|
||||
x1 = (
|
||||
engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
|
||||
)
|
||||
.squeeze(0)
|
||||
.mean(-1)
|
||||
)
|
||||
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
|
||||
x2 = engine.forward(input_=mock_input).squeeze(0).mean(-1)
|
||||
engine.config.mb_spec = MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100)
|
||||
x1 = engine.forward(input_=mock_input).squeeze(0).mean(-1)
|
||||
input_ids = mock_input["input_ids"]
|
||||
assert x1.shape[:1] == input_ids.shape[:1]
|
||||
assert x2.shape[:1] == input_ids.shape[:1]
|
||||
|
@ -106,9 +94,9 @@ def test_forward_microbatch(engine, mock_input):
|
|||
@torch.no_grad()
|
||||
def test_eval_batch(engine, mock_input):
|
||||
engine.eval()
|
||||
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
|
||||
eval_result = engine.eval_batch(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
|
||||
loss_fn=mock_loss_fn,
|
||||
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
|
||||
)
|
||||
|
@ -120,9 +108,9 @@ def test_eval_batch(engine, mock_input):
|
|||
|
||||
def test_train_batch(engine, mock_input):
|
||||
engine.train()
|
||||
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
|
||||
train_result = engine.train_batch(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
|
||||
loss_fn=mock_loss_fn,
|
||||
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
|
||||
)
|
||||
|
@ -144,18 +132,13 @@ def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
|
|||
base_model_path=None,
|
||||
)
|
||||
|
||||
old = engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
|
||||
)
|
||||
engine.config.mb_spec = MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100)
|
||||
old = engine.forward(input_=mock_input)
|
||||
engine.save(save_load_meta)
|
||||
|
||||
for name, param in engine.model.named_parameters():
|
||||
param.zero_()
|
||||
|
||||
engine.load(save_load_meta)
|
||||
new = engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
|
||||
)
|
||||
new = engine.forward(input_=mock_input)
|
||||
assert torch.allclose(old, new)
|
||||
|
|
|
@ -1,156 +0,0 @@
|
|||
import time
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
|
||||
from arealite.api.io_struct import SaveLoadMeta
|
||||
from arealite.api.trainer_api import Trainer
|
||||
from arealite.utils.functional import gather_logprobs
|
||||
from arealite.utils.logging import record_timing
|
||||
from realhf.api.core.data_api import tabulate_stats
|
||||
from realhf.base import logging, stats_tracker
|
||||
|
||||
|
||||
def compute_packed_sft_loss(
|
||||
logits: torch.Tensor,
|
||||
input_: Dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
packed_input_ids: torch.Tensor = input_["input_ids"]
|
||||
cu_seqlens: torch.Tensor = input_["cu_seqlens"]
|
||||
prompt_mask = input_["prompt_mask"].bool()
|
||||
logits = logits.float()
|
||||
|
||||
logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1, dims=-1))
|
||||
prompt_mask = torch.roll(prompt_mask, shifts=-1, dims=-1)
|
||||
logprobs = torch.where(prompt_mask, 0, logprobs)
|
||||
|
||||
loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero()
|
||||
|
||||
with torch.no_grad():
|
||||
seqlogp = torch.zeros(
|
||||
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
|
||||
)
|
||||
for i in range(cu_seqlens.shape[0] - 1):
|
||||
m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
|
||||
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
|
||||
cu_seqlens,
|
||||
logprobs.shape,
|
||||
)
|
||||
seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / (
|
||||
m.numel() - m.count_nonzero()
|
||||
)
|
||||
|
||||
## Loggin stats
|
||||
stats_tracker.denominator(
|
||||
n_seqs=torch.ones(
|
||||
cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device
|
||||
),
|
||||
n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device),
|
||||
n_valid_tokens=prompt_mask.logical_not(),
|
||||
prompt_tokens=prompt_mask,
|
||||
)
|
||||
stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs")
|
||||
stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens")
|
||||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
vocab_max_logits=vocab_max_logits,
|
||||
denominator="n_tokens",
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SFTTrainer(Trainer):
|
||||
|
||||
def train(self):
|
||||
total_epochs = self.config.exp_ctrl.total_train_epochs
|
||||
steps_per_epoch = len(self.train_dataloader)
|
||||
|
||||
self.log(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
global_step = 0
|
||||
start_time = time.monotonic()
|
||||
for epoch in range(total_epochs):
|
||||
for step, data in enumerate(self.train_dataloader):
|
||||
self.engine.train()
|
||||
timing_stats = {}
|
||||
with record_timing("timeperf/train_step", timing_stats):
|
||||
with stats_tracker.scope("sft"):
|
||||
stats = self.engine.train_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"]
|
||||
.logical_not()
|
||||
.count_nonzero(),
|
||||
mb_spec=self.config.mb_spec,
|
||||
)
|
||||
self.engine.step_lr_scheduler()
|
||||
stats_tracker.scalar(**stats)
|
||||
|
||||
if self.save_ctl.check(
|
||||
epochs=int(step == steps_per_epoch - 1), steps=1
|
||||
):
|
||||
self.log("Saving model ...")
|
||||
|
||||
with record_timing("timeperf/save", timing_stats):
|
||||
save_path = self.get_save_checkpoint_path(
|
||||
self.config, epoch, step, global_step
|
||||
)
|
||||
meta = SaveLoadMeta(
|
||||
path=save_path,
|
||||
weight_format="hf",
|
||||
with_optim=False,
|
||||
tokenizer=self.tokenizer,
|
||||
base_model_path=self.config.tokenizer_path,
|
||||
)
|
||||
self.engine.save(meta)
|
||||
|
||||
if self.eval_ctl.check(
|
||||
epochs=int(step == steps_per_epoch - 1), steps=1
|
||||
):
|
||||
if dist.get_rank() == 0:
|
||||
self.log("Running evaluation ...")
|
||||
with record_timing("timeperf/eval", timing_stats):
|
||||
self.evaluate()
|
||||
|
||||
stats = stats_tracker.export()
|
||||
stats.update(timing_stats)
|
||||
self.log_wandb_tensorboard(global_step, stats)
|
||||
|
||||
self.log(
|
||||
f"Epoch {epoch+1}/{total_epochs} "
|
||||
f"Step {step+1}/{steps_per_epoch} "
|
||||
f"Train step {global_step + 1}/{total_epochs * steps_per_epoch} done."
|
||||
)
|
||||
self.log(
|
||||
f"Detailed time stats: \n{tabulate_stats(timing_stats, floatfmt='.2f')}"
|
||||
)
|
||||
self.log(f"SFT training stats:\n{tabulate_stats(stats)}")
|
||||
global_step += 1
|
||||
|
||||
self.log(
|
||||
f"Training completes! Total time elapsed {time.monotonic() - start_time:.2f}."
|
||||
)
|
||||
|
||||
self.close_wandb_tensorboard()
|
||||
|
||||
def evaluate(self):
|
||||
if self.valid_dataloader is None:
|
||||
return
|
||||
self.engine.eval()
|
||||
for data in self.valid_dataloader:
|
||||
with stats_tracker.scope("sft-eval"):
|
||||
# No need to log anything. Logging will be handled outside
|
||||
# via stats_tracker.export().
|
||||
self.engine.eval_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"]
|
||||
.logical_not()
|
||||
.count_nonzero(),
|
||||
mb_spec=self.config.mb_spec,
|
||||
)
|
|
@ -277,7 +277,7 @@ def pad_and_stack_tensors_along_first_dim(tensor_list: List[torch.Tensor]):
|
|||
|
||||
@dataclass
|
||||
class MicroBatchList:
|
||||
data: Dict[str, Any]
|
||||
data: TensorDict
|
||||
mb_spec: MicroBatchSpec
|
||||
mbs: List[TensorDict]
|
||||
forward_indices: List[int]
|
||||
|
@ -379,7 +379,7 @@ def pad_packed_tensor_dict(
|
|||
data: TensorDict,
|
||||
pad_to_length: int,
|
||||
pad_value: float = 0.0,
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
) -> Tuple[TensorDict, int]:
|
||||
"""Pad a packed tensor dict to a specified length.
|
||||
This function assumes that the input data contains "cu_seqlens" and "max_seqlen" key,
|
||||
and all other tensors of shape [total_length, ] will be padded to `pad_to_length`.
|
||||
|
@ -391,7 +391,7 @@ def pad_packed_tensor_dict(
|
|||
pad_to_length (int): The length to pad the tensors to. All tensors
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: Dictionary with padded tensors and modified "cu_seqlens" and
|
||||
TensorDict: Dictionary with padded tensors and modified "cu_seqlens" and
|
||||
"max_seqlen".
|
||||
int: The pad length.
|
||||
"""
|
||||
|
@ -420,7 +420,7 @@ def pad_packed_tensor_dict(
|
|||
padded_data[key] = padded_tensor
|
||||
else:
|
||||
padded_data[key] = value
|
||||
return padded_data, pad_length
|
||||
return TensorDict(padded_data, batch_size=data.batch_size), pad_length
|
||||
|
||||
|
||||
def pad_mb_list(
|
||||
|
@ -453,13 +453,17 @@ def unsqueeze_packed_tensor_dict(data: TensorDict) -> TensorDict:
|
|||
assert "max_seqlen" in data, "Input data must contain 'max_seqlen' key."
|
||||
|
||||
total_length = data["cu_seqlens"][-1].item()
|
||||
new_data = {}
|
||||
for key, value in data.items():
|
||||
if key == "cu_seqlens" or key == "max_seqlen":
|
||||
continue
|
||||
if (
|
||||
key not in ["cu_seqlens", "max_seqlen"]
|
||||
and torch.is_tensor(value)
|
||||
and value.numel() == total_length
|
||||
):
|
||||
new_data[key] = value.unsqueeze(dim=0)
|
||||
else:
|
||||
if torch.is_tensor(value) and value.numel() == total_length:
|
||||
data[key] = value.unsqueeze(dim=0)
|
||||
return data
|
||||
new_data[key] = value
|
||||
return TensorDict(new_data, batch_size=data.batch_size)
|
||||
|
||||
|
||||
def unsqueeze_mb_list(
|
||||
|
@ -472,7 +476,6 @@ def unsqueeze_mb_list(
|
|||
new_mbs.append(unsqueeze_packed_tensor_dict(mb))
|
||||
if mb_list.padded_mbs is not None:
|
||||
new_padded_mbs.append(unsqueeze_packed_tensor_dict(mb_list.padded_mbs[i]))
|
||||
mb_list.mbs = new_mbs
|
||||
mb_list.padded_mbs = new_padded_mbs if mb_list.padded_mbs is not None else None
|
||||
return mb_list
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from arealite.api.cli_args import EvaluatorConfig
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from realhf.base import timeutil
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensordict import TensorDict
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
|
||||
class Evaluator:
|
||||
|
||||
def __init__(self, config: EvaluatorConfig, ft_spec: FinetuneSpec):
|
||||
self.config = config
|
||||
self.ft_sepc = ft_spec
|
||||
self.freq_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.freq_epochs,
|
||||
freq_step=config.freq_steps,
|
||||
freq_sec=config.freq_secs,
|
||||
)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
valid_dataloader: "StatefulDataLoader",
|
||||
evaluate_fn: Callable[["TensorDict"], Any],
|
||||
epoch: int,
|
||||
step: int,
|
||||
global_step: int,
|
||||
):
|
||||
if not self.freq_ctl.check(
|
||||
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
|
||||
):
|
||||
return
|
||||
for data in valid_dataloader:
|
||||
evaluate_fn(data)
|
|
@ -0,0 +1,9 @@
|
|||
import getpass
|
||||
import os
|
||||
|
||||
|
||||
def get_user_tmp():
|
||||
user = getpass.getuser()
|
||||
user_tmp = os.path.join("/home", user, ".cache", "realhf")
|
||||
os.makedirs(user_tmp, exist_ok=True)
|
||||
return user_tmp
|
|
@ -1,9 +0,0 @@
|
|||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def record_timing(name, timing_stats):
|
||||
start_time = time.perf_counter()
|
||||
yield
|
||||
timing_stats[name] = time.perf_counter() - start_time
|
|
@ -0,0 +1,68 @@
|
|||
import getpass
|
||||
import os
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
from arealite.api.cli_args import SaverConfig
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
from arealite.api.io_struct import FinetuneSpec, SaveLoadMeta
|
||||
from realhf.base import timeutil
|
||||
|
||||
|
||||
class Saver:
|
||||
|
||||
def __init__(self, config: SaverConfig, ft_spec: FinetuneSpec, for_recover: bool):
|
||||
self.config = config
|
||||
self.ft_sepc = ft_spec
|
||||
self.for_recover = for_recover
|
||||
self.freq_ctl = timeutil.EpochStepTimeFreqCtl(
|
||||
freq_epoch=config.freq_epochs,
|
||||
freq_step=config.freq_steps,
|
||||
freq_sec=config.freq_secs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_save_checkpoint_path(
|
||||
config: SaverConfig,
|
||||
epoch: int,
|
||||
step: int,
|
||||
globalstep: int,
|
||||
name: str = "default",
|
||||
):
|
||||
path = os.path.join(
|
||||
f"{config.fileroot}/checkpoints/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}",
|
||||
name,
|
||||
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
|
||||
)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
def save(
|
||||
self,
|
||||
engine: TrainEngine,
|
||||
epoch: int,
|
||||
step: int,
|
||||
global_step: int,
|
||||
name: str = "default",
|
||||
tokenizer: PreTrainedTokenizerFast | None = None,
|
||||
base_model_path: str | None = None,
|
||||
):
|
||||
if not self.freq_ctl.check(
|
||||
epochs=int(step == self.ft_sepc.steps_per_epoch - 1), steps=1
|
||||
):
|
||||
return
|
||||
path = self.get_save_checkpoint_path(epoch, step, global_step, name)
|
||||
weight_format = "hf"
|
||||
with_optim = False
|
||||
if self.for_recover:
|
||||
weight_format = "dcp"
|
||||
with_optim = True
|
||||
|
||||
meta = SaveLoadMeta(
|
||||
path=path,
|
||||
weight_format=weight_format,
|
||||
with_optim=with_optim,
|
||||
tokenizer=tokenizer,
|
||||
base_model_path=base_model_path,
|
||||
)
|
||||
engine.save(meta)
|
|
@ -0,0 +1,111 @@
|
|||
import getpass
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from arealite.api.cli_args import StatsLoggerConfig
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from realhf.api.core.data_api import tabulate_stats
|
||||
from realhf.base import logging
|
||||
|
||||
|
||||
class StatsLogger:
|
||||
|
||||
def __init__(self, config: StatsLoggerConfig, ft_spec: FinetuneSpec):
|
||||
self.logger = logging.getLogger("StatsLogger", "system")
|
||||
self.config = config
|
||||
self.ft_spec = ft_spec
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
|
||||
self.start_time = time.perf_counter()
|
||||
# wandb init, connect to remote wandb host
|
||||
if self.config.wandb.mode != "disabled":
|
||||
wandb.login()
|
||||
wandb.init(
|
||||
mode=self.config.wandb.mode,
|
||||
entity=self.config.wandb.entity,
|
||||
project=self.config.wandb.project or self.config.experiment_name,
|
||||
name=self.config.wandb.name or self.config.trial_name,
|
||||
job_type=self.config.wandb.job_type,
|
||||
group=self.config.wandb.group
|
||||
or f"{self.config.experiment_name}_{self.config.trial_name}",
|
||||
notes=self.config.wandb.notes,
|
||||
tags=self.config.wandb.tags,
|
||||
config=self.config.wandb.config,
|
||||
dir=self.get_log_path(self.config),
|
||||
force=True,
|
||||
id=f"{self.config.experiment_name}_{self.config.trial_name}_train",
|
||||
resume="allow",
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
)
|
||||
# tensorboard logging
|
||||
self.summary_writer = None
|
||||
if self.config.tensorboard.path is not None:
|
||||
self.summary_writer = SummaryWriter(log_dir=self.config.tensorboard.path)
|
||||
|
||||
def close(self):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
self.info(
|
||||
f"Training completes! Total time elapsed {time.monotonic() - self.start_time:.2f}."
|
||||
)
|
||||
wandb.finish()
|
||||
if self.summary_writer is not None:
|
||||
self.summary_writer.close()
|
||||
|
||||
def commit(self, epoch: int, step: int, global_step: int, data: Dict):
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
self.info(
|
||||
f"Epoch {epoch+1}/{self.ft_spec.total_train_epochs} "
|
||||
f"Step {step+1}/{self.ft_spec.steps_per_epoch} "
|
||||
f"Train step {global_step + 1}/{self.ft_spec.total_train_steps} done."
|
||||
)
|
||||
self.info("Stats:")
|
||||
self.print_stats(data)
|
||||
wandb.log(data, step=global_step)
|
||||
if self.summary_writer is not None:
|
||||
for key, val in data.items():
|
||||
self.summary_writer.add_scalar(f"{key}", val, global_step)
|
||||
|
||||
def print_stats(self, stats: Dict[str, float]):
|
||||
self.info("\n" + tabulate_stats(stats))
|
||||
|
||||
@staticmethod
|
||||
def get_log_path(config: StatsLoggerConfig):
|
||||
path = f"{config.fileroot}/logs/{getpass.getuser()}/{config.experiment_name}/{config.trial_name}"
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
def info(self, msg: str, *args, **kwargs):
|
||||
if dist.is_initialized() and dist.get_rank() > 0:
|
||||
return
|
||||
self.logger.info(msg, *args, **kwargs)
|
||||
|
||||
def debug(self, msg: str, *args, **kwargs):
|
||||
if dist.is_initialized() and dist.get_rank() > 0:
|
||||
return
|
||||
self.logger.debug(msg, *args, **kwargs)
|
||||
|
||||
def critical(self, msg: str, *args, **kwargs):
|
||||
if dist.is_initialized() and dist.get_rank() > 0:
|
||||
return
|
||||
self.logger.critical(msg, *args, **kwargs)
|
||||
|
||||
def warning(self, msg: str, *args, **kwargs):
|
||||
if dist.is_initialized() and dist.get_rank() > 0:
|
||||
return
|
||||
self.logger.warning(msg, *args, **kwargs)
|
||||
|
||||
def error(self, msg: str, *args, **kwargs):
|
||||
if dist.is_initialized() and dist.get_rank() > 0:
|
||||
return
|
||||
self.logger.error(msg, *args, **kwargs)
|
|
@ -6,13 +6,19 @@ cluster:
|
|||
name_resolve:
|
||||
type: nfs
|
||||
nfs_record_root: /tmp/areal/name_resolve
|
||||
seed: 1
|
||||
total_train_epochs: 1
|
||||
tokenizer_path: ${model.path}
|
||||
|
||||
model:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
path: /storage/openpsi/models/Qwen__Qwen3-1.7B/
|
||||
init_from_scratch: false
|
||||
pad_mbs_to_max_tokens: true
|
||||
gradient_checkpointing: false
|
||||
bf16: true
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 4096
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2e-5
|
||||
|
@ -24,30 +30,46 @@ model:
|
|||
gradient_clipping: 1.0
|
||||
backend: fsdp
|
||||
|
||||
trainer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
||||
seed: 1
|
||||
exp_ctrl:
|
||||
total_train_epochs: 1
|
||||
eval_freq_steps: 1
|
||||
tokenizer_path: ${model.path}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 4096
|
||||
|
||||
train_dataset:
|
||||
type: gsm8k-sft
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
|
||||
valid_dataset:
|
||||
type: gsm8k-sft
|
||||
batch_size: 128
|
||||
shuffle: true
|
||||
pin_memory: true
|
||||
num_workers: 4
|
||||
num_workers: 4
|
||||
|
||||
# Utilities
|
||||
saver:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: null
|
||||
|
||||
checkpointer:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: 1
|
||||
freq_steps: null
|
||||
freq_secs: 3600
|
||||
|
||||
evaluator:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
freq_epochs: null
|
||||
freq_steps: 1
|
||||
freq_secs: null
|
||||
|
||||
stats_logger:
|
||||
experiment_name: ${experiment_name}
|
||||
trial_name: ${trial_name}
|
||||
fileroot: ${cluster.fileroot}
|
||||
wandb:
|
||||
mode: disabled
|
|
@ -1,18 +1,19 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
from arealite.api.cli_args import SFTConfig, load_expr_config
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.trainer.sft import SFTTrainer
|
||||
from arealite.engine.sft.lm_engine import FSDPLMEngine
|
||||
from arealite.utils.data import pad_sequences_to_tensors
|
||||
from arealite.utils.evaluator import Evaluator
|
||||
from arealite.utils.saver import Saver
|
||||
from arealite.utils.stats_logger import StatsLogger
|
||||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
from realhf.base import stats_tracker
|
||||
|
||||
|
||||
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
|
||||
|
@ -42,10 +43,9 @@ def main_sft():
|
|||
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
tokenizer = load_hf_tokenizer(config.trainer.tokenizer_path)
|
||||
tokenizer = load_hf_tokenizer(config.tokenizer_path)
|
||||
|
||||
# Create dataset and dataloaders
|
||||
assert config.train_dataset == "gsm8k-sft"
|
||||
train_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("train", tokenizer, rank, world_size),
|
||||
batch_size=config.train_dataset.batch_size // world_size,
|
||||
|
@ -54,7 +54,6 @@ def main_sft():
|
|||
collate_fn=pad_sequences_to_tensors,
|
||||
drop_last=config.train_dataset.drop_last,
|
||||
)
|
||||
assert config.valid_dataset == "gsm8k-sft"
|
||||
valid_dataloader = StatefulDataLoader(
|
||||
get_gsm8k_dataset("test", tokenizer, rank, world_size),
|
||||
batch_size=config.valid_dataset.batch_size // world_size,
|
||||
|
@ -66,22 +65,52 @@ def main_sft():
|
|||
|
||||
# Initialize engine
|
||||
ft_spec = FinetuneSpec(
|
||||
total_train_epochs=config.trainer.exp_ctrl.total_train_epochs,
|
||||
dataset_size=len(train_dataloader),
|
||||
total_train_epochs=config.total_train_epochs,
|
||||
dataset_size=len(train_dataloader) * config.train_dataset.batch_size,
|
||||
train_batch_size=config.train_dataset.batch_size,
|
||||
)
|
||||
engine = FSDPEngine(config=config.model)
|
||||
engine = FSDPLMEngine(config=config.model)
|
||||
engine.initialize(None, ft_spec)
|
||||
|
||||
# Run training.
|
||||
trainer = SFTTrainer(
|
||||
config=config.trainer,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
engine=engine,
|
||||
inf_engine=None,
|
||||
)
|
||||
trainer.train()
|
||||
saver = Saver(config.saver, ft_spec, for_recover=False)
|
||||
logger = StatsLogger(config.stats_logger, ft_spec)
|
||||
evaluator = Evaluator(config.evaluator, ft_spec)
|
||||
|
||||
total_epochs = config.total_train_epochs
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
|
||||
logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}")
|
||||
global_step = 0
|
||||
for epoch in range(total_epochs):
|
||||
for step, data in enumerate(train_dataloader):
|
||||
with (
|
||||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("sft"),
|
||||
):
|
||||
stats = engine.train_lm(data)
|
||||
engine.step_lr_scheduler()
|
||||
stats_tracker.scalar(**stats)
|
||||
|
||||
with stats_tracker.record_timing("save"):
|
||||
saver.save(engine, epoch, step, global_step)
|
||||
|
||||
with stats_tracker.record_timing("eval"), stats_tracker.scope("sft-eval"):
|
||||
# No need to log anything. Logging will be handled outside
|
||||
# via stats_tracker.export().
|
||||
evaluator.evaluate(
|
||||
valid_dataloader,
|
||||
engine.evaluate_lm,
|
||||
epoch,
|
||||
step,
|
||||
global_step,
|
||||
)
|
||||
|
||||
logger.commit(epoch, step, global_step, stats_tracker.export())
|
||||
global_step += 1
|
||||
|
||||
engine.destroy()
|
||||
logger.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Dict
|
||||
|
||||
|
@ -49,6 +51,17 @@ class DistributedStatsTracker:
|
|||
return key
|
||||
return "/".join(self.scope_stack + [key])
|
||||
|
||||
@contextmanager
|
||||
def record_timing(self, key):
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# NOTE: timing records are fixed under the "timeperf" scope
|
||||
full_key = f"timeperf/{key}"
|
||||
self._set_reduce_type(full_key, ReduceType.SCALAR)
|
||||
self.stats[full_key].append(time.perf_counter() - start_time)
|
||||
|
||||
def denominator(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if not isinstance(value, torch.Tensor) or value.dtype != torch.bool:
|
||||
|
@ -252,3 +265,4 @@ denominator = DEFAULT_TRACKER.denominator
|
|||
export = DEFAULT_TRACKER.export
|
||||
scope = DEFAULT_TRACKER.scope
|
||||
scalar = DEFAULT_TRACKER.scalar
|
||||
record_timing = DEFAULT_TRACKER.record_timing
|
||||
|
|
Loading…
Reference in New Issue