PullRequest: 5 修改微批次分割逻辑

Merge branch fw/balanced-datapck of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/5

Signed-off-by: 温差 <xushusheng.xss@antgroup.com>


* .
* fw/fix-dataloading-not-shuffle
* .
* .
* .
* .
* .
This commit is contained in:
博惟 2025-03-03 09:10:04 +08:00
parent ceee49454a
commit 46c5a10eb9
10 changed files with 96 additions and 157 deletions

View File

@ -109,14 +109,10 @@ class MicroBatchSpec:
:param max_tokens_per_mb: The maximum number of tokens per micro-
batch.
:type max_tokens_per_mb: Optional[int]
:param balanced_seqs: Whether to balance the number of sequences per
micro-batch. Only effective when max_tokens_per_mb is None.
:type balanced_seqs: bool, optional
"""
n_mbs: int = 1
max_tokens_per_mb: int | None = None
balanced_seqs: bool = False
max_tokens_per_mb: int = int(1e12)
@classmethod
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
@ -124,7 +120,6 @@ class MicroBatchSpec:
fields = dict(
n_mbs=mb_spec.n_mbs,
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
balanced_seqs=mb_spec.balanced_seqs,
)
fields.update(kwargs)
return cls(**fields)
@ -350,30 +345,6 @@ class SequenceSample:
acc_seqlen = {k: sum(sum(l) for l in lens) for k, lens in self.seqlens.items()}
return max(acc_seqlen, key=acc_seqlen.get)
def get_split_spec(
self, k: int, key: Optional[str] = None, min_size: int = 1
) -> SequenceSplitSpec:
"""Get the partition specification for splitting the data into `k`
parts using a dynamic programming algorithm to achieve the most
balanced partitioning.
:param k: The number of parts to split the data into.
:type k: int
:param key: The key to be used for splitting. If None, the key
with the largest total sequence length will be used.
:type key: Optional[str]
:param min_size: The minimum size of each partition.
:type min_size: int
:return: A SequenceSplitSpec object representing the
partitioning specification.
:rtype: SequenceSplitSpec
"""
if key is None:
key = self._get_split_key()
lens = [sum(lens) for lens in self.seqlens[key]]
partitions = datapack.min_abs_diff_partition(lens, k, min_size)
return SequenceSplitSpec(partitions=partitions)
def split_with_spec(self, spec: SequenceSplitSpec) -> List["SequenceSample"]:
"""Split the data according to the given spec."""
samples = []
@ -419,47 +390,9 @@ class SequenceSample:
)
return samples
def split(
self,
k: int,
key: Optional[str] = None,
min_size: int = 1,
) -> List["SequenceSample"]:
"""Split the data into `k` parts.
This method uses the specified key or the key with the largest total sequence length
to split the data into `k` parts. The partitioning ensures that each part meets the
minimum size requirement.
:param k: The number of parts to split the data into.
:type k: int
:param key: The key to use for splitting. If None, the key with the largest
total sequence length will be used.
:type key: Optional[str]
:param min_size: The minimum size of each partition.
:type min_size: int
:return: A list of `SequenceSample` objects, each representing a part of the split data.
:rtype: List[SequenceSample]
"""
spec = self.get_split_spec(k, key, min_size)
return self.split_with_spec(spec)
def divide_into_mbs(
self, mb_spec: MicroBatchSpec
def split_with_lengths(
self, mb_spec: MicroBatchSpec, lens: List[int]
) -> Tuple[List["SequenceSample"], List[int] | np.ndarray, List[int] | np.ndarray]:
if mb_spec.max_tokens_per_mb is None:
return (
self.split(
mb_spec.n_mbs,
min_size=(
1 if not mb_spec.balanced_seqs else self.bs // mb_spec.n_mbs
),
),
np.arange(self.bs),
np.arange(self.bs),
)
lens = [sum(lens) for lens in self.seqlens[self._get_split_key()]]
group_indices = datapack.ffd_allocate(
lens, mb_spec.max_tokens_per_mb, min_groups=mb_spec.n_mbs
)
@ -474,10 +407,24 @@ class SequenceSample:
return sample.split_with_spec(spec), forward_indices, backward_indices
def divide_into_mbs_balanced(
def split(
self, mb_spec: MicroBatchSpec
) -> Tuple[List["SequenceSample"], List[int] | np.ndarray, List[int] | np.ndarray]:
"""Split the data into `n_mbs` parts.
:param mb_spec: The configuration to split the data into.
`n_mbs` is the minimum number of micro-batches,
`max_tokens_per_mb` is the maximum number of tokens in each micro-batch.
If `max_tokens_per_mb` is a large value, defaults to balanced split.
:type mb_spec: MicroBatchSpec
"""
lens = [sum(lens) for lens in self.seqlens[self._get_split_key()]]
return self.split_with_lengths(mb_spec, lens)
def synced_data_parallel_split(
self, mb_spec: MicroBatchSpec
) -> List["SequenceSample"]:
mb_inputs, *_ = self.divide_into_mbs(mb_spec)
mb_inputs, *_ = self.split(mb_spec)
all_n_mbs = [None for _ in range(constants.data_parallel_world_size())]
dist.all_gather_object(
all_n_mbs, len(mb_inputs), group=constants.data_parallel_group()
@ -487,7 +434,7 @@ class SequenceSample:
# This method is called when max_tokens_per_mb is given and during training.
# In this case, we evenly partition sequences across DP ranks,
# so the recursion will always terminate when n_mbs = bs // dp_size
return self.divide_into_mbs_balanced(
return self.synced_data_parallel_split(
MicroBatchSpec.new(mb_spec, n_mbs=max(all_n_mbs))
)

View File

@ -2,6 +2,7 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import heapq
import itertools
from typing import Any, List, Tuple, Union
@ -148,28 +149,36 @@ def reorder_to_balanced_batches(
return np.array(reordered_indices), max_diff
@numba.njit
# @numba.njit
def _ffd_allocate(
values: np.ndarray, capacity: int, min_groups: int
) -> List[List[int]]:
"""A greedy allocation algorithm that partitions a list of numbers
into k groups, where the summation of each group is less than capacity
and k >= min_groups. We want to minimize k and make partitions as balanced
as possible.
1. Sort the numbers in reverse order.
2. If the number of groups is less than, create a new group.
3. If the new number fits into the smallest group, add it into the group.
4. Otherwise, create a new group.
"""
value_indices = np.argsort(-values)
group_indices = []
group_values = []
group_indices: List[List[int]] = []
group_values: List[Tuple[float, int]] = []
group_cnt = 0
for idx in value_indices:
if len(group_values) < min_groups:
group_values.append(values[idx])
group_indices.append([idx])
continue
placed = False
for i in range(len(group_values)):
if group_values[i] + values[idx] <= capacity:
group_values[i] += values[idx]
group_indices[i].append(idx)
placed = True
break
if not placed:
group_values.append(values[idx])
if (
len(group_values) < min_groups
or group_values[0][0] + values[idx] > capacity
):
heapq.heappush(group_values, (float(values[idx]), group_cnt))
group_indices.append([idx])
group_cnt += 1
else:
v, group_idx = heapq.heappop(group_values)
heapq.heappush(group_values, (float(v + values[idx]), group_idx))
group_indices[group_idx].append(idx)
return group_indices
@ -188,15 +197,15 @@ if __name__ == "__main__":
for i in range(100):
st = time.monotonic()
nums = np.random.randint(512, 4000, size=(32768,))
nums = np.random.randint(1024, 8192, size=(100,))
# k = np.random.randint(2, 20)
# min_size = np.random.randint(1, len(nums) // k)
# res = min_abs_diff_partition(nums, k, min_size)
# assert all(y - x >= min_size for x, y in res)
max_tokens_per_mb = 655360
n_groups = np.random.randint(10, 20)
groups = ffd_allocate(nums, max_tokens_per_mb, n_groups)
assert len(groups) >= n_groups
max_tokens_per_mb = 163840
min_n_groups = np.random.randint(1, 8)
groups = ffd_allocate(nums, max_tokens_per_mb, min_n_groups)
assert len(groups) >= min_n_groups
import itertools
indices = list(itertools.chain(*groups))
@ -207,6 +216,8 @@ if __name__ == "__main__":
print(
len(groups),
min_n_groups,
[sum(nums[i] for i in group) for group in groups],
max(group_percent),
min(group_percent),
np.mean(group_percent),

View File

@ -82,11 +82,6 @@ def check_valid_parallel_batch_size(rpc_alloc: RPCAllocation):
factor = 1
if rpc.is_train() and rpc_alloc.parallel.pipeline_parallel_size > 1:
factor = 2
if mb_spec.balanced_seqs or (
mb_spec.max_tokens_per_mb is not None and rpc.is_train()
):
assert rpc.n_seqs % dp_size == 0, (rpc.n_seqs, dp_size)
return
assert (
rpc.n_seqs

View File

@ -103,7 +103,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
post_hook=post_hook,
aggregate_fn=aggregate_fn,
)
mb_inputs, fwd_indices, bwd_indices = input_.divide_into_mbs(mb_spec)
mb_inputs, fwd_indices, bwd_indices = input_.split(mb_spec)
if constants.parallelism_rank() == 0:
logger.info(
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
@ -161,14 +161,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
# NOTE: Interleave mini-batches in the pipeline results will not decrease
# the memory usage, because we need to hold all KV-caches for different
# mini-batches, so we split mini-batches in the outer loop.
if mb_spec.max_tokens_per_mb is not None and constants.parallelism_rank() == 0:
logger.warning(
"Generation will ignore max_tokens_per_mb because the length is not predictable."
)
mb_spec = MicroBatchSpec.new(
mb_spec, max_tokens_per_mb=None, balanced_seqs=True
)
mb_inputs, *_ = input_.divide_into_mbs(mb_spec)
mb_inputs, *_ = input_.split(mb_spec)
if constants.parallelism_rank() == 0:
logger.info(
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "

View File

@ -751,7 +751,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
with megatron_ctx():
self.engine.zero_grad()
if constants.pipe_parallel_world_size() > 1:
mb_inputs = input_.divide_into_mbs_balanced(
mb_inputs = input_.synced_data_parallel_split(
MicroBatchSpec.new(
mb_spec,
n_mbs=mb_spec.n_mbs * self.pipe_runner.default_train_mbs,
@ -771,7 +771,7 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
version_steps=version_steps,
)
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
mb_inputs = input_.synced_data_parallel_split(mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)

View File

@ -127,7 +127,7 @@ class MockTrainEngine(model_api.PipelinableEngine):
version_steps=version_steps,
)
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
mb_inputs = input_.synced_data_parallel_split(mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)

View File

@ -808,7 +808,7 @@ class PipelineRunner:
mb_spec = MicroBatchSpec.new(
mb_spec, n_mbs=self.default_inf_mbs * mb_spec.n_mbs
)
mb_inputs, fwd_indices, bwd_indices = input_.divide_into_mbs(mb_spec)
mb_inputs, fwd_indices, bwd_indices = input_.split(mb_spec)
if constants.parallelism_rank() == 0:
logger.info(
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
@ -884,11 +884,8 @@ class PipelineRunner:
# When the global batch is fixed, not matter how many micro-batches we
# split, the all-together KV-cache memory usage will not be changed,
# so it's useless to split micro-batches here.
mb_spec = MicroBatchSpec(
n_mbs=self.default_inf_mbs,
balanced_seqs=True,
)
mb_inputs, *_ = input_.divide_into_mbs(mb_spec)
mb_spec = MicroBatchSpec(n_mbs=self.default_inf_mbs)
mb_inputs, *_ = input_.split(mb_spec)
if constants.parallelism_rank() == 0:
logger.info(
f"MB spec: {mb_spec}, #mbs={len(mb_inputs)}, "
@ -1009,7 +1006,7 @@ class PipelineRunner:
mb_spec = MicroBatchSpec.new(
mb_spec, n_mbs=mb_spec.n_mbs * self.default_train_mbs
)
mb_inputs = input_.divide_into_mbs_balanced(mb_spec)
mb_inputs = input_.synced_data_parallel_split(mb_spec)
total_loss_weight = torch.tensor(
sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32
)

View File

@ -615,9 +615,11 @@ class PPOActorInterface(model_api.ModelInterface):
)
# NOTE: We cannot randomly shuffle data here because
# data must have the same shape across different pipeline stages.
datas = input_.split(
self.n_minibatches,
min_size=input_.bs // self.n_minibatches,
datas, *_ = input_.split(MicroBatchSpec(n_mbs=self.n_minibatches))
logger.info(
f"PPO minibatch split (size {self.n_minibatches}): "
f"#seqs: {[s.bs for s in datas]}, "
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
)
if self.use_dense_reward:
@ -1091,9 +1093,11 @@ class PPOCriticInterface(model_api.ModelInterface):
)
# NOTE: We cannot randomly shuffle data here because
# data must have the same shape across different pipeline stages.
datas = input_.split(
self.n_minibatches,
min_size=input_.bs // self.n_minibatches,
datas, *_ = input_.split(MicroBatchSpec(n_mbs=self.n_minibatches))
logger.info(
f"PPO minibatch split (size {self.n_minibatches}): "
f"#seqs: {[s.bs for s in datas]}, "
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
)
# Logging.

View File

@ -381,21 +381,27 @@ async def model_rpc_request_func(
# Dispatch data to different data parallel ranks.
dp_size = topo.get_dim("data")
pp_size = topo.get_dim("pipe")
if rpc.mb_spec.balanced_seqs or (
rpc.mb_spec.max_tokens_per_mb is not None
and rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP
):
# For a train RPC, we must assure that all DP ranks have the same number
# of micro-batches, so we must evenly distribute the sequences.
assert sample.bs % dp_size == 0
min_n_seqs_per_dp = sample.bs // dp_size
if rpc.is_generate():
# The workload of generation is decided by batch size, instead of the generated length.
samples, forward_indices, _ = sample.split_with_lengths(
mb_spec=data_api.MicroBatchSpec(n_mbs=dp_size),
lens=[1 for _ in range(sample.bs)],
)
else:
min_n_seqs_per_dp = pp_size * rpc.min_n_seqs_per_pass * rpc.mb_spec.n_mbs
if rpc.interface_type == dfg.ModelInterfaceType.TRAIN_STEP and pp_size > 1:
min_n_seqs_per_dp *= 2
split_spec = sample.get_split_spec(dp_size, min_size=int(min_n_seqs_per_dp))
partitions = split_spec.partitions
samples, forward_indices, _ = sample.split(
data_api.MicroBatchSpec(n_mbs=dp_size)
)
blogger.info(
f"DP split (DP size {dp_size}) for RPC {rpc.name}: "
f"#seqs: {[s.bs for s in samples]}, "
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in samples]}"
)
sample = data_api.SequenceSample.gather(samples)
buf_indices = [buf_indices[i] for i in forward_indices]
partitions = data_api.SequenceSplitSpec(
sizes=[s.bs for s in samples]
).partitions
target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)}
# Set data owner of produced data by this RPC, such that downstream RPCs can know

View File

@ -243,19 +243,6 @@ def test_gather_split(sample_type: str, dp: int):
x = SequenceSample.gather(samples)
# Test gather-split-gather cosistency
for k in x.keys:
y = SequenceSample.gather(x.split(dp, key=k, min_size=1))
recursive_assert_equal(x, y)
# Test balanced split
balanced_size = sum(batch_sizes) // dp
for k in x.keys:
splitted = x.split(dp, key=k, min_size=balanced_size)
assert all(len(s.ids) >= balanced_size for s in splitted)
y = SequenceSample.gather(splitted)
recursive_assert_equal(x, y)
# Test split to original samples
spec = SequenceSplitSpec(sizes=batch_sizes)
ss = x.split_with_spec(spec)
@ -264,11 +251,10 @@ def test_gather_split(sample_type: str, dp: int):
# Test split to the finest granularity
total_bs = sum(batch_sizes)
for k in x.keys:
ss = x.split(total_bs, key=k, min_size=1)
assert len(ss) == total_bs
y = SequenceSample.gather(ss)
recursive_assert_equal(x, y)
ss, _, backward_indices = x.split(MicroBatchSpec(n_mbs=x.bs))
assert len(ss) == total_bs
y = SequenceSample.reorder(SequenceSample.gather(ss), backward_indices)
recursive_assert_equal(x, y)
# Test divide micro batch and merge back
for seqlens in [
@ -281,7 +267,7 @@ def test_gather_split(sample_type: str, dp: int):
n_mbs=np.random.randint(1, 10),
max_tokens_per_mb=np.random.randint(800, 1000),
)
mb_data, fwd_indices, bwd_indices = x.divide_into_mbs(mb_spec)
mb_data, fwd_indices, bwd_indices = x.split(mb_spec)
for id_x, id_y in zip(
[x.ids[i] for i in fwd_indices],