PullRequest: 43 Reduce GPU memory used by data transfer.

Merge branch mzy/fix-data-transfer-oom of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/43

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* add oom observe logs
* tested
* format and clear code
* .
* format
* remove logging
* .
* add comments
This commit is contained in:
晓雷 2025-03-18 15:20:57 +08:00
parent 122bf6f214
commit 4ac9595295
4 changed files with 32 additions and 15 deletions

View File

@ -1,6 +1,7 @@
import json
import time
from parser import extract_answer
from grader import math_equal

View File

@ -12,7 +12,7 @@ import torch.distributed as dist
from realhf import SequenceSample
from realhf.api.core.config import ModelName, ModelShardID
from realhf.base import constants
from realhf.base import constants, logging
from realhf.base.topology import ProcessTopology, new_or_get_group
from realhf.impl.model.comm.global_comm import filter_match_mwids
from realhf.system.redistributor import RedistribStep
@ -21,6 +21,8 @@ BCAST_GROUPS = {}
GATHER_GROUPS = {}
SCATTER_GROUPS = {}
logger = logging.getLogger("data_manager", "system")
class DataManager:
@ -325,8 +327,21 @@ class DataManager:
)
if dist.get_rank() == step.root:
scatter_list = []
for ids in step.ids:
# Scatter destinations include all DP, TP, and PP ranks
# and data is duplicated among TP/PP groups
# We allocate new memory for DP ranks, but use the same pointer
# for all TP and PP ranks to save memory.
scatter_clusters = []
for idx, ids in enumerate(step.ids):
for _ids, idx_list in scatter_clusters:
if set(ids) == set(_ids):
idx_list.append(idx)
break
else:
scatter_clusters.append((ids, [idx]))
scatter_list = [None for _ in range(len(step.ids))]
before_pad = []
for ids, idx_list in scatter_clusters:
for i in ids:
self.storage[i].to_device(constants.current_device())
samples = [self.storage[i] for i in ids]
@ -337,11 +352,17 @@ class DataManager:
for key in step.keys
]
)
scatter_list.append(data)
maxlen = max([x.shape[0] for x in scatter_list])
scatter_list = [self._pad_data(x, maxlen) for x in scatter_list]
if step.root not in step.dsts:
before_pad.append(data)
maxlen = max([x.shape[0] for x in before_pad])
after_pad = [self._pad_data(x, maxlen) for x in before_pad]
for (ids, idx_list), data in zip(scatter_clusters, after_pad):
for idx in idx_list:
scatter_list[idx] = data
assert all([torch.is_tensor(t) for t in scatter_list])
if step.root not in step.dsts:
idx = bisect.bisect(step.dsts, step.root)
scatter_list.insert(idx, buf)
else:

View File

@ -3,7 +3,6 @@
import asyncio
import dataclasses
import itertools
import os
from collections import defaultdict
from typing import *

View File

@ -17,11 +17,7 @@ import torch.distributed as dist
from realhf.api.core.config import ModelName, ModelShardID
from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, testing, topology
from realhf.base.testing import (
LocalMultiProcessTest,
PipeDataModelParallelTopology,
init_global_constants,
)
from realhf.base.testing import LocalMultiProcessTest, init_global_constants
from realhf.system.data_manager import DataManager
from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner
@ -143,7 +139,7 @@ def _test_data_transfer(
):
from_model_name = ModelName("data_transfer_test", 0)
from_topo = PipeDataModelParallelTopology(
from_topo = topology.PipeDataModelParallelTopology(
num_pp=from_pp_dp_mp[0],
num_mp=from_pp_dp_mp[-1],
num_dp=from_pp_dp_mp[1],
@ -152,7 +148,7 @@ def _test_data_transfer(
gradient_accumulation_fusion=True,
)
to_model_name = ModelName("data_transfer_test", 1)
to_topo = PipeDataModelParallelTopology(
to_topo = topology.PipeDataModelParallelTopology(
num_pp=to_pp_dp_mp[0],
num_mp=to_pp_dp_mp[-1],
num_dp=to_pp_dp_mp[1],