mirror of https://github.com/inclusionAI/AReaL
PullRequest: 4 Fix the dataloader shuffle and random seed issue.
Merge branch fw/fix-dataloading-not-shuffle of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/4 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fw/fix-dataloading-not-shuffle * . * . * .
This commit is contained in:
parent
d04f17031a
commit
e7c4a49adc
|
@ -16,12 +16,6 @@ class DatasetAbstraction:
|
|||
args: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DataLoaderAbstraction:
|
||||
type_: str = "default"
|
||||
args: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelWrapperAbstraction:
|
||||
type_: str
|
||||
|
@ -187,10 +181,6 @@ class StandaloneModelShardAbstraction:
|
|||
model: ModelAbstraction
|
||||
backend: ModelBackendAbstraction
|
||||
# evaluation
|
||||
eval_datasets: Optional[List[DatasetAbstraction]] = None
|
||||
eval_dataloader: Optional[DataLoaderAbstraction] = dataclasses.field(
|
||||
default_factory=lambda: DataLoaderAbstraction(
|
||||
"packed_eval", args=dict(batch_size=128)
|
||||
)
|
||||
)
|
||||
eval_dataset: Optional[DatasetAbstraction] = None
|
||||
eval_bs: int = 128
|
||||
should_instantiate: bool = True
|
||||
|
|
|
@ -881,64 +881,3 @@ def make_dataset(
|
|||
logger.info(f"Dataset creation/loading time: {time.perf_counter() - tik:.3f}s")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
ALL_DATALOADER_CLASSES = {}
|
||||
|
||||
|
||||
def register_dataloader(name, dataloader_cls):
|
||||
assert name not in ALL_DATALOADER_CLASSES
|
||||
ALL_DATALOADER_CLASSES[name] = dataloader_cls
|
||||
|
||||
|
||||
def make_dataloader(
|
||||
cfg: Union[str, config_api.DataLoaderAbstraction], dataset: torch.utils.data.Dataset, seed_offset: Optional[int] = None
|
||||
) -> torch.utils.data.DataLoader:
|
||||
if isinstance(cfg, str):
|
||||
cfg = config_api.DataLoaderAbstraction(type_=cfg)
|
||||
dataloader_cls = ALL_DATALOADER_CLASSES[cfg.type_]
|
||||
if seed_offset is None:
|
||||
return dataloader_cls(dataset, **cfg.args)
|
||||
else:
|
||||
return dataloader_cls(dataset, **cfg.args, seed_offset=seed_offset)
|
||||
|
||||
|
||||
def PackedDataLoader(dataset, *args, seed_offset: int = 0, **kwargs):
|
||||
if not isinstance(getattr(dataset, "util", None), DatasetUtility):
|
||||
raise ValueError("Dataset must have a `util` attribute of type DatasetUtility.")
|
||||
g = torch.Generator()
|
||||
g.manual_seed(dataset.util.seed + dist.get_rank() + seed_offset)
|
||||
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
*args,
|
||||
collate_fn=SequenceSample.gather,
|
||||
# NOTE: This is *NOT* the actual batch size for training.
|
||||
# It is just a proper size to load data to workers.
|
||||
batch_size=10240,
|
||||
shuffle=True,
|
||||
generator=g,
|
||||
worker_init_fn=seed_worker,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def PackedEvalDataLoader(dataset, *args, **kwargs):
|
||||
if not isinstance(getattr(dataset, "util", None), DatasetUtility):
|
||||
raise ValueError("Dataset must have a `util` attribute of type DatasetUtility.")
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
*args,
|
||||
collate_fn=SequenceSample.gather,
|
||||
shuffle=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
register_dataloader("packed", PackedDataLoader)
|
||||
register_dataloader("packed_eval", PackedEvalDataLoader)
|
||||
|
|
|
@ -15,7 +15,6 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
|||
|
||||
import realhf.api.core.dfg as dfg
|
||||
from realhf.api.core.config import (
|
||||
DataLoaderAbstraction,
|
||||
DatasetAbstraction,
|
||||
ModelAbstraction,
|
||||
ModelName,
|
||||
|
@ -114,17 +113,13 @@ class WorkerInformation:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class ModelWorker:
|
||||
seed: int
|
||||
base_seed: int
|
||||
shards: List[StandaloneModelShardAbstraction]
|
||||
# dataset, for source model workers
|
||||
tokenizer_name_or_path: Optional[str] = None
|
||||
datasets: Optional[List[Union[str, DatasetAbstraction]]] = None
|
||||
dataloader: Union[str, DataLoaderAbstraction] = "packed"
|
||||
use_dataset_cache: bool = False
|
||||
dataset_cahce_root: str = constants.DATASET_CACHE_PATH
|
||||
# cuda & cudnn config
|
||||
cudnn_benchmark: bool = False
|
||||
cudnn_deterministic: bool = False
|
||||
cuda_cache_cleanliness: bool = True
|
||||
cuda_cache_clear_freq: int = 10
|
||||
torch_cache_mysophobia: bool = False
|
||||
|
@ -210,12 +205,13 @@ class ExperimentSaveEvalControl:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class MasterWorker:
|
||||
base_seed: int
|
||||
exp_ctrl: ExperimentSaveEvalControl
|
||||
# main components
|
||||
n_model_workers: int
|
||||
model_rpcs: List[dfg.MFCDef] = None
|
||||
model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = None
|
||||
msid2mwid: Dict[ModelShardID, int] = None
|
||||
msid2mwid: Dict[ModelShardID | str, int] = None
|
||||
data_transfer_pairs: List[Tuple[str, str]] = None
|
||||
sync_param_pairs: List[Tuple[str, str]] = None
|
||||
worker_info: Optional[WorkerInformation] = None
|
||||
|
@ -263,7 +259,11 @@ class ExperimentConfig:
|
|||
|
||||
def __post_init__(self):
|
||||
self.master_worker = [
|
||||
MasterWorker(exp_ctrl=self.exp_ctrl, n_model_workers=len(self.model_worker))
|
||||
MasterWorker(
|
||||
base_seed=self.model_worker[0].base_seed,
|
||||
exp_ctrl=self.exp_ctrl,
|
||||
n_model_workers=len(self.model_worker),
|
||||
)
|
||||
]
|
||||
|
||||
def lazy_init(self):
|
||||
|
|
|
@ -2,17 +2,29 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
_SEED = None
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
global _SEED
|
||||
_SEED = seed
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
transformers.set_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def get_seed() -> int:
|
||||
global _SEED
|
||||
return _SEED
|
||||
|
|
|
@ -18,7 +18,6 @@ from omegaconf import MISSING, OmegaConf
|
|||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.core.config import (
|
||||
DataLoaderAbstraction,
|
||||
DatasetAbstraction,
|
||||
ModelAbstraction,
|
||||
ModelBackendAbstraction,
|
||||
|
@ -256,22 +255,17 @@ class CommonExperimentConfig(Experiment):
|
|||
return NotImplementedError(f"datasets is not implemented in {self.__class__}")
|
||||
|
||||
@property
|
||||
def eval_datasets(self) -> List[DatasetAbstraction]:
|
||||
"""A list of dataset configurations used for evaluation.
|
||||
def eval_dataset(self) -> DatasetAbstraction | None:
|
||||
"""The dataset configuration used for evaluation.
|
||||
|
||||
Can be None if runtime evaluation is not needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def eval_dataloader(self) -> DataLoaderAbstraction:
|
||||
"""The dataloader configuration used for evaluation.
|
||||
|
||||
Reserved to changed the evaluation batch size. Training does not
|
||||
require this property because the batch size is handled in MFC
|
||||
definitions.
|
||||
"""
|
||||
return DataLoaderAbstraction("packed_eval", args=dict(batch_size=128))
|
||||
def eval_bs(self) -> int:
|
||||
"""The batch size for runtime evaluation."""
|
||||
return 128
|
||||
|
||||
@property
|
||||
def tokenizer_name_or_path(self) -> str:
|
||||
|
@ -553,7 +547,7 @@ class CommonExperimentConfig(Experiment):
|
|||
|
||||
for i, j in itertools.product(range(self.n_nodes), range(self.n_gpus_per_node)):
|
||||
mw = ModelWorker(
|
||||
seed=self.seed,
|
||||
base_seed=self.seed,
|
||||
shards=[],
|
||||
datasets=self.datasets,
|
||||
torch_cache_mysophobia=self.torch_cache_mysophobia,
|
||||
|
@ -611,7 +605,6 @@ class CommonExperimentConfig(Experiment):
|
|||
backend=ModelBackendAbstraction(
|
||||
"vllm",
|
||||
args=dict(
|
||||
seed=self.seed,
|
||||
model_path=model_cfg.path,
|
||||
**vllm_dict_args,
|
||||
),
|
||||
|
@ -691,7 +684,6 @@ class CommonExperimentConfig(Experiment):
|
|||
backend = ModelBackendAbstraction(
|
||||
"vllm",
|
||||
args=dict(
|
||||
seed=self.seed,
|
||||
model_path=model_cfg.path,
|
||||
**vllm_dict_args,
|
||||
),
|
||||
|
@ -713,8 +705,8 @@ class CommonExperimentConfig(Experiment):
|
|||
),
|
||||
model=model,
|
||||
backend=backend,
|
||||
eval_datasets=self.eval_datasets,
|
||||
eval_dataloader=self.eval_dataloader,
|
||||
eval_dataset=self.eval_dataset,
|
||||
eval_bs=self.eval_bs,
|
||||
)
|
||||
)
|
||||
shard_counter[model_name] += 1
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
import dataclasses
|
||||
|
||||
from realhf.api.core.config import (
|
||||
DataLoaderAbstraction,
|
||||
DatasetAbstraction,
|
||||
ModelInterfaceAbstraction,
|
||||
ModelInterfaceType,
|
||||
|
@ -82,22 +81,18 @@ class SFTConfig(CommonExperimentConfig):
|
|||
]
|
||||
|
||||
@property
|
||||
def eval_datasets(self):
|
||||
return [
|
||||
DatasetAbstraction(
|
||||
"prompt_answer",
|
||||
args=dict(
|
||||
max_length=self.dataset.max_seqlen,
|
||||
dataset_path=self.dataset.valid_path,
|
||||
),
|
||||
)
|
||||
]
|
||||
def eval_dataset(self):
|
||||
return DatasetAbstraction(
|
||||
"prompt_answer",
|
||||
args=dict(
|
||||
max_length=self.dataset.max_seqlen,
|
||||
dataset_path=self.dataset.valid_path,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def eval_dataloader(self):
|
||||
return DataLoaderAbstraction(
|
||||
"packed_eval", args=dict(batch_size=self.dataset.valid_bs_n_seqs)
|
||||
)
|
||||
def eval_bs(self) -> int:
|
||||
return self.dataset.valid_bs_n_seqs
|
||||
|
||||
@property
|
||||
def tokenizer_name_or_path(self):
|
||||
|
|
|
@ -31,7 +31,7 @@ except ModuleNotFoundError:
|
|||
|
||||
from realhf.api.core import data_api, model_api
|
||||
from realhf.api.quickstart.model import vLLMConfig
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base import constants, logging, seeding
|
||||
|
||||
logger = logging.getLogger("vLLM backend")
|
||||
|
||||
|
@ -165,7 +165,6 @@ class vLLMGenerationEngine(model_api.PipelinableEngine, LLM):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
|
||||
seed: int = 0
|
||||
model_path: str = ""
|
||||
|
||||
def _initialize(
|
||||
|
@ -187,7 +186,7 @@ class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
|
|||
skip_tokenizer_init=False,
|
||||
trust_remote_code=True,
|
||||
max_model_len=self.max_model_len,
|
||||
seed=self.seed,
|
||||
seed=seeding.get_seed(),
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype=self.kv_cache_type,
|
||||
device=constants.current_device(),
|
||||
|
|
|
@ -152,10 +152,7 @@ class SFTInterface(model_api.ModelInterface):
|
|||
res = module.eval_batch(
|
||||
input_=x.to_device(device),
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
mb_spec=MicroBatchSpec(
|
||||
n_mbs=constants.pipe_parallel_world_size(),
|
||||
balanced_seqs=True,
|
||||
),
|
||||
mb_spec=MicroBatchSpec(),
|
||||
)
|
||||
|
||||
if res is not None:
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
import dataclasses
|
||||
from typing import *
|
||||
|
||||
from realhf.api.core.data_api import SequenceSample
|
||||
from realhf.api.core.dfg import MFCDef, ModelInterfaceType
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base.monitor import (
|
||||
caculuate_llama_forward_flops,
|
||||
calculate_llama_gen_flops,
|
||||
calculate_llama_train_flops,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FlopsCounter:
|
||||
train_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
|
||||
train_bs: List[int] = dataclasses.field(default_factory=list)
|
||||
train_seqlens: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
|
||||
inf_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
|
||||
inf_bs: List[int] = dataclasses.field(default_factory=list)
|
||||
inf_seqlens: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
|
||||
gen_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
|
||||
gen_bs: List[int] = dataclasses.field(default_factory=list)
|
||||
prompt_lens: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
gen_len: List[int] = dataclasses.field(default_factory=list)
|
||||
|
||||
def clear(self):
|
||||
self.train_bs.clear()
|
||||
self.train_seqlens.clear()
|
||||
|
||||
self.inf_bs.clear()
|
||||
self.inf_seqlens.clear()
|
||||
|
||||
self.gen_bs.clear()
|
||||
self.prompt_lens.clear()
|
||||
self.gen_len.clear()
|
||||
|
||||
self.train_configs.clear()
|
||||
self.inf_configs.clear()
|
||||
self.gen_configs.clear()
|
||||
|
||||
def add_rpc(
|
||||
self, rpc: MFCDef, sample: SequenceSample, model_config: ReaLModelConfig
|
||||
):
|
||||
# Record the data amount for each interface to compute FLOPs.
|
||||
# Since the user may arbitrarily specify input/output keys,
|
||||
# we can only try to find the most probable key name for computing FLOPs.
|
||||
# If such keys do not exist, we will use the key with the longest
|
||||
# sequence length in this model function call.
|
||||
acc_seqlens = {
|
||||
k: sum(sum(x) for x in slens) for k, slens in sample.seqlens.items()
|
||||
}
|
||||
seqlen_key = max(sample.seqlens, key=acc_seqlens.get)
|
||||
flops_seqlens = [sum(x) for x in sample.seqlens[seqlen_key]]
|
||||
if rpc.interface_type == ModelInterfaceType.GENERATE:
|
||||
self.gen_configs.append(model_config)
|
||||
self.gen_bs.append(sample.bs)
|
||||
self.gen_len.append(
|
||||
rpc.interface_impl.args["generation_config"]["min_new_tokens"]
|
||||
)
|
||||
self.prompt_lens.append(flops_seqlens)
|
||||
elif rpc.interface_type == ModelInterfaceType.TRAIN_STEP:
|
||||
self.train_configs.append(model_config)
|
||||
self.train_bs.append(sample.bs)
|
||||
self.train_seqlens.append(flops_seqlens)
|
||||
elif rpc.interface_type == ModelInterfaceType.INFERENCE:
|
||||
self.inf_configs.append(model_config)
|
||||
self.inf_bs.append(sample.bs)
|
||||
self.inf_seqlens.append(flops_seqlens)
|
||||
|
||||
def get_flops(self) -> int:
|
||||
flops = 0
|
||||
for train_bs, train_seqlens, real_config in zip(
|
||||
self.train_bs,
|
||||
self.train_seqlens,
|
||||
self.train_configs,
|
||||
):
|
||||
flops += calculate_llama_train_flops(
|
||||
checkpoint_activations_factor=4,
|
||||
batch_size=train_bs,
|
||||
seqlens=train_seqlens,
|
||||
num_layers=real_config.n_layers,
|
||||
hidden_size=real_config.hidden_dim,
|
||||
intermediate_size=real_config.intermediate_dim,
|
||||
vocab_size=real_config.vocab_size,
|
||||
)
|
||||
for inf_bs, inf_seqlens, real_config in zip(
|
||||
self.inf_bs,
|
||||
self.inf_seqlens,
|
||||
self.inf_configs,
|
||||
):
|
||||
flops += caculuate_llama_forward_flops(
|
||||
batch_size=inf_bs,
|
||||
seqlens=inf_seqlens,
|
||||
num_layers=real_config.n_layers,
|
||||
hidden_size=real_config.hidden_dim,
|
||||
intermediate_size=real_config.intermediate_dim,
|
||||
vocab_size=real_config.vocab_size,
|
||||
)
|
||||
for gen_bs, prompt_lens, gen_len, real_config in zip(
|
||||
self.gen_bs,
|
||||
self.prompt_lens,
|
||||
self.gen_len,
|
||||
self.gen_configs,
|
||||
):
|
||||
flops += calculate_llama_gen_flops(
|
||||
batch_size=gen_bs,
|
||||
prompt_lens=prompt_lens,
|
||||
gen_len=gen_len,
|
||||
num_layers=real_config.n_layers,
|
||||
hidden_size=real_config.hidden_dim,
|
||||
intermediate_size=real_config.intermediate_dim,
|
||||
vocab_size=real_config.vocab_size,
|
||||
)
|
||||
return flops
|
|
@ -11,6 +11,7 @@ import getpass
|
|||
import itertools
|
||||
import os
|
||||
import pprint
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
|
@ -40,6 +41,7 @@ from realhf.base import (
|
|||
logging,
|
||||
name_resolve,
|
||||
names,
|
||||
seeding,
|
||||
timeutil,
|
||||
topology,
|
||||
)
|
||||
|
@ -48,12 +50,8 @@ from realhf.base.asyncio_utils import (
|
|||
setup_run_until_complete,
|
||||
teardown_run_util_complete,
|
||||
)
|
||||
from realhf.base.monitor import (
|
||||
caculuate_llama_forward_flops,
|
||||
calculate_llama_gen_flops,
|
||||
calculate_llama_train_flops,
|
||||
)
|
||||
from realhf.system.buffer import AsyncIOSequenceBuffer
|
||||
from realhf.system.flops_counter import FlopsCounter
|
||||
|
||||
logger = logging.getLogger("master worker", "system")
|
||||
blogger = logging.getLogger("benchmark")
|
||||
|
@ -107,37 +105,6 @@ def _attach_param_realloc_hooks(
|
|||
return payload
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InterfaceDataAmount:
|
||||
train_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
|
||||
train_bs: List[int] = dataclasses.field(default_factory=list)
|
||||
train_seqlens: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
|
||||
inf_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
|
||||
inf_bs: List[int] = dataclasses.field(default_factory=list)
|
||||
inf_seqlens: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
|
||||
gen_configs: List[ReaLModelConfig] = dataclasses.field(default_factory=list)
|
||||
gen_bs: List[int] = dataclasses.field(default_factory=list)
|
||||
prompt_lens: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
gen_len: List[int] = dataclasses.field(default_factory=list)
|
||||
|
||||
def clear(self):
|
||||
self.train_bs.clear()
|
||||
self.train_seqlens.clear()
|
||||
|
||||
self.inf_bs.clear()
|
||||
self.inf_seqlens.clear()
|
||||
|
||||
self.gen_bs.clear()
|
||||
self.prompt_lens.clear()
|
||||
self.gen_len.clear()
|
||||
|
||||
self.train_configs.clear()
|
||||
self.inf_configs.clear()
|
||||
self.gen_configs.clear()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RPCCorountineControl:
|
||||
## Shared resources ##
|
||||
|
@ -157,9 +124,7 @@ class RPCCorountineControl:
|
|||
|
||||
# for training data management and data cleaning after each step
|
||||
ids_to_clear: Set[int] = dataclasses.field(default_factory=set)
|
||||
data_amount: InterfaceDataAmount = dataclasses.field(
|
||||
default_factory=InterfaceDataAmount
|
||||
)
|
||||
flops_counter: FlopsCounter = dataclasses.field(default_factory=FlopsCounter)
|
||||
|
||||
should_save: bool = False
|
||||
should_eval: bool = False
|
||||
|
@ -408,31 +373,7 @@ async def model_rpc_request_func(
|
|||
|
||||
buf_indices, sample = await buffer.get_batch_for_rpc(rpc)
|
||||
|
||||
# Record the data amount for each interface to compute FLOPs.
|
||||
# Since the user may arbitrarily specify input/output keys,
|
||||
# we can only try to find the most probable key name for computing FLOPs.
|
||||
# If such keys do not exist, we will use the key with the longest
|
||||
# sequence length in this model function call.
|
||||
acc_seqlens = {
|
||||
k: sum(sum(x) for x in slens) for k, slens in sample.seqlens.items()
|
||||
}
|
||||
seqlen_key = max(sample.seqlens, key=acc_seqlens.get)
|
||||
flops_seqlens = [sum(x) for x in sample.seqlens[seqlen_key]]
|
||||
if rpc.interface_type == dfg.ModelInterfaceType.GENERATE:
|
||||
ctrl.data_amount.gen_configs.append(model_configs[rpc.model_name])
|
||||
ctrl.data_amount.gen_bs.append(sample.bs)
|
||||
ctrl.data_amount.gen_len.append(
|
||||
rpc.interface_impl.args["generation_config"]["min_new_tokens"]
|
||||
)
|
||||
ctrl.data_amount.prompt_lens.append(flops_seqlens)
|
||||
elif rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP:
|
||||
ctrl.data_amount.train_configs.append(model_configs[rpc.model_name])
|
||||
ctrl.data_amount.train_bs.append(sample.bs)
|
||||
ctrl.data_amount.train_seqlens.append(flops_seqlens)
|
||||
elif rpc.interface_type == dfg.ModelInterfaceType.INFERENCE:
|
||||
ctrl.data_amount.inf_configs.append(model_configs[rpc.model_name])
|
||||
ctrl.data_amount.inf_bs.append(sample.bs)
|
||||
ctrl.data_amount.inf_seqlens.append(flops_seqlens)
|
||||
ctrl.flops_counter.add_rpc(rpc, sample, model_configs[rpc.model_name])
|
||||
|
||||
this_rpc_consumed_seqs += sample.bs
|
||||
|
||||
|
@ -599,6 +540,8 @@ class MasterWorker(worker_base.Worker):
|
|||
def _configure(self, config: config_pkg.MasterWorker):
|
||||
self.config = config
|
||||
|
||||
seeding.set_random_seed(self.config.base_seed + self.config.n_model_workers)
|
||||
|
||||
self.__model_topos: Dict[ModelName, topology.PipeModelDataParallelTopology] = (
|
||||
config.model_topos
|
||||
)
|
||||
|
@ -1248,12 +1191,9 @@ class MasterWorker(worker_base.Worker):
|
|||
filtered_data.append(x)
|
||||
all_data = filtered_data
|
||||
|
||||
# Reorder loaded (meta-)data and store them into the buffer.
|
||||
# NOTE: The reordered indices prioritize longer sequences for detecting OOM errors early.
|
||||
# reorder_indices, _ = datapack.reorder_to_balanced_batches(
|
||||
# np.array(seqlens), src_rpc.n_seqs
|
||||
# )
|
||||
# all_data: List[data_api.SequenceSample] = [all_data[i] for i in reorder_indices]
|
||||
# We load data in a round-robin manner across different DP ranks,
|
||||
# so we also need to shuffle the data to fuse different dataset splits.
|
||||
random.shuffle(all_data)
|
||||
|
||||
blogger.info(
|
||||
f"Master worker loaded {len(all_data)} pieces of data. "
|
||||
|
@ -1287,52 +1227,10 @@ class MasterWorker(worker_base.Worker):
|
|||
flops = None
|
||||
tflops_per_gpu = float("inf")
|
||||
else:
|
||||
flops = 0
|
||||
for train_bs, train_seqlens, real_config in zip(
|
||||
self.__rpc_ctrl.data_amount.train_bs,
|
||||
self.__rpc_ctrl.data_amount.train_seqlens,
|
||||
self.__rpc_ctrl.data_amount.train_configs,
|
||||
):
|
||||
flops += calculate_llama_train_flops(
|
||||
checkpoint_activations_factor=4,
|
||||
batch_size=train_bs,
|
||||
seqlens=train_seqlens,
|
||||
num_layers=real_config.n_layers,
|
||||
hidden_size=real_config.hidden_dim,
|
||||
intermediate_size=real_config.intermediate_dim,
|
||||
vocab_size=real_config.vocab_size,
|
||||
)
|
||||
for inf_bs, inf_seqlens, real_config in zip(
|
||||
self.__rpc_ctrl.data_amount.inf_bs,
|
||||
self.__rpc_ctrl.data_amount.inf_seqlens,
|
||||
self.__rpc_ctrl.data_amount.inf_configs,
|
||||
):
|
||||
flops += caculuate_llama_forward_flops(
|
||||
batch_size=inf_bs,
|
||||
seqlens=inf_seqlens,
|
||||
num_layers=real_config.n_layers,
|
||||
hidden_size=real_config.hidden_dim,
|
||||
intermediate_size=real_config.intermediate_dim,
|
||||
vocab_size=real_config.vocab_size,
|
||||
)
|
||||
for gen_bs, prompt_lens, gen_len, real_config in zip(
|
||||
self.__rpc_ctrl.data_amount.gen_bs,
|
||||
self.__rpc_ctrl.data_amount.prompt_lens,
|
||||
self.__rpc_ctrl.data_amount.gen_len,
|
||||
self.__rpc_ctrl.data_amount.gen_configs,
|
||||
):
|
||||
flops += calculate_llama_gen_flops(
|
||||
batch_size=gen_bs,
|
||||
prompt_lens=prompt_lens,
|
||||
gen_len=gen_len,
|
||||
num_layers=real_config.n_layers,
|
||||
hidden_size=real_config.hidden_dim,
|
||||
intermediate_size=real_config.intermediate_dim,
|
||||
vocab_size=real_config.vocab_size,
|
||||
)
|
||||
flops = self.__rpc_ctrl.flops_counter.get_flops()
|
||||
tflops = flops / (e2e_time * (10**12))
|
||||
tflops_per_gpu = flops / (e2e_time * self.config.n_model_workers * (10**12))
|
||||
self.__rpc_ctrl.data_amount.clear()
|
||||
self.__rpc_ctrl.flops_counter.clear()
|
||||
#########################################
|
||||
|
||||
epoch = self.__rpc_ctrl.step_info.epoch + 1
|
||||
|
|
|
@ -105,7 +105,6 @@ class NoRequestToHandle(Exception):
|
|||
|
||||
class ModelWorker(worker_base.Worker):
|
||||
_setup_counter = -1
|
||||
_seed_offset = 0
|
||||
|
||||
def _configure(self, cfg: system_api.ModelWorker):
|
||||
self._setup_counter += 1
|
||||
|
@ -126,10 +125,7 @@ class ModelWorker(worker_base.Worker):
|
|||
|
||||
self.__worker_index = cfg.worker_info.worker_index
|
||||
|
||||
torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
|
||||
torch.backends.cudnn.deterministic = cfg.cudnn_deterministic
|
||||
|
||||
seeding.set_random_seed(cfg.seed)
|
||||
seeding.set_random_seed(cfg.base_seed + self.__worker_index)
|
||||
|
||||
# Reveal process group identity of this worker to world.
|
||||
gpu_utils.reveal_pg_identity(
|
||||
|
@ -308,7 +304,8 @@ class ModelWorker(worker_base.Worker):
|
|||
datasets = [
|
||||
data_api.make_dataset(
|
||||
d,
|
||||
self.config.seed,
|
||||
# NOTE: we must use the same seed to ensure the same dataset split
|
||||
self.config.base_seed,
|
||||
self.__dataset_dp_rank,
|
||||
self.__dataset_dp_size,
|
||||
self.config.tokenizer_name_or_path,
|
||||
|
@ -327,10 +324,20 @@ class ModelWorker(worker_base.Worker):
|
|||
else:
|
||||
self.__dataset = torch.utils.data.ConcatDataset(datasets)
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(seeding.get_seed())
|
||||
self.__dataloader = torch.utils.data.DataLoader(
|
||||
self.__dataset,
|
||||
collate_fn=data_api.SequenceSample.gather,
|
||||
# NOTE: This is *NOT* the actual batch size for training.
|
||||
# It is just a proper size to load data to workers.
|
||||
batch_size=10240,
|
||||
shuffle=True,
|
||||
generator=g,
|
||||
)
|
||||
|
||||
self.__raw_samples = []
|
||||
for tmp_sample in data_api.make_dataloader(
|
||||
self.config.dataloader, self.__dataset
|
||||
):
|
||||
for tmp_sample in self.__dataloader:
|
||||
self.__raw_samples += tmp_sample.meta().unpack()
|
||||
|
||||
self.__models: Dict[ModelName, model_api.Model] = dict()
|
||||
|
@ -406,30 +413,27 @@ class ModelWorker(worker_base.Worker):
|
|||
interface_impl[0]
|
||||
)
|
||||
|
||||
if s.eval_datasets is not None and s.eval_dataloader is not None:
|
||||
eval_datasets = [
|
||||
data_api.make_dataset(
|
||||
d,
|
||||
self.config.seed,
|
||||
s.id.dp_rank,
|
||||
s.id.topo.get_dim("data"),
|
||||
self.__models[s.id.model_name].tokenizer,
|
||||
self.config.worker_info.experiment_name,
|
||||
self.config.worker_info.trial_name,
|
||||
cache_root=(
|
||||
None
|
||||
if not self.config.use_dataset_cache
|
||||
else self.config.dataset_cahce_root
|
||||
),
|
||||
)
|
||||
for d in s.eval_datasets
|
||||
]
|
||||
if len(eval_datasets) > 1:
|
||||
eval_dataset = torch.utils.data.ConcatDataset(eval_datasets)
|
||||
else:
|
||||
eval_dataset = eval_datasets[0]
|
||||
eval_dataloader = data_api.make_dataloader(
|
||||
s.eval_dataloader, eval_dataset
|
||||
if s.eval_dataset is not None:
|
||||
eval_dataset = data_api.make_dataset(
|
||||
s.eval_dataset,
|
||||
# NOTE: we must use the same seed to ensure the same dataset split
|
||||
self.config.base_seed,
|
||||
s.id.dp_rank,
|
||||
s.id.topo.get_dim("data"),
|
||||
self.__models[s.id.model_name].tokenizer,
|
||||
self.config.worker_info.experiment_name,
|
||||
self.config.worker_info.trial_name,
|
||||
cache_root=(
|
||||
None
|
||||
if not self.config.use_dataset_cache
|
||||
else self.config.dataset_cahce_root
|
||||
),
|
||||
)
|
||||
eval_dataloader = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=s.eval_bs,
|
||||
collate_fn=data_api.SequenceSample.gather,
|
||||
shuffle=False,
|
||||
)
|
||||
else:
|
||||
eval_dataloader = None
|
||||
|
@ -601,10 +605,17 @@ class ModelWorker(worker_base.Worker):
|
|||
dataset_indices_path,
|
||||
self.__dataset.active_indices,
|
||||
)
|
||||
self.__dataloader = data_api.make_dataloader(
|
||||
self.config.dataloader, self.__dataset, seed_offset=self._seed_offset,
|
||||
g = torch.Generator()
|
||||
g = g.set_state(self.__dataloader.generator.get_state())
|
||||
self.__dataloader = torch.utils.data.DataLoader(
|
||||
self.__dataset,
|
||||
collate_fn=data_api.SequenceSample.gather,
|
||||
# NOTE: This is *NOT* the actual batch size for training.
|
||||
# It is just a proper size to load data to workers.
|
||||
batch_size=10240,
|
||||
shuffle=True,
|
||||
generator=g,
|
||||
)
|
||||
self._seed_offset += 1
|
||||
self.__data_generator = enumerate(self.__dataloader)
|
||||
|
||||
# Fetch.
|
||||
|
|
|
@ -156,7 +156,7 @@ class NameResolvingRequestClient:
|
|||
|
||||
def request(
|
||||
self,
|
||||
handlers: List[str] | None = None,
|
||||
handlers: List[str | int] | None = None,
|
||||
handle_type: str | None = None,
|
||||
datas: List[Any] | None = None,
|
||||
payloads: List[Payload] | None = None,
|
||||
|
@ -222,7 +222,7 @@ class NameResolvingRequestClient:
|
|||
|
||||
def call(
|
||||
self,
|
||||
handlers: List[str] | None = None,
|
||||
handlers: List[str | int] | None = None,
|
||||
handle_type: str | None = None,
|
||||
datas: List[Any] | None = None,
|
||||
payloads: List[Payload] | None = None,
|
||||
|
|
Loading…
Reference in New Issue