mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
3c8f36fa49
commit
fdd85e2e55
|
@ -0,0 +1,73 @@
|
|||
import torch
|
||||
|
||||
|
||||
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
|
||||
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
|
||||
def init_custom_process_group(
|
||||
backend=None,
|
||||
init_method=None,
|
||||
timeout=None,
|
||||
world_size=-1,
|
||||
rank=-1,
|
||||
store=None,
|
||||
group_name=None,
|
||||
pg_options=None,
|
||||
):
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
PrefixStore,
|
||||
_new_process_group_helper,
|
||||
_world,
|
||||
default_pg_timeout,
|
||||
rendezvous,
|
||||
)
|
||||
|
||||
assert (store is None) or (
|
||||
init_method is None
|
||||
), "Cannot specify both init_method and store."
|
||||
|
||||
if store is not None:
|
||||
assert world_size > 0, "world_size must be positive if using store"
|
||||
assert rank >= 0, "rank must be non-negative if using store"
|
||||
elif init_method is None:
|
||||
init_method = "env://"
|
||||
|
||||
if backend:
|
||||
backend = Backend(backend)
|
||||
else:
|
||||
backend = Backend("undefined")
|
||||
|
||||
if timeout is None:
|
||||
timeout = default_pg_timeout
|
||||
|
||||
# backward compatible API
|
||||
if store is None:
|
||||
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
|
||||
store, rank, world_size = next(rendezvous_iterator)
|
||||
store.set_timeout(timeout)
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
store = PrefixStore(group_name, store)
|
||||
|
||||
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
|
||||
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
|
||||
# We need to determine the appropriate parameter name based on PyTorch version
|
||||
pg_options_param_name = (
|
||||
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
|
||||
)
|
||||
pg, _ = _new_process_group_helper(
|
||||
world_size,
|
||||
rank,
|
||||
[],
|
||||
backend,
|
||||
store,
|
||||
group_name=group_name,
|
||||
**{pg_options_param_name: pg_options},
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||
|
||||
return pg
|
Loading…
Reference in New Issue