rm mb_spec

This commit is contained in:
bowei.fw 2025-07-09 22:16:31 +08:00
parent 32077b02ed
commit a78fd2dd24
9 changed files with 32 additions and 78 deletions

View File

@ -187,7 +187,6 @@ class TrainEngine(abc.ABC):
def train_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
@ -197,7 +196,6 @@ class TrainEngine(abc.ABC):
def eval_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
@ -207,7 +205,6 @@ class TrainEngine(abc.ABC):
def forward(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
@ -323,7 +320,7 @@ Extended engines (such as Actor in PPO) provide convenient organization and call
class Actor(Engine):
@torch.no_grad()
def compute_logps(self, input_: Dict[str, Tensor], mb_spec: MicroBatchSpec) -> torch.Tensor:
def compute_logps(self, input_: Dict[str, Tensor]) -> torch.Tensor:
... # unpad
logps = self.forward(xxx)
... # pad back
@ -332,8 +329,7 @@ class Actor(Engine):
def compute_advantages_and_returns(self, input_: Dict) -> Dict:
pass
def ppo_update(self, input_: Dict,
mb_spec: MicroBatchSpec) -> List[Dict[str, float]]:
def ppo_update(self, input_: Dict) -> List[Dict[str, float]]:
...
all_stats = []
for _ in range(self.ppo_n_minibatches):
@ -344,11 +340,10 @@ class Actor(Engine):
class Critic(Engine):
@torch.no_grad()
def compute_values(self, input_: Dict, mb_spec: MicroBatchSpec) -> torch.Tensor:
def compute_values(self, input_: Dict) -> torch.Tensor:
pass
def ppo_update(self, input_: Dict,
mb_spec: MicroBatchSpec) -> List[Dict[str, float]]:
def ppo_update(self, input_: Dict) -> List[Dict[str, float]]:
...
all_stats = []
for _ in range(self.ppo_n_minibatches):

View File

@ -120,6 +120,8 @@ class TrainEngineConfig:
default=False,
metadata={"help": "Initialize critic/reward model from LM checkpoint"},
)
# Runtime microbatch limit
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
# Training Backend Configuration
gradient_checkpointing: bool = field(
@ -533,7 +535,6 @@ class BaseExperimentConfig:
},
)
tokenizer_path: str = field(default="")
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
valid_dataset: DatasetConfig = field(default_factory=DatasetConfig)

View File

@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
from tensordict import TensorDict
from arealite.api.cli_args import MicroBatchSpec
from arealite.api.io_struct import (
FinetuneSpec,
LLMRequest,
@ -79,7 +78,6 @@ class TrainEngine(abc.ABC):
def train_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
@ -90,7 +88,6 @@ class TrainEngine(abc.ABC):
def eval_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
@ -101,7 +98,6 @@ class TrainEngine(abc.ABC):
def forward(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
output_seqlens: List[List[int]] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,

View File

@ -21,7 +21,6 @@ from transformers import (
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import (
FinetuneSpec,
MicroBatchSpec,
SaveLoadMeta,
TrainEngine,
WeightUpdateMeta,
@ -319,9 +318,7 @@ class FSDPEngine(TrainEngine):
assert self.lr_scheduler is not None
self.lr_scheduler.step()
def _prepare_mb_list(
self, input_: TensorDict, mb_spec: MicroBatchSpec
) -> MicroBatchList:
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
@ -329,7 +326,7 @@ class FSDPEngine(TrainEngine):
packed_input = pack_tensor_dict(input_)
mb_list = split_packed_tensor_dict_into_mb_list(
packed_input,
mb_spec,
self.config.mb_spec,
)
mb_list = pad_mb_list(mb_list, pad_value=0.0)
# NOTE: We unsqueeze here because huggingface transformer models requires
@ -340,7 +337,6 @@ class FSDPEngine(TrainEngine):
def train_batch(
self,
input_: TensorDict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict[str, float]:
@ -351,7 +347,7 @@ class FSDPEngine(TrainEngine):
assert self.lr_scheduler is not None
self.optimizer.zero_grad()
mb_list = self._prepare_mb_list(input_, mb_spec)
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
@ -400,13 +396,12 @@ class FSDPEngine(TrainEngine):
def eval_batch(
self,
input_: TensorDict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
"""Evaluate on a batch."""
input_ = input_.to(self.device)
mb_list = self._prepare_mb_list(input_, mb_spec)
mb_list = self._prepare_mb_list(input_)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
)
@ -434,7 +429,6 @@ class FSDPEngine(TrainEngine):
def forward(
self,
input_: TensorDict,
mb_spec: MicroBatchSpec,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
@ -442,7 +436,7 @@ class FSDPEngine(TrainEngine):
"""Forward pass with optional post-processing."""
input_ = input_.to(self.device)
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
mb_list = self._prepare_mb_list(input_, mb_spec)
mb_list = self._prepare_mb_list(input_)
if output_seqlens is None:
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()

View File

@ -9,12 +9,7 @@ import torch.distributed as dist
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
from arealite.api.cli_args import (
EngineConfig,
MicroBatchSpec,
ParallelismConfig,
TrainingArgs,
)
from arealite.api.cli_args import EngineConfig, ParallelismConfig, TrainingArgs
from arealite.api.engine_api import TrainEngine
from arealite.api.io_struct import FinetuneSpec
from arealite.api.llm_client_api import LLMClient
@ -150,7 +145,6 @@ class HFEngine(TrainEngine):
def train_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> Dict:
@ -192,7 +186,6 @@ class HFEngine(TrainEngine):
def eval_batch(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
loss_weight_fn: Callable[[Dict], float],
) -> torch.Tensor | None:
@ -221,7 +214,6 @@ class HFEngine(TrainEngine):
def forward(
self,
input_: Dict,
mb_spec: MicroBatchSpec,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),

View File

@ -4,7 +4,7 @@ import torch
import torch.utils.data
from tensordict import TensorDict
from arealite.api.cli_args import MicroBatchSpec, TrainEngineConfig
from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import TrainEngine
from arealite.engine.fsdp_engine import FSDPEngine
from arealite.utils.functional import gather_logprobs
@ -15,22 +15,20 @@ class LMEngine:
def __init__(self, engine: TrainEngine):
self.engine = engine
def train_lm(self, data: TensorDict, mb_spec: MicroBatchSpec):
def train_lm(self, data: TensorDict):
self.engine.train()
return self.engine.train_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
mb_spec=mb_spec,
)
def evaluate_lm(self, data, mb_spec: MicroBatchSpec):
def evaluate_lm(self, data):
self.engine.eval()
self.engine.eval_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
mb_spec=mb_spec,
)
@ -39,11 +37,11 @@ class FSDPLMEngine(FSDPEngine):
super().__init__(config)
self.lm_engine = LMEngine(self)
def train_lm(self, data, mb_spec):
return self.lm_engine.train_lm(data, mb_spec)
def train_lm(self, data):
return self.lm_engine.train_lm(data)
def evaluate_lm(self, data, mb_spec):
return self.lm_engine.evaluate_lm(data, mb_spec)
def evaluate_lm(self, data):
return self.lm_engine.evaluate_lm(data)
def compute_packed_sft_loss(

View File

@ -81,22 +81,10 @@ def engine():
@torch.no_grad()
def test_forward_microbatch(engine, mock_input):
engine.eval()
x2 = (
engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
)
.squeeze(0)
.mean(-1)
)
x1 = (
engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
)
.squeeze(0)
.mean(-1)
)
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
x2 = engine.forward(input_=mock_input).squeeze(0).mean(-1)
engine.config.mb_spec = MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100)
x1 = engine.forward(input_=mock_input).squeeze(0).mean(-1)
input_ids = mock_input["input_ids"]
assert x1.shape[:1] == input_ids.shape[:1]
assert x2.shape[:1] == input_ids.shape[:1]
@ -106,9 +94,9 @@ def test_forward_microbatch(engine, mock_input):
@torch.no_grad()
def test_eval_batch(engine, mock_input):
engine.eval()
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
eval_result = engine.eval_batch(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
loss_fn=mock_loss_fn,
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
)
@ -120,9 +108,9 @@ def test_eval_batch(engine, mock_input):
def test_train_batch(engine, mock_input):
engine.train()
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
train_result = engine.train_batch(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
loss_fn=mock_loss_fn,
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
)
@ -144,18 +132,13 @@ def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
base_model_path=None,
)
old = engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
)
engine.config.mb_spec = MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100)
old = engine.forward(input_=mock_input)
engine.save(save_load_meta)
for name, param in engine.model.named_parameters():
param.zero_()
engine.load(save_load_meta)
new = engine.forward(
input_=mock_input,
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
)
new = engine.forward(input_=mock_input)
assert torch.allclose(old, new)

View File

@ -9,8 +9,6 @@ cluster:
seed: 1
total_train_epochs: 1
tokenizer_path: ${model.path}
mb_spec:
max_tokens_per_mb: 4096
model:
experiment_name: ${experiment_name}
@ -19,6 +17,8 @@ model:
init_from_scratch: false
gradient_checkpointing: false
bf16: true
mb_spec:
max_tokens_per_mb: 4096
optimizer:
type: adam
lr: 2e-5

View File

@ -1,4 +1,3 @@
import functools
import os
import sys
@ -89,7 +88,7 @@ def main_sft():
stats_tracker.record_timing("train_step"),
stats_tracker.scope("sft"),
):
stats = engine.train_lm(data, config.mb_spec)
stats = engine.train_lm(data)
engine.step_lr_scheduler()
stats_tracker.scalar(**stats)
@ -99,13 +98,9 @@ def main_sft():
with stats_tracker.record_timing("eval"), stats_tracker.scope("sft-eval"):
# No need to log anything. Logging will be handled outside
# via stats_tracker.export().
evaluate_fn = functools.partial(
engine.evaluate_lm,
mb_spec=config.mb_spec,
)
evaluator.evaluate(
valid_dataloader,
evaluate_fn,
engine.evaluate_lm,
epoch,
step,
global_step,