This commit is contained in:
root 2025-06-26 19:36:59 +08:00
parent 6ec4493cb1
commit 84ff7597da
9 changed files with 30 additions and 39 deletions

View File

@ -46,10 +46,7 @@ class DatasetFactory:
from realhf.api.core.data_api import load_hf_tokenizer
tokenizer = load_hf_tokenizer(tokenizer_path)
return process_gsm8k_sft_dataset(
dataset,
tokenizer=tokenizer
)
return process_gsm8k_sft_dataset(dataset, tokenizer=tokenizer)
if config.preprocessor.type == "areal":
tokenizer_path = self.args.rollout.llm_client.tokenizer_path
assert self.args.rollout.llm_client.tokenizer_path is not None

View File

@ -3,7 +3,7 @@
import abc
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union, Callable, List, Dict
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
import torch.distributed as dist
from datasets import Dataset
@ -26,6 +26,7 @@ if TYPE_CHECKING:
# distributed sampler
# process group init
def _collate_fn(data_list: List[Dict]) -> Dict:
keys = data_list[0].keys()
result = {}
@ -33,6 +34,7 @@ def _collate_fn(data_list: List[Dict]) -> Dict:
result[k] = [d[k] for d in data_list]
return result
class Trainer(abc.ABC):
def __init__(
self,
@ -107,7 +109,7 @@ class TrainerFactory:
config: TrainerConfig,
train_dataset: Dataset,
valid_dataset: Optional[Dataset] = None,
rollout_controller: Optional["RolloutController"] = None
rollout_controller: Optional["RolloutController"] = None,
) -> Trainer:
if config.type == "grpo":
from arealite.impl.trainer.grpo import SpmdGRPOTrainer

View File

@ -26,16 +26,16 @@ def main():
"--config", help="The path of the main configuration file", required=True
)
args, overrides = parser.parse_known_args()
# Initialize hydra config
config_file = Path(args.config).absolute()
assert config_file.exists()
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
hydra_init(config_path=str(relpath.parent), job_name="app")
cfg = hydra_compose(
config_name=str(relpath.name).rstrip('.yaml'), overrides=overrides
config_name=str(relpath.name).rstrip(".yaml"), overrides=overrides
)
# Merge with the default configuration
default_cfg = OmegaConf.structured(TrainingArgs)
cfg = OmegaConf.merge(default_cfg, cfg)
@ -70,7 +70,7 @@ def main():
cfg.trainer,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
rollout_controller=rollout_controller
rollout_controller=rollout_controller,
)
trainer.train()

View File

@ -1,4 +1,4 @@
from typing import List, Dict
from typing import Dict, List
import torch
from datasets import Dataset
@ -22,6 +22,7 @@ def process_gsm8k_rl_dataset(dataset: Dataset, tokenizer, reward_mode):
lambda x: tokenizer(x["question"], return_attention_mask=False), batched=True
)
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
def _tokenize(example, idx):
# Add query_id column
@ -39,12 +40,8 @@ def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
)["input_ids"]
seq_len = len(tokenized_seq)
prompt_len = len(tokenized_prompt)
return {
"seq": tokenized_seq,
"prompt_len": prompt_len,
"seq_len": seq_len
}
return {"seq": tokenized_seq, "prompt_len": prompt_len, "seq_len": seq_len}
dataset = dataset.map(
lambda example, idx: _tokenize(example, idx),

View File

@ -1,9 +1,10 @@
import time
import os
from typing import Dict, List, Optional, Callable
import time
from typing import Callable, Dict, List, Optional
import torch
import torch.distributed as dist
import torch.utils.data
from datasets import Dataset
from arealite.api.cli_args import TrainerConfig, TrainingArgs
@ -20,24 +21,23 @@ from arealite.utils import (
)
from realhf.api.core.data_api import load_hf_tokenizer, tabulate_stats
from realhf.api.core.model_api import FinetuneSpec
from realhf.base import logging, stats_tracker, timeutil, constants
import torch.utils.data
from realhf.base import constants, logging, stats_tracker, timeutil
logger = logging.getLogger("SFT Trainer")
def get_save_checkpoint_path(
args: TrainingArgs, epoch: int, step: int, globalstep: int
args: TrainingArgs, epoch: int, step: int, globalstep: int
):
path = os.path.join(
constants.get_save_path(args),
"model",
f"epoch{epoch}epochstep{step}globalstep{globalstep}"
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
)
os.makedirs(path, exist_ok=True)
return path
def compute_packed_sft_loss(
logits: torch.Tensor,
input_: Dict[str, torch.Tensor],
@ -102,11 +102,7 @@ class SFTTrainer(Trainer):
rollout_controller: Optional[RolloutController] = None,
):
super().__init__(
args,
trainer_config,
train_dataset,
valid_dataset,
rollout_controller
args, trainer_config, train_dataset, valid_dataset, rollout_controller
)
self.config = config = trainer_config.sft
@ -147,9 +143,7 @@ class SFTTrainer(Trainer):
input_lens = data["seq_len"]
input_lens = torch.tensor(input_lens, dtype=torch.int)
input_ids = [
torch.tensor(seq, dtype=torch.long) for seq in tokenized_seqs
]
input_ids = [torch.tensor(seq, dtype=torch.long) for seq in tokenized_seqs]
prompt_mask = []
for input_len, prompt_len in zip(input_lens, prompt_lens):
@ -200,7 +194,7 @@ class SFTTrainer(Trainer):
timing_stats = {}
with record_timing("timeperf/data_processing", timing_stats):
packed_input_data = self._get_packed_input(data)
with record_timing("timeperf/train_step", timing_stats):
with stats_tracker.scope("sft"):
stats = self.model.train_batch(

View File

@ -152,6 +152,7 @@ def test_train_batch(tmp_path_factory, engine, mock_input):
engine.load_optimizer_state(path)
@torch.no_grad()
def test_save_load_weights(tmp_path_factory, engine, mock_input):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")

View File

@ -1,13 +1,14 @@
"""Test script for FSDP Engine implementation."""
from typing import Dict
import os
from typing import Dict
import torch
from datasets import load_dataset
from arealite.api.cli_args import (
DatasetConfig,
DatasetPreprocessor,
EngineBackendConfig,
EngineConfig,
ModelFamily,
@ -15,7 +16,6 @@ from arealite.api.cli_args import (
SFTTrainerConfig,
TrainerConfig,
TrainingArgs,
DatasetPreprocessor
)
from arealite.api.dataset_api import DatasetFactory
from arealite.api.trainer_api import TrainerFactory

View File

@ -4,7 +4,6 @@
# Pad/unpad operations are modified from flash-attention under BSD-3 license.
# Copyright (c) 2023, Tri Dao.
import socket
import time
from contextlib import contextmanager
from dataclasses import dataclass
@ -19,7 +18,7 @@ from einops import rearrange, repeat
from tensorboardX import SummaryWriter
from arealite.api.cli_args import MicroBatchSpec, TrainingArgs
from realhf.base import datapack, constants
from realhf.base import constants, datapack
############### Dict and list operations begin ###############
@ -608,8 +607,10 @@ def gather_logprobs(
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
return log_probs_labels
############### Tensor computations end ###############
def init_stats_logging(args: TrainingArgs):
"""
Initialize wandb and/or tensorboard according to config.
@ -672,4 +673,3 @@ def record_timing(name, timing_stats):
start_time = time.perf_counter()
yield
timing_stats[name] = time.perf_counter() - start_time

View File

@ -46,7 +46,7 @@ def main(args):
yaml.dump(
config_dict,
f,
default_flow_style=False
default_flow_style=False,
sort_keys=False,
)