mirror of https://github.com/inclusionAI/AReaL
rm mb_spec
This commit is contained in:
parent
32077b02ed
commit
a78fd2dd24
|
@ -187,7 +187,6 @@ class TrainEngine(abc.ABC):
|
|||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
|
@ -197,7 +196,6 @@ class TrainEngine(abc.ABC):
|
|||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
|
@ -207,7 +205,6 @@ class TrainEngine(abc.ABC):
|
|||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
|
@ -323,7 +320,7 @@ Extended engines (such as Actor in PPO) provide convenient organization and call
|
|||
class Actor(Engine):
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logps(self, input_: Dict[str, Tensor], mb_spec: MicroBatchSpec) -> torch.Tensor:
|
||||
def compute_logps(self, input_: Dict[str, Tensor]) -> torch.Tensor:
|
||||
... # unpad
|
||||
logps = self.forward(xxx)
|
||||
... # pad back
|
||||
|
@ -332,8 +329,7 @@ class Actor(Engine):
|
|||
def compute_advantages_and_returns(self, input_: Dict) -> Dict:
|
||||
pass
|
||||
|
||||
def ppo_update(self, input_: Dict,
|
||||
mb_spec: MicroBatchSpec) -> List[Dict[str, float]]:
|
||||
def ppo_update(self, input_: Dict) -> List[Dict[str, float]]:
|
||||
...
|
||||
all_stats = []
|
||||
for _ in range(self.ppo_n_minibatches):
|
||||
|
@ -344,11 +340,10 @@ class Actor(Engine):
|
|||
class Critic(Engine):
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_values(self, input_: Dict, mb_spec: MicroBatchSpec) -> torch.Tensor:
|
||||
def compute_values(self, input_: Dict) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def ppo_update(self, input_: Dict,
|
||||
mb_spec: MicroBatchSpec) -> List[Dict[str, float]]:
|
||||
def ppo_update(self, input_: Dict) -> List[Dict[str, float]]:
|
||||
...
|
||||
all_stats = []
|
||||
for _ in range(self.ppo_n_minibatches):
|
||||
|
|
|
@ -120,6 +120,8 @@ class TrainEngineConfig:
|
|||
default=False,
|
||||
metadata={"help": "Initialize critic/reward model from LM checkpoint"},
|
||||
)
|
||||
# Runtime microbatch limit
|
||||
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
||||
|
||||
# Training Backend Configuration
|
||||
gradient_checkpointing: bool = field(
|
||||
|
@ -533,7 +535,6 @@ class BaseExperimentConfig:
|
|||
},
|
||||
)
|
||||
tokenizer_path: str = field(default="")
|
||||
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
|
||||
|
||||
train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
||||
valid_dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
||||
|
|
|
@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
|||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec
|
||||
from arealite.api.io_struct import (
|
||||
FinetuneSpec,
|
||||
LLMRequest,
|
||||
|
@ -79,7 +78,6 @@ class TrainEngine(abc.ABC):
|
|||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
|
@ -90,7 +88,6 @@ class TrainEngine(abc.ABC):
|
|||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
|
@ -101,7 +98,6 @@ class TrainEngine(abc.ABC):
|
|||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[List[int]] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
|
|
|
@ -21,7 +21,6 @@ from transformers import (
|
|||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import (
|
||||
FinetuneSpec,
|
||||
MicroBatchSpec,
|
||||
SaveLoadMeta,
|
||||
TrainEngine,
|
||||
WeightUpdateMeta,
|
||||
|
@ -319,9 +318,7 @@ class FSDPEngine(TrainEngine):
|
|||
assert self.lr_scheduler is not None
|
||||
self.lr_scheduler.step()
|
||||
|
||||
def _prepare_mb_list(
|
||||
self, input_: TensorDict, mb_spec: MicroBatchSpec
|
||||
) -> MicroBatchList:
|
||||
def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
|
||||
assert "attention_mask" in input_ and "input_ids" in input_
|
||||
if isinstance(input_, dict):
|
||||
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
|
||||
|
@ -329,7 +326,7 @@ class FSDPEngine(TrainEngine):
|
|||
packed_input = pack_tensor_dict(input_)
|
||||
mb_list = split_packed_tensor_dict_into_mb_list(
|
||||
packed_input,
|
||||
mb_spec,
|
||||
self.config.mb_spec,
|
||||
)
|
||||
mb_list = pad_mb_list(mb_list, pad_value=0.0)
|
||||
# NOTE: We unsqueeze here because huggingface transformer models requires
|
||||
|
@ -340,7 +337,6 @@ class FSDPEngine(TrainEngine):
|
|||
def train_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict[str, float]:
|
||||
|
@ -351,7 +347,7 @@ class FSDPEngine(TrainEngine):
|
|||
assert self.lr_scheduler is not None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
mb_list = self._prepare_mb_list(input_, mb_spec)
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
|
@ -400,13 +396,12 @@ class FSDPEngine(TrainEngine):
|
|||
def eval_batch(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
"""Evaluate on a batch."""
|
||||
input_ = input_.to(self.device)
|
||||
mb_list = self._prepare_mb_list(input_, mb_spec)
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32
|
||||
)
|
||||
|
@ -434,7 +429,6 @@ class FSDPEngine(TrainEngine):
|
|||
def forward(
|
||||
self,
|
||||
input_: TensorDict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
|
||||
|
@ -442,7 +436,7 @@ class FSDPEngine(TrainEngine):
|
|||
"""Forward pass with optional post-processing."""
|
||||
input_ = input_.to(self.device)
|
||||
cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"]
|
||||
mb_list = self._prepare_mb_list(input_, mb_spec)
|
||||
mb_list = self._prepare_mb_list(input_)
|
||||
|
||||
if output_seqlens is None:
|
||||
output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist()
|
||||
|
|
|
@ -9,12 +9,7 @@ import torch.distributed as dist
|
|||
import transformers
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
EngineConfig,
|
||||
MicroBatchSpec,
|
||||
ParallelismConfig,
|
||||
TrainingArgs,
|
||||
)
|
||||
from arealite.api.cli_args import EngineConfig, ParallelismConfig, TrainingArgs
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
from arealite.api.io_struct import FinetuneSpec
|
||||
from arealite.api.llm_client_api import LLMClient
|
||||
|
@ -150,7 +145,6 @@ class HFEngine(TrainEngine):
|
|||
def train_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> Dict:
|
||||
|
@ -192,7 +186,6 @@ class HFEngine(TrainEngine):
|
|||
def eval_batch(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor],
|
||||
loss_weight_fn: Callable[[Dict], float],
|
||||
) -> torch.Tensor | None:
|
||||
|
@ -221,7 +214,6 @@ class HFEngine(TrainEngine):
|
|||
def forward(
|
||||
self,
|
||||
input_: Dict,
|
||||
mb_spec: MicroBatchSpec,
|
||||
output_seqlens: List[int] | None = None,
|
||||
post_hook: Callable[[torch.Tensor, Dict], Any] | None = None,
|
||||
aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1),
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.utils.data
|
||||
from tensordict import TensorDict
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec, TrainEngineConfig
|
||||
from arealite.api.cli_args import TrainEngineConfig
|
||||
from arealite.api.engine_api import TrainEngine
|
||||
from arealite.engine.fsdp_engine import FSDPEngine
|
||||
from arealite.utils.functional import gather_logprobs
|
||||
|
@ -15,22 +15,20 @@ class LMEngine:
|
|||
def __init__(self, engine: TrainEngine):
|
||||
self.engine = engine
|
||||
|
||||
def train_lm(self, data: TensorDict, mb_spec: MicroBatchSpec):
|
||||
def train_lm(self, data: TensorDict):
|
||||
self.engine.train()
|
||||
return self.engine.train_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
|
||||
mb_spec=mb_spec,
|
||||
)
|
||||
|
||||
def evaluate_lm(self, data, mb_spec: MicroBatchSpec):
|
||||
def evaluate_lm(self, data):
|
||||
self.engine.eval()
|
||||
self.engine.eval_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x["prompt_mask"].logical_not().count_nonzero(),
|
||||
mb_spec=mb_spec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -39,11 +37,11 @@ class FSDPLMEngine(FSDPEngine):
|
|||
super().__init__(config)
|
||||
self.lm_engine = LMEngine(self)
|
||||
|
||||
def train_lm(self, data, mb_spec):
|
||||
return self.lm_engine.train_lm(data, mb_spec)
|
||||
def train_lm(self, data):
|
||||
return self.lm_engine.train_lm(data)
|
||||
|
||||
def evaluate_lm(self, data, mb_spec):
|
||||
return self.lm_engine.evaluate_lm(data, mb_spec)
|
||||
def evaluate_lm(self, data):
|
||||
return self.lm_engine.evaluate_lm(data)
|
||||
|
||||
|
||||
def compute_packed_sft_loss(
|
||||
|
|
|
@ -81,22 +81,10 @@ def engine():
|
|||
@torch.no_grad()
|
||||
def test_forward_microbatch(engine, mock_input):
|
||||
engine.eval()
|
||||
x2 = (
|
||||
engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
|
||||
)
|
||||
.squeeze(0)
|
||||
.mean(-1)
|
||||
)
|
||||
x1 = (
|
||||
engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
|
||||
)
|
||||
.squeeze(0)
|
||||
.mean(-1)
|
||||
)
|
||||
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
|
||||
x2 = engine.forward(input_=mock_input).squeeze(0).mean(-1)
|
||||
engine.config.mb_spec = MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100)
|
||||
x1 = engine.forward(input_=mock_input).squeeze(0).mean(-1)
|
||||
input_ids = mock_input["input_ids"]
|
||||
assert x1.shape[:1] == input_ids.shape[:1]
|
||||
assert x2.shape[:1] == input_ids.shape[:1]
|
||||
|
@ -106,9 +94,9 @@ def test_forward_microbatch(engine, mock_input):
|
|||
@torch.no_grad()
|
||||
def test_eval_batch(engine, mock_input):
|
||||
engine.eval()
|
||||
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
|
||||
eval_result = engine.eval_batch(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
|
||||
loss_fn=mock_loss_fn,
|
||||
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
|
||||
)
|
||||
|
@ -120,9 +108,9 @@ def test_eval_batch(engine, mock_input):
|
|||
|
||||
def test_train_batch(engine, mock_input):
|
||||
engine.train()
|
||||
engine.config.mb_spec = MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100)
|
||||
train_result = engine.train_batch(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=2, max_tokens_per_mb=100),
|
||||
loss_fn=mock_loss_fn,
|
||||
loss_weight_fn=lambda x: x["cu_seqlens"][-1],
|
||||
)
|
||||
|
@ -144,18 +132,13 @@ def test_hf_save_load_weights(tmp_path_factory, engine, mock_input):
|
|||
base_model_path=None,
|
||||
)
|
||||
|
||||
old = engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
|
||||
)
|
||||
engine.config.mb_spec = MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100)
|
||||
old = engine.forward(input_=mock_input)
|
||||
engine.save(save_load_meta)
|
||||
|
||||
for name, param in engine.model.named_parameters():
|
||||
param.zero_()
|
||||
|
||||
engine.load(save_load_meta)
|
||||
new = engine.forward(
|
||||
input_=mock_input,
|
||||
mb_spec=MicroBatchSpec(n_mbs=1, max_tokens_per_mb=100),
|
||||
)
|
||||
new = engine.forward(input_=mock_input)
|
||||
assert torch.allclose(old, new)
|
||||
|
|
|
@ -9,8 +9,6 @@ cluster:
|
|||
seed: 1
|
||||
total_train_epochs: 1
|
||||
tokenizer_path: ${model.path}
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 4096
|
||||
|
||||
model:
|
||||
experiment_name: ${experiment_name}
|
||||
|
@ -19,6 +17,8 @@ model:
|
|||
init_from_scratch: false
|
||||
gradient_checkpointing: false
|
||||
bf16: true
|
||||
mb_spec:
|
||||
max_tokens_per_mb: 4096
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 2e-5
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import functools
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
@ -89,7 +88,7 @@ def main_sft():
|
|||
stats_tracker.record_timing("train_step"),
|
||||
stats_tracker.scope("sft"),
|
||||
):
|
||||
stats = engine.train_lm(data, config.mb_spec)
|
||||
stats = engine.train_lm(data)
|
||||
engine.step_lr_scheduler()
|
||||
stats_tracker.scalar(**stats)
|
||||
|
||||
|
@ -99,13 +98,9 @@ def main_sft():
|
|||
with stats_tracker.record_timing("eval"), stats_tracker.scope("sft-eval"):
|
||||
# No need to log anything. Logging will be handled outside
|
||||
# via stats_tracker.export().
|
||||
evaluate_fn = functools.partial(
|
||||
engine.evaluate_lm,
|
||||
mb_spec=config.mb_spec,
|
||||
)
|
||||
evaluator.evaluate(
|
||||
valid_dataloader,
|
||||
evaluate_fn,
|
||||
engine.evaluate_lm,
|
||||
epoch,
|
||||
step,
|
||||
global_step,
|
||||
|
|
Loading…
Reference in New Issue