mirror of https://github.com/inclusionAI/AReaL
PullRequest: 5 修改微批次分割逻辑
Merge branch fw/balanced-datapck of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/5 Signed-off-by: 温差 <xushusheng.xss@antgroup.com> * . * fw/fix-dataloading-not-shuffle * . * . * . * . * .
This commit is contained in:
parent
ceee49454a
commit
46c5a10eb9
|
@ -109,14 +109,10 @@ class MicroBatchSpec:
|
|||
:param max_tokens_per_mb: The maximum number of tokens per micro-
|
||||
batch.
|
||||
:type max_tokens_per_mb: Optional[int]
|
||||
:param balanced_seqs: Whether to balance the number of sequences per
|
||||
micro-batch. Only effective when max_tokens_per_mb is None.
|
||||
:type balanced_seqs: bool, optional
|
||||
"""
|
||||
|
||||
n_mbs: int = 1
|
||||
max_tokens_per_mb: int | None = None
|
||||
balanced_seqs: bool = False
|
||||
max_tokens_per_mb: int = int(1e12)
|
||||
|
||||
@classmethod
|
||||
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
|
||||
|
@ -124,7 +120,6 @@ class MicroBatchSpec:
|
|||
fields = dict(
|
||||
n_mbs=mb_spec.n_mbs,
|
||||
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
|
||||
balanced_seqs=mb_spec.balanced_seqs,
|
||||
)
|
||||
fields.update(kwargs)
|
||||
return cls(**fields)
|
||||
|
@ -350,30 +345,6 @@ class SequenceSample:
|
|||
acc_seqlen = {k: sum(sum(l) for l in lens) for k, lens in self.seqlens.items()}
|
||||
return max(acc_seqlen, key=acc_seqlen.get)
|
||||
|
||||
def get_split_spec(
|
||||
self, k: int, key: Optional[str] = None, min_size: int = 1
|
||||
) -> SequenceSplitSpec:
|
||||
"""Get the partition specification for splitting the data into `k`
|
||||
parts using a dynamic programming algorithm to achieve the most
|
||||
balanced partitioning.
|
||||
|
||||
:param k: The number of parts to split the data into.
|
||||
:type k: int
|
||||
:param key: The key to be used for splitting. If None, the key
|
||||
with the largest total sequence length will be used.
|
||||
:type key: Optional[str]
|
||||
:param min_size: The minimum size of each partition.
|
||||
:type min_size: int
|
||||
:return: A SequenceSplitSpec object representing the
|
||||
partitioning specification.
|
||||
:rtype: SequenceSplitSpec
|
||||
"""
|
||||
if key is None:
|
||||
key = self._get_split_key()
|
||||
lens = [sum(lens) for lens in self.seqlens[key]]
|
||||
partitions = datapack.min_abs_diff_partition(lens, k, min_size)
|
||||
return SequenceSplitSpec(partitions=partitions)
|
||||
|
||||
def split_with_spec(self, spec: SequenceSplitSpec) -> List["SequenceSample"]:
|
||||
"""Split the data according to the given spec."""
|
||||
samples = []
|
||||
|
@ -419,47 +390,9 @@ class SequenceSample:
|
|||
)
|
||||
return samples
|
||||
|
||||
def split(
|
||||
self,
|
||||
k: int,
|
||||
key: Optional[str] = None,
|
||||
min_size: int = 1,
|
||||
) -> List["SequenceSample"]:
|
||||
"""Split the data into `k` parts.
|
||||
|
||||
This method uses the specified key or the key with the largest total sequence length
|
||||
to split the data into `k` parts. The partitioning ensures that each part meets the
|
||||
minimum size requirement.
|
||||
|
||||
:param k: The number of parts to split the data into.
|
||||
:type k: int
|
||||
:param key: The key to use for splitting. If None, the key with the largest
|
||||
total sequence length will be used.
|
||||
:type key: Optional[str]
|
||||
:param min_size: The minimum size of each partition.
|
||||
:type min_size: int
|
||||
:return: A list of `SequenceSample` objects, each representing a part of the split data.
|
||||
:rtype: List[SequenceSample]
|
||||
"""
|
||||
spec = self.get_split_spec(k, key, min_size)
|
||||
return self.split_with_spec(spec)
|
||||
|
||||
def divide_into_mbs(
|
||||
self, mb_spec: MicroBatchSpec
|
||||
def split_with_lengths(
|
||||
self, mb_spec: MicroBatchSpec, lens: List[int]
|
||||
) -> Tuple[List["SequenceSample"], List[int] | np.ndarray, List[int] | np.ndarray]:
|
||||
if mb_spec.max_tokens_per_mb is None:
|
||||
return (
|
||||
self.split(
|
||||
mb_spec.n_mbs,
|
||||
min_size=(
|
||||
1 if not mb_spec.balanced_seqs else self.bs // mb_spec.n_mbs
|
||||
),
|
||||
),
|
||||
np.arange(self.bs),
|
||||
np.arange(self.bs),
|
||||
)
|
||||
|
||||
lens = [sum(lens) for lens in self.seqlens[self._get_split_key()]]
|
||||
group_indices = datapack.ffd_allocate(
|
||||
lens, mb_spec.max_tokens_per_mb, min_groups=mb_spec.n_mbs
|
||||
)
|
||||
|
@ -474,10 +407,24 @@ class SequenceSample:
|
|||
|
||||
return sample.split_with_spec(spec), forward_indices, backward_indices
|
||||
|
||||
def divide_into_mbs_balanced(
|
||||
def split(
|
||||
self, mb_spec: MicroBatchSpec
|
||||
) -> Tuple[List["SequenceSample"], List[int] | np.ndarray, List[int] | np.ndarray]:
|
||||
"""Split the data into `n_mbs` parts.
|
||||
|
||||
:param mb_spec: The configuration to split the data into.
|
||||
`n_mbs` is the minimum number of micro-batches,
|
||||
`max_tokens_per_mb` is the maximum number of tokens in each micro-batch.
|
||||
If `max_tokens_per_mb` is a large value, defaults to balanced split.
|
||||
:type mb_spec: MicroBatchSpec
|
||||
"""
|
||||
lens = [sum(lens) for lens in self.seqlens[self._get_split_key()]]
|
||||
return self.split_with_lengths(mb_spec, lens)
|
||||
|
||||
def synced_data_parallel_split(
|
||||
self, mb_spec: MicroBatchSpec
|
||||
) -> List["SequenceSample"]:
|
||||
mb_inputs, *_ = self.divide_into_mbs(mb_spec)
|
||||
mb_inputs, *_ = self.split(mb_spec)
|
||||
all_n_mbs = [None for _ in range(constants.data_parallel_world_size())]
|
||||
dist.all_gather_object(
|
||||
all_n_mbs, len(mb_inputs), group=constants.data_parallel_group()
|
||||
|
@ -487,7 +434,7 @@ class SequenceSample:
|
|||
# This method is called when max_tokens_per_mb is given and during training.
|
||||
# In this case, we evenly partition sequences across DP ranks,
|
||||
# so the recursion will always terminate when n_mbs = bs // dp_size
|
||||
return self.divide_into_mbs_balanced(
|
||||
return self.synced_data_parallel_split(
|
||||
MicroBatchSpec.new(mb_spec, n_mbs=max(all_n_mbs))
|
||||
)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import heapq
|
||||
import itertools
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
|
@ -148,28 +149,36 @@ def reorder_to_balanced_batches(
|
|||
return np.array(reordered_indices), max_diff
|
||||
|
||||
|
||||
@numba.njit
|
||||
# @numba.njit
|
||||
def _ffd_allocate(
|
||||
values: np.ndarray, capacity: int, min_groups: int
|
||||
) -> List[List[int]]:
|
||||
"""A greedy allocation algorithm that partitions a list of numbers
|
||||
into k groups, where the summation of each group is less than capacity
|
||||
and k >= min_groups. We want to minimize k and make partitions as balanced
|
||||
as possible.
|
||||
|
||||
1. Sort the numbers in reverse order.
|
||||
2. If the number of groups is less than, create a new group.
|
||||
3. If the new number fits into the smallest group, add it into the group.
|
||||
4. Otherwise, create a new group.
|
||||
"""
|
||||
value_indices = np.argsort(-values)
|
||||
group_indices = []
|
||||
group_values = []
|
||||
group_indices: List[List[int]] = []
|
||||
group_values: List[Tuple[float, int]] = []
|
||||
group_cnt = 0
|
||||
for idx in value_indices:
|
||||
if len(group_values) < min_groups:
|
||||
group_values.append(values[idx])
|
||||
group_indices.append([idx])
|
||||
continue
|
||||
placed = False
|
||||
for i in range(len(group_values)):
|
||||
if group_values[i] + values[idx] <= capacity:
|
||||
group_values[i] += values[idx]
|
||||
group_indices[i].append(idx)
|
||||
placed = True
|
||||
break
|
||||
if not placed:
|
||||
group_values.append(values[idx])
|
||||
if (
|
||||
len(group_values) < min_groups
|
||||
or group_values[0][0] + values[idx] > capacity
|
||||
):
|
||||
heapq.heappush(group_values, (float(values[idx]), group_cnt))
|
||||
group_indices.append([idx])
|
||||
group_cnt += 1
|
||||
else:
|
||||
v, group_idx = heapq.heappop(group_values)
|
||||
heapq.heappush(group_values, (float(v + values[idx]), group_idx))
|
||||
group_indices[group_idx].append(idx)
|
||||
return group_indices
|
||||
|
||||
|
||||
|
@ -188,15 +197,15 @@ if __name__ == "__main__":
|
|||
|
||||
for i in range(100):
|
||||
st = time.monotonic()
|
||||
nums = np.random.randint(512, 4000, size=(32768,))
|
||||
nums = np.random.randint(1024, 8192, size=(100,))
|
||||
# k = np.random.randint(2, 20)
|
||||
# min_size = np.random.randint(1, len(nums) // k)
|
||||
# res = min_abs_diff_partition(nums, k, min_size)
|
||||
# assert all(y - x >= min_size for x, y in res)
|
||||
max_tokens_per_mb = 655360
|
||||
n_groups = np.random.randint(10, 20)
|
||||
groups = ffd_allocate(nums, max_tokens_per_mb, n_groups)
|
||||
assert len(groups) >= n_groups
|
||||
max_tokens_per_mb = 163840
|
||||
min_n_groups = np.random.randint(1, 8)
|
||||
groups = ffd_allocate(nums, max_tokens_per_mb, min_n_groups)
|
||||
assert len(groups) >= min_n_groups
|
||||
import itertools
|
||||
|
||||
indices = list(itertools.chain(*groups))
|
||||
|
@ -207,6 +216,8 @@ if __name__ == "__main__":
|
|||
|
||||
print(
|
||||
len(groups),
|
||||
min_n_groups,
|
||||
[sum(nums[i] for i in group) for group in groups],
|
||||
max(group_percent),
|
||||
min(group_percent),
|
||||
np.mean(group_percent),
|
||||
|
|
|
@ -82,11 +82,6 @@ def check_valid_parallel_batch_size(rpc_alloc: RPCAllocation):
|
|||
factor = 1
|
||||
if rpc.is_train() and rpc_alloc.parallel.pipeline_parallel_size > 1:
|
||||
factor = 2
|
||||
if mb_spec.balanced_seqs or (
|
||||
mb_spec.max_tokens_per_mb is not None and rpc.is_train()
|
||||
):
|
||||
assert rpc.n_seqs % dp_size == 0, (rpc.n_seqs, dp_size)
|
||||
return
|
||||
|
||||
assert (
|
||||
rpc.n_seqs
|
||||
|
|
|
@ -103,7 +103,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
|
|||
post_hook=post_hook,
|
||||
aggregate_fn=aggregate_fn,
|
||||
)
|
||||
mb_inputs, fwd_indices, bwd_indices = input_.divide_into_mbs(mb_spec)
|
||||
mb_inputs, fwd_indices, bwd_indices = input_.split(mb_spec)
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
|
@ -161,14 +161,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
|
|||
# NOTE: Interleave mini-batches in the pipeline results will not decrease
|
||||
# the memory usage, because we need to hold all KV-caches for different
|
||||
# mini-batches, so we split mini-batches in the outer loop.
|
||||
if mb_spec.max_tokens_per_mb is not None and constants.parallelism_rank() == 0:
|
||||
logger.warning(
|
||||
"Generation will ignore max_tokens_per_mb because the length is not predictable."
|
||||
)
|
||||
mb_spec = MicroBatchSpec.new(
|
||||
mb_spec, max_tokens_per_mb=None, balanced_seqs=True
|
||||
)
|
||||
mb_inputs, *_ = input_.divide_into_mbs(mb_spec)
|
||||
mb_inputs, *_ = input_.split(mb_spec)
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
|
|
|
@ -751,7 +751,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
with megatron_ctx():
|
||||
self.engine.zero_grad()
|
||||
if constants.pipe_parallel_world_size() > 1:
|
||||
mb_inputs = input_.divide_into_mbs_balanced(
|
||||
mb_inputs = input_.synced_data_parallel_split(
|
||||
MicroBatchSpec.new(
|
||||
mb_spec,
|
||||
n_mbs=mb_spec.n_mbs * self.pipe_runner.default_train_mbs,
|
||||
|
@ -771,7 +771,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
version_steps=version_steps,
|
||||
)
|
||||
|
||||
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
|
||||
mb_inputs = input_.synced_data_parallel_split(mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
|
||||
)
|
||||
|
|
|
@ -127,7 +127,7 @@ class MockTrainEngine(model_api.PipelinableEngine):
|
|||
version_steps=version_steps,
|
||||
)
|
||||
|
||||
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
|
||||
mb_inputs = input_.synced_data_parallel_split(mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
|
||||
)
|
||||
|
|
|
@ -808,7 +808,7 @@ class PipelineRunner:
|
|||
mb_spec = MicroBatchSpec.new(
|
||||
mb_spec, n_mbs=self.default_inf_mbs * mb_spec.n_mbs
|
||||
)
|
||||
mb_inputs, fwd_indices, bwd_indices = input_.divide_into_mbs(mb_spec)
|
||||
mb_inputs, fwd_indices, bwd_indices = input_.split(mb_spec)
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
|
@ -884,11 +884,8 @@ class PipelineRunner:
|
|||
# When the global batch is fixed, not matter how many micro-batches we
|
||||
# split, the all-together KV-cache memory usage will not be changed,
|
||||
# so it's useless to split micro-batches here.
|
||||
mb_spec = MicroBatchSpec(
|
||||
n_mbs=self.default_inf_mbs,
|
||||
balanced_seqs=True,
|
||||
)
|
||||
mb_inputs, *_ = input_.divide_into_mbs(mb_spec)
|
||||
mb_spec = MicroBatchSpec(n_mbs=self.default_inf_mbs)
|
||||
mb_inputs, *_ = input_.split(mb_spec)
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(
|
||||
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
|
||||
|
@ -1009,7 +1006,7 @@ class PipelineRunner:
|
|||
mb_spec = MicroBatchSpec.new(
|
||||
mb_spec, n_mbs=mb_spec.n_mbs * self.default_train_mbs
|
||||
)
|
||||
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
|
||||
mb_inputs = input_.synced_data_parallel_split(mb_spec)
|
||||
total_loss_weight = torch.tensor(
|
||||
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
|
||||
)
|
||||
|
|
|
@ -615,9 +615,11 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
)
|
||||
# NOTE: We cannot randomly shuffle data here because
|
||||
# data must have the same shape across different pipeline stages.
|
||||
datas = input_.split(
|
||||
self.n_minibatches,
|
||||
min_size=input_.bs // self.n_minibatches,
|
||||
datas, *_ = input_.split(MicroBatchSpec(n_mbs=self.n_minibatches))
|
||||
logger.info(
|
||||
f"PPO minibatch split (size {self.n_minibatches}): "
|
||||
f"#seqs: {[s.bs for s in datas]}, "
|
||||
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
|
||||
)
|
||||
|
||||
if self.use_dense_reward:
|
||||
|
@ -1091,9 +1093,11 @@ class PPOCriticInterface(model_api.ModelInterface):
|
|||
)
|
||||
# NOTE: We cannot randomly shuffle data here because
|
||||
# data must have the same shape across different pipeline stages.
|
||||
datas = input_.split(
|
||||
self.n_minibatches,
|
||||
min_size=input_.bs // self.n_minibatches,
|
||||
datas, *_ = input_.split(MicroBatchSpec(n_mbs=self.n_minibatches))
|
||||
logger.info(
|
||||
f"PPO minibatch split (size {self.n_minibatches}): "
|
||||
f"#seqs: {[s.bs for s in datas]}, "
|
||||
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
|
||||
)
|
||||
|
||||
# Logging.
|
||||
|
|
|
@ -381,21 +381,27 @@ async def model_rpc_request_func(
|
|||
|
||||
# Dispatch data to different data parallel ranks.
|
||||
dp_size = topo.get_dim("data")
|
||||
pp_size = topo.get_dim("pipe")
|
||||
if rpc.mb_spec.balanced_seqs or (
|
||||
rpc.mb_spec.max_tokens_per_mb is not None
|
||||
and rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP
|
||||
):
|
||||
# For a train RPC, we must assure that all DP ranks have the same number
|
||||
# of micro-batches, so we must evenly distribute the sequences.
|
||||
assert sample.bs % dp_size == 0
|
||||
min_n_seqs_per_dp = sample.bs // dp_size
|
||||
if rpc.is_generate():
|
||||
# The workload of generation is decided by batch size, instead of the generated length.
|
||||
samples, forward_indices, _ = sample.split_with_lengths(
|
||||
mb_spec=data_api.MicroBatchSpec(n_mbs=dp_size),
|
||||
lens=[1 for _ in range(sample.bs)],
|
||||
)
|
||||
else:
|
||||
min_n_seqs_per_dp = pp_size * rpc.min_n_seqs_per_pass * rpc.mb_spec.n_mbs
|
||||
if rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP and pp_size > 1:
|
||||
min_n_seqs_per_dp *= 2
|
||||
split_spec = sample.get_split_spec(dp_size, min_size=int(min_n_seqs_per_dp))
|
||||
partitions = split_spec.partitions
|
||||
samples, forward_indices, _ = sample.split(
|
||||
data_api.MicroBatchSpec(n_mbs=dp_size)
|
||||
)
|
||||
blogger.info(
|
||||
f"DP split (DP size {dp_size}) for RPC {rpc.name}: "
|
||||
f"#seqs: {[s.bs for s in samples]}, "
|
||||
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in samples]}"
|
||||
)
|
||||
sample = data_api.SequenceSample.gather(samples)
|
||||
buf_indices = [buf_indices[i] for i in forward_indices]
|
||||
|
||||
partitions = data_api.SequenceSplitSpec(
|
||||
sizes=[s.bs for s in samples]
|
||||
).partitions
|
||||
target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)}
|
||||
|
||||
# Set data owner of produced data by this RPC, such that downstream RPCs can know
|
||||
|
|
|
@ -243,19 +243,6 @@ def test_gather_split(sample_type: str, dp: int):
|
|||
|
||||
x = SequenceSample.gather(samples)
|
||||
|
||||
# Test gather-split-gather cosistency
|
||||
for k in x.keys:
|
||||
y = SequenceSample.gather(x.split(dp, key=k, min_size=1))
|
||||
recursive_assert_equal(x, y)
|
||||
|
||||
# Test balanced split
|
||||
balanced_size = sum(batch_sizes) // dp
|
||||
for k in x.keys:
|
||||
splitted = x.split(dp, key=k, min_size=balanced_size)
|
||||
assert all(len(s.ids) >= balanced_size for s in splitted)
|
||||
y = SequenceSample.gather(splitted)
|
||||
recursive_assert_equal(x, y)
|
||||
|
||||
# Test split to original samples
|
||||
spec = SequenceSplitSpec(sizes=batch_sizes)
|
||||
ss = x.split_with_spec(spec)
|
||||
|
@ -264,11 +251,10 @@ def test_gather_split(sample_type: str, dp: int):
|
|||
|
||||
# Test split to the finest granularity
|
||||
total_bs = sum(batch_sizes)
|
||||
for k in x.keys:
|
||||
ss = x.split(total_bs, key=k, min_size=1)
|
||||
assert len(ss) == total_bs
|
||||
y = SequenceSample.gather(ss)
|
||||
recursive_assert_equal(x, y)
|
||||
ss, _, backward_indices = x.split(MicroBatchSpec(n_mbs=x.bs))
|
||||
assert len(ss) == total_bs
|
||||
y = SequenceSample.reorder(SequenceSample.gather(ss), backward_indices)
|
||||
recursive_assert_equal(x, y)
|
||||
|
||||
# Test divide micro batch and merge back
|
||||
for seqlens in [
|
||||
|
@ -281,7 +267,7 @@ def test_gather_split(sample_type: str, dp: int):
|
|||
n_mbs=np.random.randint(1, 10),
|
||||
max_tokens_per_mb=np.random.randint(800, 1000),
|
||||
)
|
||||
mb_data, fwd_indices, bwd_indices = x.divide_into_mbs(mb_spec)
|
||||
mb_data, fwd_indices, bwd_indices = x.split(mb_spec)
|
||||
|
||||
for id_x, id_y in zip(
|
||||
[x.ids[i] for i in fwd_indices],
|
||||
|
|
Loading…
Reference in New Issue