mirror of https://github.com/inclusionAI/AReaL
74 lines
2.2 KiB
Python
74 lines
2.2 KiB
Python
import torch
|
|
|
|
|
|
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
|
|
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
|
|
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
|
|
def init_custom_process_group(
|
|
backend=None,
|
|
init_method=None,
|
|
timeout=None,
|
|
world_size=-1,
|
|
rank=-1,
|
|
store=None,
|
|
group_name=None,
|
|
pg_options=None,
|
|
):
|
|
from torch.distributed.distributed_c10d import (
|
|
Backend,
|
|
PrefixStore,
|
|
_new_process_group_helper,
|
|
_world,
|
|
default_pg_timeout,
|
|
rendezvous,
|
|
)
|
|
|
|
assert (store is None) or (
|
|
init_method is None
|
|
), "Cannot specify both init_method and store."
|
|
|
|
if store is not None:
|
|
assert world_size > 0, "world_size must be positive if using store"
|
|
assert rank >= 0, "rank must be non-negative if using store"
|
|
elif init_method is None:
|
|
init_method = "env://"
|
|
|
|
if backend:
|
|
backend = Backend(backend)
|
|
else:
|
|
backend = Backend("undefined")
|
|
|
|
if timeout is None:
|
|
timeout = default_pg_timeout
|
|
|
|
# backward compatible API
|
|
if store is None:
|
|
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
|
|
store, rank, world_size = next(rendezvous_iterator)
|
|
store.set_timeout(timeout)
|
|
|
|
# Use a PrefixStore to avoid accidental overrides of keys used by
|
|
# different systems (e.g. RPC) in case the store is multi-tenant.
|
|
store = PrefixStore(group_name, store)
|
|
|
|
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
|
|
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
|
|
# We need to determine the appropriate parameter name based on PyTorch version
|
|
pg_options_param_name = (
|
|
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
|
|
)
|
|
pg, _ = _new_process_group_helper(
|
|
world_size,
|
|
rank,
|
|
[],
|
|
backend,
|
|
store,
|
|
group_name=group_name,
|
|
**{pg_options_param_name: pg_options},
|
|
timeout=timeout,
|
|
)
|
|
|
|
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
|
|
|
return pg
|