AReaL/realhf/api/core/data_api.py

909 lines
32 KiB
Python

# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import json
import os
import random
import time
from contextlib import contextmanager
# NOTE: We don't sue wildcard importing here because the type
# `Sequence` has a very similar name to `SequenceSample`.
# We don't want to confuse them.
from typing import (
Any,
Callable,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import numpy as np
import torch
import torch.distributed as dist
import torch.utils.data
import transformers
# NOTE: We only use pandatic dataclasses for SequenceSample
# such that it will perform automatic checks.
from pydantic import Field
from pydantic import dataclasses as pdclasses
from pydantic import field_validator, model_validator
from realhf.api.cli_args import MicroBatchSpec
from realhf.api.core import config as config_api
from realhf.base import constants, datapack, logging, seeding
from realhf.base.cluster import spec as cluster_spec
logger = logging.getLogger("api.data")
RL_TASKS = ["math", "code", "rlhf", "stem"]
def load_hf_tokenizer(
model_name_or_path: str,
fast_tokenizer=True,
padding_side: Optional[str] = None,
) -> transformers.PreTrainedTokenizerFast:
kwargs = {}
if padding_side is not None:
kwargs["padding_side"] = padding_side
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path,
fast_tokenizer=fast_tokenizer,
trust_remote_code=True,
force_download=True,
**kwargs,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer
@pdclasses.dataclass
class SequenceSplitSpec:
partitions: Optional[List[Tuple[int, int]]] = None
sizes: Optional[List[int]] = None
@model_validator(mode="after")
def _validate_partitions(self) -> "SequenceSplitSpec":
if self.partitions is not None:
bound = 0
for start, end in self.partitions:
if start >= end:
raise ValueError(f"Partition {start}-{end} is empty.")
if start != bound:
raise ValueError(f"Partition {start}-{end} is not contiguous.")
bound = end
if self.sizes is None and self.partitions is None:
raise ValueError("Either sizes or partitions must be provided.")
elif self.sizes is not None and self.partitions is not None:
if len(self.sizes) != len(self.partitions):
raise ValueError("Sizes and partitions are not the consistent.")
if self.sizes != [end - start for start, end in self.partitions]:
raise ValueError("Sizes and partitions are not the consistent.")
elif self.sizes is None:
self.sizes = [end - start for start, end in self.partitions]
elif self.partitions is None:
offsets = np.cumsum([0] + self.sizes)
self.partitions = [
(offsets[i], offsets[i + 1]) for i in range(len(self.sizes))
]
return self
@pdclasses.dataclass(config=dict(arbitrary_types_allowed=True))
class SequenceSample:
"""The data structure used to represent sequence data.
Each piece of data is assumed to have several "keys" (like a dictionary),
with each key potentially corresponding to multiple sequences.
For example, when running PPO, multiple responses can be generated for each prompt.
If there are 2 prompts, each with 3 responses, the batch might look like:
.. code-block:: console
>>> s = SequenceSample(...)
>>> s.keys
{'resp', 'prompt'}
>>> s.seqlens
{'prompt': [[13], [6]], 'resp': [[6, 17, 15], [13, 15, 13]]}
>>> s.data
{'prompt': torch.tensor([...]), 'resp': torch.tensor([...])}
Key points:
- Data with different batch indices can have varying lengths (e.g., the first prompt has a length of 13
while the second has a length of 6).
- A key (e.g., "response") can correspond to multiple sequences with different lengths.
Additionally, the number of sequences for each key can differ from the number of sequences for the data.
For example, the first prompt may have 2 responses, and the second may have 3.
- Regardless of the batch size or the number of sequences stored for each key,
the data is concatenated into a 1D tensor. The outer dimension represents the batch size,
and the inner dimension represents the number of sequences for the key.
This data structure facilitates easy gathering, splitting,
and transferring of non-padded batches between different GPUs.
:param keys: The keys of the data.
:type keys: Set[str]
:param trailing_shapes: The trailing shapes of the data,
excluding the first dimension, which must be the sequence length.
Used to construct the receiving buffer for data transfer.
:type trailing_shapes: Dict[str, torch.Size | Tuple | None]
:param dtypes: The types of the data. Used to construct
the receiving buffer for data transfer.
:type dtypes: Dict[str, torch.dtype | None]
:param ids: Unique identifiers for each piece of data.
Should be provided in the dataset implementation.
Used to append new data to the buffer after a model function call.
:type ids: List[Hashable]
:param seqlens: The sequence lengths of each sequence in the data. For a given key,
this should be a list of lists of integers. The outer list represents the batch size,
while the inner lists represent the sequence lengths for this key.
Python-native lists are used here because (1) pickling torch.Tensor or numpy array is inefficient,
and (2) the size of the inner lists can vary across the batch, making 2D arrays impractical.
:type seqlens: Dict[str, List[List[int]]]
:param data: The actual concatenated data. If this is None,
the sample is a metadata-only sample used by the master worker.
The specification of the data should be consistent with the seqlens,
dtypes, and trailing_shapes.
:type data: Optional[Dict[str, torch.Tensor | None]]
:param metadata: Metadata for the sample. It should be a
dictionary of lists, provided in the dataset implementation.
Note that adding metadata can slow down data transfer.
:type metadata: Dict[str, List[Any]]
"""
keys: Set[str]
trailing_shapes: Dict[str, torch.Size | Tuple | None]
dtypes: Dict[str, torch.dtype | None]
ids: List[Hashable]
seqlens: Dict[str, List[List[int]]]
data: Optional[Dict[str, torch.Tensor | None]] = None
metadata: Dict[str, List[Any]] = Field(default_factory=dict)
@field_validator("ids")
@classmethod
def _validate_ids(cls, ids: List[Hashable]) -> List[str]:
ids = list(map(str, ids))
if len(ids) != len(set(ids)):
raise ValueError(f"IDs contain duplicates: {ids}.")
return ids
@field_validator("trailing_shapes")
@classmethod
def _validate_trailing_shapes(
cls, trailing_shapes: Dict
) -> Dict[str, Tuple | None]:
for k, v in trailing_shapes.items():
if v is not None:
trailing_shapes[k] = tuple(v)
return trailing_shapes
@field_validator("keys")
@classmethod
def _validate_keys_type(cls, keys: Iterable) -> Set[str]:
keys_ = set(keys)
if len(keys_) != len(keys):
raise ValueError(f"Keys contain duplicates: {keys}.")
return keys_
@field_validator("seqlens")
@classmethod
def _validate_seqlens_device_dtype(
cls, seqlens: Dict[str, List[torch.Tensor]]
) -> Dict[str, List[torch.Tensor]]:
for k, lens in seqlens.items():
assert isinstance(lens, list)
assert all(isinstance(l, list) for l in lens)
for i, lens_ in enumerate(lens):
assert all(isinstance(l_, int) for l_ in lens_)
return seqlens
@model_validator(mode="after")
def _validate_list_length(self) -> "SequenceSample":
cond = True
l = len(self.ids)
cond &= all(len(lens) == l for lens in self.seqlens.values())
if not cond:
raise ValueError(
f"Lengths of ids({len(self.ids)})"
f"/seqlens({self.seqlens}) "
"are not the same."
)
return self
@model_validator(mode="after")
def _validate_keys(self) -> "SequenceSample":
cond = True
cond &= self.keys == set(self.seqlens.keys())
cond &= self.keys == set(self.trailing_shapes.keys())
cond &= self.keys == set(self.dtypes.keys())
if self.data is not None:
cond &= self.keys == set(self.data.keys())
if not cond:
err = (
f"Keys are mismatched. "
f"keys={self.keys}, "
f"seqlens keys={set(self.seqlens.keys())}, "
f"trailing_shapes keys={set(self.trailing_shapes.keys())}, "
f"dtypes keys={set(self.dtypes.keys())}"
)
if self.data is not None:
err += f", data keys={set(self.data.keys())}"
raise KeyError(err)
return self
@model_validator(mode="after")
def _validate_shapes(self) -> "SequenceSample":
if self.data is None:
return self
acc_seqlen = {
k: sum(sum(lens) for lens in lens_list)
for k, lens_list in self.seqlens.items()
}
for k, v in self.data.items():
if v is None:
continue
if v.shape != (acc_seqlen[k], *self.trailing_shapes[k]):
raise ValueError(
f"Key: {k}, Data shape {v.shape} does not match "
f"configured shape {(acc_seqlen[k], *self.trailing_shapes[k])}."
)
return self
@model_validator(mode="after")
def _validate_dtypes(self) -> "SequenceSample":
if self.data is None:
return self
for k, v in self.data.items():
if v is None:
continue
if v.dtype != self.dtypes[k]:
raise ValueError(
f"Data dtype {v.dtype} "
f"does not match configured "
f"dtype {self.dtypes[k]}."
)
return self
@classmethod
def gather(cls, samples: List["SequenceSample"], keys: Optional[List[str]] = None):
"""Gather a list of SequenceSample objects into a single batch.
:param samples: A list of SequenceSample objects to be gathered.
:type samples: List[SequenceSample]
:param keys: The keys to be gathered. Only a subset of keys can
be gathered. If None, the keys from the first sample will be
used.
:type keys: Optional[List[str]]
"""
if keys is None:
keys = samples[0].keys
else:
keys = set(keys)
seqlens = {k: sum([s.seqlens[k] for s in samples], []) for k in keys}
if samples[0].data is not None:
data = {
k: (
torch.cat([s.data[k] for s in samples], dim=0)
if samples[0].data[k] is not None
else None
)
for k in keys
}
else:
data = None
id_ = sum([s.ids for s in samples], [])
metadata = {
k: sum([s.metadata[k] for s in samples], []) for k in samples[0].metadata
}
with cls.disable_validation():
return cls(
keys=keys,
dtypes={key: samples[0].dtypes[key] for key in keys},
trailing_shapes={key: samples[0].trailing_shapes[key] for key in keys},
ids=id_,
seqlens=seqlens,
data=data,
metadata=metadata,
)
def _get_split_key(self) -> str:
acc_seqlen = {k: sum(sum(l) for l in lens) for k, lens in self.seqlens.items()}
return max(acc_seqlen, key=acc_seqlen.get)
def split_with_spec(self, spec: SequenceSplitSpec) -> List["SequenceSample"]:
"""Split the data according to the given spec."""
samples = []
data_offset = {k: 0 for k in self.keys}
for start, end in spec.partitions:
new_seqlens = {
k: lens_list[start:end] for k, lens_list in self.seqlens.items()
}
_data_len = {
k: sum(sum(lens) for lens in lens_list)
for k, lens_list in new_seqlens.items()
}
if self.data is not None:
new_data = {
k: (
v[data_offset[k] : _data_len[k] + data_offset[k]]
if v is not None
else None
)
for k, v in self.data.items()
}
else:
new_data = None
for k in self.keys:
data_offset[k] += _data_len[k]
new_id = self.ids[start:end]
for k, v in self.metadata.items():
if not isinstance(v, list):
raise ValueError(
f"Unknown how to split non-list metadata: ({k}, {v})."
)
with self.disable_validation():
samples.append(
SequenceSample(
dtypes=self.dtypes,
trailing_shapes=self.trailing_shapes,
keys=self.keys,
ids=new_id,
seqlens=new_seqlens,
data=new_data,
metadata={k: v[start:end] for k, v in self.metadata.items()},
)
)
return samples
def split_with_lengths(
self, mb_spec: MicroBatchSpec, lens: List[int]
) -> Tuple[List["SequenceSample"], List[int] | np.ndarray, List[int] | np.ndarray]:
group_indices = datapack.ffd_allocate(
lens, mb_spec.max_tokens_per_mb, min_groups=mb_spec.n_mbs
)
group_indices = sorted([sorted(g) for g in group_indices])
forward_indices = datapack.flat2d(group_indices)
sample = SequenceSample.reorder(self, forward_indices)
backward_indices = np.zeros(self.bs, dtype=np.int64)
backward_indices[forward_indices] = np.arange(self.bs)
spec = SequenceSplitSpec(sizes=[len(group) for group in group_indices])
return sample.split_with_spec(spec), forward_indices, backward_indices
def split(
self, mb_spec: MicroBatchSpec
) -> Tuple[List["SequenceSample"], List[int] | np.ndarray, List[int] | np.ndarray]:
"""Split the data into `n_mbs` parts.
:param mb_spec: The configuration to split the data into.
`n_mbs` is the minimum number of micro-batches,
`max_tokens_per_mb` is the maximum number of tokens in each micro-batch.
If `max_tokens_per_mb` is a large value, defaults to balanced split.
:type mb_spec: MicroBatchSpec
"""
lens = [sum(lens) for lens in self.seqlens[self._get_split_key()]]
return self.split_with_lengths(mb_spec, lens)
def synced_data_parallel_split(
self, mb_spec: MicroBatchSpec
) -> List["SequenceSample"]:
mb_inputs, *_ = self.split(mb_spec)
all_n_mbs = [None for _ in range(constants.data_parallel_world_size())]
dist.all_gather_object(
all_n_mbs, len(mb_inputs), group=constants.data_parallel_group()
)
if all(mbs == len(mb_inputs) for mbs in all_n_mbs):
return mb_inputs
# This method is called when max_tokens_per_mb is given and during training.
# In this case, we evenly partition sequences across DP ranks,
# so the recursion will always terminate when n_mbs = bs // dp_size
return self.synced_data_parallel_split(
MicroBatchSpec.new(mb_spec, n_mbs=max(all_n_mbs))
)
@staticmethod
def reorder(
sample: "SequenceSample", indices: List[int] | np.ndarray
) -> "SequenceSample":
assert set(list(indices)) == set(range(sample.bs))
samples = sample.unpack()
return SequenceSample.gather([samples[i] for i in indices])
@staticmethod
def reorder_output(
x: torch.Tensor,
expected_seqlens: List[List[int]],
forward_indices: List[int] | np.ndarray,
backward_indices: List[int] | np.ndarray,
) -> torch.Tensor:
assert len(forward_indices) == len(backward_indices) == len(expected_seqlens)
actual_seqlens = [expected_seqlens[i] for i in forward_indices]
group_seqlens = [sum(s) for s in actual_seqlens]
assert x.shape[0] == sum(group_seqlens), (
x.shape[0],
group_seqlens,
len(group_seqlens),
sum(group_seqlens),
)
offsets = [0] + np.cumsum(group_seqlens, axis=0).tolist()
mbs = [x[s:e] for s, e in zip(offsets[:-1], offsets[1:])]
return torch.cat([mbs[i] for i in backward_indices])
def unpack(self):
"""Unpack a batch of data into individual pieces of data."""
partitions = [(i, i + 1) for i in range(self.bs)]
return self.split_with_spec(SequenceSplitSpec(partitions=partitions))
def cuda(self):
return self.to_device("cuda")
def cpu(self):
return self.to_device("cpu")
def to_device(self, device: torch.device):
"""Move the data to device inplace."""
if self.data is None:
return self
self.data = {
k: v.to(device) if v is not None else None for k, v in self.data.items()
}
return self
@property
def bs(self):
"""The batch size or the number of data pieces in the sample."""
return len(self.ids)
def meta(self) -> "SequenceSample":
"""Create a new SequenceSample that does not contain any data."""
with self.disable_validation():
return SequenceSample(
keys=self.keys,
trailing_shapes=self.trailing_shapes,
dtypes=self.dtypes,
ids=self.ids,
data=None,
seqlens=self.seqlens,
metadata=self.metadata,
)
def update_(self, other: "SequenceSample"):
"""Inplace update data from another SequenceSample.
Used to amend newly produced data after a model function call.
"""
self.keys = self.keys.union(other.keys)
self.trailing_shapes.update(other.trailing_shapes)
self.dtypes.update(other.dtypes)
assert self.ids == other.ids, (self.ids, other.ids)
if self.data is not None:
self.data.update(other.data)
self.seqlens.update(other.seqlens)
self.metadata.update(other.metadata)
@staticmethod
def shuffled(sample: "SequenceSample") -> "SequenceSample":
"""Create a shuffled sample.
Define it as a staticmethod because it is an out-of-place operation.
(Think about the difference between `sorted` and `l.sort()`).
"""
seed = seeding.get_shuffle_seed()
rng = np.random.RandomState(seed)
indices = np.arange(sample.bs)
rng.shuffle(indices)
return SequenceSample.reorder(sample, indices)
@staticmethod
def _resolve_seqlen_from_key(key, seqlens: List[int]) -> List[torch.Tensor]:
if key in [
"seq_no_eos_mask",
"greedy_seq_no_eos_mask",
"loss_mask",
"rewards",
"greedy_rewards",
"base_scores",
"task_ids",
]:
return [[1] for _ in seqlens]
elif key in [
"input_ids",
"packed_seq",
"seq",
"packed_logits_mask",
"logits_mask",
"prompt_mask",
"greedy_prompt_mask",
"packed_input_ids",
"greedy_packed_input_ids",
"values",
"packed_prompts",
]:
return [[seqlen] for seqlen in seqlens]
elif key in [
"packed_logprobs",
"prox_logp",
"logprobs",
"packed_ref_logprobs",
"ref_logprobs",
"old_logp",
"ref_logp",
"advantages",
"ppo_loss_mask",
"kl_rewards",
"returns",
]:
return [[seqlen - 1] for seqlen in seqlens]
else:
raise NotImplementedError(
f"Seqlen could not be resolved given key {key}. "
f"Please explicltly construct the `SequenceSample` object"
" without using the `from_default` method."
)
@classmethod
def from_default(
cls,
seqlens: List[int],
ids: List[Hashable],
data: Dict[str, torch.Tensor],
metadata: Optional[Dict[str, Any]] = None,
):
"""Construct a `SequenceSample` object from default parameters.
This helper function is intended for cases where each piece of data has
a single sequence length (e.g., a single response for each prompt).
The sequence lengths for different keys are resolved automatically
according to the rules in ``_resolve_seqlen_from_key``. While this function
can reduce boilerplate code, it may introduce potential bugs, so it should
be used with caution.
:param seqlens: The sequence lengths of each piece of data. This represents
the length of the main attribute (e.g., `packed_input_ids`). Sequence lengths
for other attributes (e.g., rewards and logprobs) are computed from this parameter.
It is **NOT** the actual length of rewards or logprobs even if it is the only key
in the data.
:type seqlens: List[int]
:param ids: Unique identifiers for each piece of data.
:type ids: List[Hashable]
:param data: The actual data.
:type data: Dict[str, torch.Tensor]
:param metadata: Metadata for the sample. Should be a dictionary where each value
is a list with a length equal to the number of sequence lengths.
:type metadata: Optional[Dict[str, Any]]
"""
if metadata is None:
metadata = {}
for k, v in metadata.items():
if not isinstance(v, list) or len(v) != len(seqlens):
raise ValueError(
f"Metadata `{k}` should be a list of length {len(seqlens)}: {v}."
)
keys = set(data.keys())
if isinstance(seqlens[0], list):
assert len(seqlens[0]) == 1
seqlens = [seqlen[0] for seqlen in seqlens]
else:
assert all(isinstance(seqlen, int) for seqlen in seqlens)
seqlens = {key: cls._resolve_seqlen_from_key(key, seqlens) for key in keys}
trailing_shapes = {
key: data[key].shape[1:] if data[key] is not None else None for key in keys
}
dtypes = {
key: data[key].dtype if data[key] is not None else None for key in keys
}
return cls(
keys=keys,
ids=ids,
seqlens=seqlens,
trailing_shapes=trailing_shapes,
dtypes=dtypes,
data=data,
metadata=metadata,
)
def select(self, keys: List[str]):
"""Select a subset of keys inside the SequenceSample."""
with self.disable_validation():
keys = set(keys)
return SequenceSample(
keys=keys,
dtypes={key: self.dtypes[key] for key in keys},
trailing_shapes={key: self.trailing_shapes[key] for key in keys},
ids=self.ids,
seqlens={key: self.seqlens[key] for key in keys},
data=(
None if self.data is None else {key: self.data[key] for key in keys}
),
metadata=self.metadata,
)
def remap_keys_(self, remap: Dict[str, str]):
"""Inplace remap keys of the data.
Useful for reusing the same interface implementation in
different algorithms, where the data can be named differently.
"""
for k in self.keys:
if k in remap:
new_k = remap[k]
self.seqlens[new_k] = self.seqlens.pop(k)
self.trailing_shapes[new_k] = self.trailing_shapes.pop(k)
self.dtypes[new_k] = self.dtypes.pop(k)
if self.data is not None:
self.data[new_k] = self.data.pop(k)
self.keys = set(remap.get(k, k) for k in self.keys)
@classmethod
@contextmanager
def disable_validation(cls):
"""Disable the expensive pydantic validation within this context.
Used to accelerate gather/split/transfer operations since we
have ensured that the data created in datasets and interfaces
are valid.
"""
original_init = cls.__init__
def no_validation_init(self, *args, **kwargs):
kwargs["keys"] = set(kwargs["keys"])
self.__dict__.update(kwargs)
cls.__init__ = no_validation_init
try:
yield
finally:
cls.__init__ = original_init
def as_json_compatible(self) -> Dict:
return dict(
ids=self.ids,
keys=list(self.keys),
trailing_shapes={
k: tuple(v) if v is not None else None
for k, v in self.trailing_shapes.items()
},
dtypes={k: str(v) if v is not None else v for k, v in self.dtypes.items()},
seqlens=self.seqlens,
data={
k: v.cpu().numpy().tolist() if v is not None else None
for k, v in self.data.items()
},
metadata=self.metadata,
)
@classmethod
def from_json_compatible(cls, data: Dict):
dtypes = {}
for k, dtype_str in data["dtypes"].items():
if dtype_str is not None:
dtypes[k] = getattr(torch, dtype_str.split(".")[1])
else:
dtypes[k] = None
return cls(
ids=data["ids"],
keys=set(data["keys"]),
trailing_shapes=data["trailing_shapes"],
dtypes=dtypes,
seqlens=data["seqlens"],
data={
k: torch.tensor(v, dtype=dtypes[k]) if v is not None else v
for k, v in data["data"].items()
},
metadata=data["metadata"],
)
@dataclasses.dataclass
class DataBatchMeta:
dp_rank: int
meta_sample: SequenceSample | None
@dataclasses.dataclass
class DatasetUtility:
seed: int
dp_rank: int
world_size: int
tokenizer: transformers.PreTrainedTokenizerFast
def __post_init__(self):
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
if self.tokenizer.eos_token_id is None:
raise ValueError("eos_token_id of tokenizer must be defined.")
def get_shuffle_indices(seed: int, size: int):
"""Generate shuffled indices given seed and (dataset) size."""
np_rng = np.random.RandomState(seed=seed)
dtype_ = np.uint32
if size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx)
return shuffle_idx
def load_shuffle_split_dataset(
util: DatasetUtility,
dataset_path: str,
dataset_builder: Optional[Callable[[], List[Dict[str, str]]]] = None,
):
if dataset_path is not None:
if dataset_path.endswith(".jsonl"):
with open(dataset_path, "r") as f:
data = [json.loads(ff) for ff in f]
else:
raise NotImplementedError(f"Unknown dataset extension: {dataset_path}")
else:
assert dataset_builder is not None
data = dataset_builder()
if any("id" not in d for d in data):
logger.warning(
f'Key "id" not found in the dataset. Use indices as dataset IDs.'
)
for idx, d in enumerate(data):
d["id"] = idx
# NOTE: in the original way of seperating data, there is a chance that some DP rank
# get no data, which will raise error in the dataset tokenizer.
assert (
len(data) >= util.world_size
), "Dataset size must not be smaller than data parallel world size."
bins = np.zeros(util.world_size, dtype=np.int64)
for idx, d in enumerate(data):
bins[idx % util.world_size] += 1
dp_indices = np.pad(np.cumsum(bins), (1, 0))
shuffle_indices = get_shuffle_indices(util.seed, len(data))
subset_indices = shuffle_indices[
dp_indices[util.dp_rank] : dp_indices[util.dp_rank + 1]
]
data: List[Dict[str, str]] = [data[i] for i in subset_indices]
return data
ALL_DATASET_CLASSES = {}
def register_dataset(name, dataset_cls):
assert name not in ALL_DATASET_CLASSES
assert "/" not in name
ALL_DATASET_CLASSES[name] = dataset_cls
def make_dataset(
cfg: Union[str, config_api.DatasetAbstraction],
seed: int,
dp_rank: int,
world_size: int,
tokenizer_or_tokenizer_name: Union[transformers.PreTrainedTokenizerFast, str],
experiment_name: str,
trial_name: str,
cache_root: Optional[str] = None,
) -> torch.utils.data.Dataset:
if isinstance(cfg, str):
cfg = config_api.DatasetAbstraction(type_=cfg)
if isinstance(tokenizer_or_tokenizer_name, str):
tokenizer = load_hf_tokenizer(tokenizer_or_tokenizer_name)
elif tokenizer_or_tokenizer_name is None:
raise RuntimeError("tokenizer_or_tokenizer_name cannot be None.")
else:
tokenizer = tokenizer_or_tokenizer_name
util = DatasetUtility(
seed,
dp_rank,
world_size,
tokenizer,
)
if cache_root is None:
dataset_cls = ALL_DATASET_CLASSES[cfg.type_]
return dataset_cls(util=util, **cfg.args)
# Create and check cache path.
if not cache_root.startswith(cluster_spec.fileroot) and not cache_root.startswith(
"/home"
):
raise ValueError(
f"Data cache path {cache_root} should be /home or under {cluster_spec.fileroot}."
)
if "_" in experiment_name or "_" in trial_name:
raise ValueError(f"Invalid experiment/trial name.")
output_path = os.path.join(
cache_root,
experiment_name,
trial_name,
cfg.type_,
f"seed{seed}",
f"world_size{world_size}",
f"rank{dp_rank}",
)
os.makedirs(output_path, exist_ok=True)
fname = "dataset.pt"
cache_found = os.path.isfile(os.path.join(output_path, fname))
tik = time.perf_counter()
if not cache_found:
logger.info(f"No data cache found for rank {dp_rank}. Create it from scratch.")
dataset = ALL_DATASET_CLASSES[cfg.type_](seed, dp_rank, world_size, **cfg.args)
torch.save(dataset, os.path.join(output_path, fname))
else:
logger.info(f"Rank {dp_rank} find existing data cache, load it.")
dataset = torch.load(os.path.join(output_path, fname))
logger.info(f"Dataset creation/loading time: {time.perf_counter() - tik:.3f}s")
return dataset
def gather_stat(src: List[Dict]) -> Dict:
cnt, stats = {}, {}
for reply in src:
# FIXME: understand why the reply can be None
if not reply:
continue
for k, v in reply.items():
cnt[k] = cnt.get(k, 0) + 1
stats[k] = stats.get(k, 0) + v
res = {k: v / cnt for k, v, cnt in zip(stats.keys(), stats.values(), cnt.values())}
for k, c in cnt.items():
if c != len(src):
logger.warning(f"Gathered `{k}` is not present in every returned stats.")
for k, v in res.items():
if any(abs(v - x.get(k, None)) > 1e-4 for x in src):
logger.warning(
f"Gathered `{k}` is not all-reduced "
f"before returning: ({[x.get(k, None) for x in src]}, {v})."
)
return res
def tabulate_stats(data: Dict[str, float], col=4, floatfmt=".4e") -> str:
from tabulate import tabulate
items = list(data.items())
# Calculate how many rows we'll need
row_count = (len(items) + col - 1) // col
# Reorganize items in column-major order
column_major = []
for i in range(row_count):
row = []
for j in range(col):
index = i + j * row_count
if index < len(items):
row.extend(items[index])
column_major.append(row)
return tabulate(column_major, floatfmt=floatfmt, tablefmt="fancy_grid")