mirror of https://github.com/inclusionAI/AReaL
format
This commit is contained in:
parent
6ec4493cb1
commit
84ff7597da
|
@ -46,10 +46,7 @@ class DatasetFactory:
|
|||
from realhf.api.core.data_api import load_hf_tokenizer
|
||||
|
||||
tokenizer = load_hf_tokenizer(tokenizer_path)
|
||||
return process_gsm8k_sft_dataset(
|
||||
dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
return process_gsm8k_sft_dataset(dataset, tokenizer=tokenizer)
|
||||
if config.preprocessor.type == "areal":
|
||||
tokenizer_path = self.args.rollout.llm_client.tokenizer_path
|
||||
assert self.args.rollout.llm_client.tokenizer_path is not None
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union, Callable, List, Dict
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
from datasets import Dataset
|
||||
|
@ -26,6 +26,7 @@ if TYPE_CHECKING:
|
|||
# distributed sampler
|
||||
# process group init
|
||||
|
||||
|
||||
def _collate_fn(data_list: List[Dict]) -> Dict:
|
||||
keys = data_list[0].keys()
|
||||
result = {}
|
||||
|
@ -33,6 +34,7 @@ def _collate_fn(data_list: List[Dict]) -> Dict:
|
|||
result[k] = [d[k] for d in data_list]
|
||||
return result
|
||||
|
||||
|
||||
class Trainer(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -107,7 +109,7 @@ class TrainerFactory:
|
|||
config: TrainerConfig,
|
||||
train_dataset: Dataset,
|
||||
valid_dataset: Optional[Dataset] = None,
|
||||
rollout_controller: Optional["RolloutController"] = None
|
||||
rollout_controller: Optional["RolloutController"] = None,
|
||||
) -> Trainer:
|
||||
if config.type == "grpo":
|
||||
from arealite.impl.trainer.grpo import SpmdGRPOTrainer
|
||||
|
|
|
@ -26,16 +26,16 @@ def main():
|
|||
"--config", help="The path of the main configuration file", required=True
|
||||
)
|
||||
args, overrides = parser.parse_known_args()
|
||||
|
||||
|
||||
# Initialize hydra config
|
||||
config_file = Path(args.config).absolute()
|
||||
assert config_file.exists()
|
||||
relpath = Path(os.path.relpath(str(config_file), Path(__file__).parent.absolute()))
|
||||
hydra_init(config_path=str(relpath.parent), job_name="app")
|
||||
cfg = hydra_compose(
|
||||
config_name=str(relpath.name).rstrip('.yaml'), overrides=overrides
|
||||
config_name=str(relpath.name).rstrip(".yaml"), overrides=overrides
|
||||
)
|
||||
|
||||
|
||||
# Merge with the default configuration
|
||||
default_cfg = OmegaConf.structured(TrainingArgs)
|
||||
cfg = OmegaConf.merge(default_cfg, cfg)
|
||||
|
@ -70,7 +70,7 @@ def main():
|
|||
cfg.trainer,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
rollout_controller=rollout_controller
|
||||
rollout_controller=rollout_controller,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
@ -22,6 +22,7 @@ def process_gsm8k_rl_dataset(dataset: Dataset, tokenizer, reward_mode):
|
|||
lambda x: tokenizer(x["question"], return_attention_mask=False), batched=True
|
||||
)
|
||||
|
||||
|
||||
def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
|
||||
def _tokenize(example, idx):
|
||||
# Add query_id column
|
||||
|
@ -39,12 +40,8 @@ def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer):
|
|||
)["input_ids"]
|
||||
seq_len = len(tokenized_seq)
|
||||
prompt_len = len(tokenized_prompt)
|
||||
|
||||
return {
|
||||
"seq": tokenized_seq,
|
||||
"prompt_len": prompt_len,
|
||||
"seq_len": seq_len
|
||||
}
|
||||
|
||||
return {"seq": tokenized_seq, "prompt_len": prompt_len, "seq_len": seq_len}
|
||||
|
||||
dataset = dataset.map(
|
||||
lambda example, idx: _tokenize(example, idx),
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import time
|
||||
import os
|
||||
from typing import Dict, List, Optional, Callable
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
from datasets import Dataset
|
||||
|
||||
from arealite.api.cli_args import TrainerConfig, TrainingArgs
|
||||
|
@ -20,24 +21,23 @@ from arealite.utils import (
|
|||
)
|
||||
from realhf.api.core.data_api import load_hf_tokenizer, tabulate_stats
|
||||
from realhf.api.core.model_api import FinetuneSpec
|
||||
from realhf.base import logging, stats_tracker, timeutil, constants
|
||||
|
||||
import torch.utils.data
|
||||
from realhf.base import constants, logging, stats_tracker, timeutil
|
||||
|
||||
logger = logging.getLogger("SFT Trainer")
|
||||
|
||||
|
||||
def get_save_checkpoint_path(
|
||||
args: TrainingArgs, epoch: int, step: int, globalstep: int
|
||||
args: TrainingArgs, epoch: int, step: int, globalstep: int
|
||||
):
|
||||
path = os.path.join(
|
||||
constants.get_save_path(args),
|
||||
"model",
|
||||
f"epoch{epoch}epochstep{step}globalstep{globalstep}"
|
||||
f"epoch{epoch}epochstep{step}globalstep{globalstep}",
|
||||
)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def compute_packed_sft_loss(
|
||||
logits: torch.Tensor,
|
||||
input_: Dict[str, torch.Tensor],
|
||||
|
@ -102,11 +102,7 @@ class SFTTrainer(Trainer):
|
|||
rollout_controller: Optional[RolloutController] = None,
|
||||
):
|
||||
super().__init__(
|
||||
args,
|
||||
trainer_config,
|
||||
train_dataset,
|
||||
valid_dataset,
|
||||
rollout_controller
|
||||
args, trainer_config, train_dataset, valid_dataset, rollout_controller
|
||||
)
|
||||
|
||||
self.config = config = trainer_config.sft
|
||||
|
@ -147,9 +143,7 @@ class SFTTrainer(Trainer):
|
|||
input_lens = data["seq_len"]
|
||||
|
||||
input_lens = torch.tensor(input_lens, dtype=torch.int)
|
||||
input_ids = [
|
||||
torch.tensor(seq, dtype=torch.long) for seq in tokenized_seqs
|
||||
]
|
||||
input_ids = [torch.tensor(seq, dtype=torch.long) for seq in tokenized_seqs]
|
||||
|
||||
prompt_mask = []
|
||||
for input_len, prompt_len in zip(input_lens, prompt_lens):
|
||||
|
@ -200,7 +194,7 @@ class SFTTrainer(Trainer):
|
|||
timing_stats = {}
|
||||
with record_timing("timeperf/data_processing", timing_stats):
|
||||
packed_input_data = self._get_packed_input(data)
|
||||
|
||||
|
||||
with record_timing("timeperf/train_step", timing_stats):
|
||||
with stats_tracker.scope("sft"):
|
||||
stats = self.model.train_batch(
|
||||
|
|
|
@ -152,6 +152,7 @@ def test_train_batch(tmp_path_factory, engine, mock_input):
|
|||
|
||||
engine.load_optimizer_state(path)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_save_load_weights(tmp_path_factory, engine, mock_input):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
"""Test script for FSDP Engine implementation."""
|
||||
|
||||
from typing import Dict
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from arealite.api.cli_args import (
|
||||
DatasetConfig,
|
||||
DatasetPreprocessor,
|
||||
EngineBackendConfig,
|
||||
EngineConfig,
|
||||
ModelFamily,
|
||||
|
@ -15,7 +16,6 @@ from arealite.api.cli_args import (
|
|||
SFTTrainerConfig,
|
||||
TrainerConfig,
|
||||
TrainingArgs,
|
||||
DatasetPreprocessor
|
||||
)
|
||||
from arealite.api.dataset_api import DatasetFactory
|
||||
from arealite.api.trainer_api import TrainerFactory
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# Pad/unpad operations are modified from flash-attention under BSD-3 license.
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import socket
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
@ -19,7 +18,7 @@ from einops import rearrange, repeat
|
|||
from tensorboardX import SummaryWriter
|
||||
|
||||
from arealite.api.cli_args import MicroBatchSpec, TrainingArgs
|
||||
from realhf.base import datapack, constants
|
||||
from realhf.base import constants, datapack
|
||||
|
||||
############### Dict and list operations begin ###############
|
||||
|
||||
|
@ -608,8 +607,10 @@ def gather_logprobs(
|
|||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
||||
return log_probs_labels
|
||||
|
||||
|
||||
############### Tensor computations end ###############
|
||||
|
||||
|
||||
def init_stats_logging(args: TrainingArgs):
|
||||
"""
|
||||
Initialize wandb and/or tensorboard according to config.
|
||||
|
@ -672,4 +673,3 @@ def record_timing(name, timing_stats):
|
|||
start_time = time.perf_counter()
|
||||
yield
|
||||
timing_stats[name] = time.perf_counter() - start_time
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ def main(args):
|
|||
yaml.dump(
|
||||
config_dict,
|
||||
f,
|
||||
default_flow_style=False
|
||||
default_flow_style=False,
|
||||
sort_keys=False,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue