mirror of https://github.com/inclusionAI/AReaL
157 lines
5.0 KiB
Python
157 lines
5.0 KiB
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
import dataclasses
|
|
import itertools
|
|
import os
|
|
import socket
|
|
from collections import defaultdict
|
|
from typing import *
|
|
|
|
import torch.distributed
|
|
|
|
from realhf.api.core.config import ModelName, ModelShardID
|
|
from realhf.base import constants, gpu_utils, name_resolve, names, network, topology
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class NCCLProcessGroupInfo:
|
|
world_size: int
|
|
global_rank: int
|
|
local_gpu_id: int
|
|
# 3D parallelism groups of each model.
|
|
model_groups: Dict[str, torch.distributed.ProcessGroup]
|
|
|
|
|
|
def filter_match_mwids(
|
|
model_name: ModelName,
|
|
topo: topology.ProcessTopology,
|
|
msid2mwid: Dict[ModelShardID, int],
|
|
**conditions,
|
|
) -> List[int]:
|
|
if len(conditions) == 0:
|
|
mwids_this_model = [
|
|
msid2mwid[ModelShardID.from_parallelism_rank(model_name, topo, j)]
|
|
for j in range(topo.world_size())
|
|
]
|
|
else:
|
|
mwids_this_model = [
|
|
msid2mwid[ModelShardID.from_parallelism_rank(model_name, topo, j)]
|
|
for j in topo.filter_match(**conditions)
|
|
]
|
|
mwids_this_model = sorted(mwids_this_model)
|
|
assert len(set(mwids_this_model)) == len(mwids_this_model)
|
|
return list(mwids_this_model)
|
|
|
|
|
|
def setup_global_comm(
|
|
expr_name: str,
|
|
trial_name: str,
|
|
worker_index: int,
|
|
model_topos: Optional[Dict[str, topology.ProcessTopology]] = None,
|
|
msid2mwid: Optional[Dict[ModelShardID, int]] = None,
|
|
backend: str = "nccl",
|
|
) -> NCCLProcessGroupInfo:
|
|
peers: List[int] = list(
|
|
sorted(
|
|
map(
|
|
int,
|
|
name_resolve.get_subtree(
|
|
names.distributed_peer(
|
|
expr_name,
|
|
trial_name,
|
|
gpu_utils.GLOBAL_PROCESS_GROUP_NAME,
|
|
)
|
|
),
|
|
)
|
|
)
|
|
)
|
|
assert len(peers) == len(set(peers)), f"Duplicated trainer worker index. {peers}"
|
|
world_size = len(peers)
|
|
global_rank = peers.index(worker_index)
|
|
|
|
mw_ranks = {}
|
|
if model_topos is not None:
|
|
assert msid2mwid is not None
|
|
for model_name, topo in model_topos.items():
|
|
mw_ranks[model_name] = filter_match_mwids(model_name, topo, msid2mwid)
|
|
|
|
if (
|
|
"GPU_DEVICES_ISOLATED" not in os.environ
|
|
and "RAY" not in os.environ["REAL_MODE"]
|
|
and constants.use_cuda()
|
|
):
|
|
raise RuntimeError(
|
|
"GPU devices not isolated in slurm or local mode. This should not happen."
|
|
)
|
|
|
|
if constants.use_cuda():
|
|
assert len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == 1, os.environ[
|
|
"CUDA_VISIBLE_DEVICES"
|
|
]
|
|
local_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"])
|
|
else:
|
|
local_gpu_id = global_rank
|
|
|
|
pg_master_name = names.distributed_master(
|
|
expr_name, trial_name, gpu_utils.GLOBAL_PROCESS_GROUP_NAME
|
|
)
|
|
|
|
if worker_index == 0:
|
|
host_ip = socket.gethostbyname(socket.gethostname())
|
|
port = network.find_free_port(experiment_name=expr_name, trial_name=trial_name)
|
|
pg_init_addr = f"tcp://{host_ip}:{port}"
|
|
name_resolve.add(pg_master_name, pg_init_addr, keepalive_ttl=300)
|
|
else:
|
|
try:
|
|
pg_init_addr = name_resolve.wait(pg_master_name, timeout=300)
|
|
except TimeoutError:
|
|
raise TimeoutError(
|
|
f"global_rank={global_rank} worker_index={worker_index} wait for process group init timeout."
|
|
)
|
|
|
|
if not constants.use_cuda():
|
|
backend = "gloo"
|
|
torch_dist_kwargs = dict(
|
|
world_size=world_size,
|
|
rank=global_rank,
|
|
init_method=pg_init_addr,
|
|
backend=backend,
|
|
timeout=constants.NCCL_DEFAULT_TIMEOUT,
|
|
)
|
|
if constants.use_cuda():
|
|
torch.cuda.set_device(
|
|
0
|
|
) # initialize CUDA here with only a single visible device
|
|
# This environment variable is used by DeepSpeed.
|
|
os.environ["LOCAL_RANK"] = "0"
|
|
|
|
if not torch.distributed.is_initialized():
|
|
torch.distributed.init_process_group(
|
|
**torch_dist_kwargs, group_name=gpu_utils.GLOBAL_PROCESS_GROUP_NAME
|
|
)
|
|
|
|
model_groups = {}
|
|
for model_name, ranks in mw_ranks.items():
|
|
model_groups[model_name] = topology.new_or_get_group(ranks, backend=backend)
|
|
constants.set_parallelism_group(model_name, model_groups[model_name], ranks)
|
|
cpu_group = topology.new_or_get_group(ranks, backend="gloo")
|
|
constants.set_cpu_parallelism_group(model_name, cpu_group)
|
|
|
|
self_group = None
|
|
for i in range(world_size):
|
|
group = topology.new_or_get_group([i], backend=backend)
|
|
if i == global_rank:
|
|
self_group = group
|
|
constants.set_self_group(self_group)
|
|
|
|
# logger.info(f"Setup process group finishes for worker_index={worker_index}")
|
|
|
|
return NCCLProcessGroupInfo(
|
|
world_size=world_size,
|
|
global_rank=global_rank,
|
|
local_gpu_id=local_gpu_id,
|
|
model_groups=model_groups,
|
|
)
|