mirror of https://github.com/inclusionAI/AReaL
PullRequest: 43 Reduce GPU memory used by data transfer.
Merge branch mzy/fix-data-transfer-oom of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/43 Signed-off-by: 博惟 <bowei.fw@antgroup.com> * add oom observe logs * tested * format and clear code * . * format * remove logging * . * add comments
This commit is contained in:
parent
122bf6f214
commit
4ac9595295
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import time
|
||||
from parser import extract_answer
|
||||
|
||||
from grader import math_equal
|
||||
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.distributed as dist
|
|||
|
||||
from realhf import SequenceSample
|
||||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.base import constants
|
||||
from realhf.base import constants, logging
|
||||
from realhf.base.topology import ProcessTopology, new_or_get_group
|
||||
from realhf.impl.model.comm.global_comm import filter_match_mwids
|
||||
from realhf.system.redistributor import RedistribStep
|
||||
|
@ -21,6 +21,8 @@ BCAST_GROUPS = {}
|
|||
GATHER_GROUPS = {}
|
||||
SCATTER_GROUPS = {}
|
||||
|
||||
logger = logging.getLogger("data_manager", "system")
|
||||
|
||||
|
||||
class DataManager:
|
||||
|
||||
|
@ -325,8 +327,21 @@ class DataManager:
|
|||
)
|
||||
|
||||
if dist.get_rank() == step.root:
|
||||
scatter_list = []
|
||||
for ids in step.ids:
|
||||
# Scatter destinations include all DP, TP, and PP ranks
|
||||
# and data is duplicated among TP/PP groups
|
||||
# We allocate new memory for DP ranks, but use the same pointer
|
||||
# for all TP and PP ranks to save memory.
|
||||
scatter_clusters = []
|
||||
for idx, ids in enumerate(step.ids):
|
||||
for _ids, idx_list in scatter_clusters:
|
||||
if set(ids) == set(_ids):
|
||||
idx_list.append(idx)
|
||||
break
|
||||
else:
|
||||
scatter_clusters.append((ids, [idx]))
|
||||
scatter_list = [None for _ in range(len(step.ids))]
|
||||
before_pad = []
|
||||
for ids, idx_list in scatter_clusters:
|
||||
for i in ids:
|
||||
self.storage[i].to_device(constants.current_device())
|
||||
samples = [self.storage[i] for i in ids]
|
||||
|
@ -337,11 +352,17 @@ class DataManager:
|
|||
for key in step.keys
|
||||
]
|
||||
)
|
||||
scatter_list.append(data)
|
||||
maxlen = max([x.shape[0] for x in scatter_list])
|
||||
scatter_list = [self._pad_data(x, maxlen) for x in scatter_list]
|
||||
if step.root not in step.dsts:
|
||||
before_pad.append(data)
|
||||
|
||||
maxlen = max([x.shape[0] for x in before_pad])
|
||||
after_pad = [self._pad_data(x, maxlen) for x in before_pad]
|
||||
for (ids, idx_list), data in zip(scatter_clusters, after_pad):
|
||||
for idx in idx_list:
|
||||
scatter_list[idx] = data
|
||||
|
||||
assert all([torch.is_tensor(t) for t in scatter_list])
|
||||
|
||||
if step.root not in step.dsts:
|
||||
idx = bisect.bisect(step.dsts, step.root)
|
||||
scatter_list.insert(idx, buf)
|
||||
else:
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import asyncio
|
||||
import dataclasses
|
||||
import itertools
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import *
|
||||
|
||||
|
|
|
@ -17,11 +17,7 @@ import torch.distributed as dist
|
|||
from realhf.api.core.config import ModelName, ModelShardID
|
||||
from realhf.api.core.data_api import SequenceSample
|
||||
from realhf.base import constants, testing, topology
|
||||
from realhf.base.testing import (
|
||||
LocalMultiProcessTest,
|
||||
PipeDataModelParallelTopology,
|
||||
init_global_constants,
|
||||
)
|
||||
from realhf.base.testing import LocalMultiProcessTest, init_global_constants
|
||||
from realhf.system.data_manager import DataManager
|
||||
from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner
|
||||
|
||||
|
@ -143,7 +139,7 @@ def _test_data_transfer(
|
|||
):
|
||||
|
||||
from_model_name = ModelName("data_transfer_test", 0)
|
||||
from_topo = PipeDataModelParallelTopology(
|
||||
from_topo = topology.PipeDataModelParallelTopology(
|
||||
num_pp=from_pp_dp_mp[0],
|
||||
num_mp=from_pp_dp_mp[-1],
|
||||
num_dp=from_pp_dp_mp[1],
|
||||
|
@ -152,7 +148,7 @@ def _test_data_transfer(
|
|||
gradient_accumulation_fusion=True,
|
||||
)
|
||||
to_model_name = ModelName("data_transfer_test", 1)
|
||||
to_topo = PipeDataModelParallelTopology(
|
||||
to_topo = topology.PipeDataModelParallelTopology(
|
||||
num_pp=to_pp_dp_mp[0],
|
||||
num_mp=to_pp_dp_mp[-1],
|
||||
num_dp=to_pp_dp_mp[1],
|
||||
|
|
Loading…
Reference in New Issue