forked from OSSInnovation/mindspore
!13857 check whether communication unit has been inited
From: @yao_yf Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
b1c86b6a22
|
@ -77,6 +77,13 @@ class Backend:
|
|||
raise ValueError("Invalid backend: '{}'".format(name))
|
||||
return value
|
||||
|
||||
DEFAULT_BACKEND = Backend("hccl")
|
||||
|
||||
class GlobalComm:
|
||||
"""World communication information."""
|
||||
BACKEND = DEFAULT_BACKEND
|
||||
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||
INITED = False
|
||||
|
||||
def is_hccl_available():
|
||||
"""
|
||||
|
@ -114,6 +121,8 @@ def check_parameter_available(func):
|
|||
def wrapper(*args, **kargs):
|
||||
if _is_role_pserver() or _is_role_sched():
|
||||
return func(*args, **kargs)
|
||||
if not GlobalComm.INITED:
|
||||
raise RuntimeError("Distributed Communication has not been inited")
|
||||
group = None
|
||||
if "group" in kargs.keys():
|
||||
group = kargs.get("group")
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
|||
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
|
||||
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
|
||||
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
|
||||
_get_local_rank_helper, _get_local_size_helper
|
||||
_get_local_rank_helper, _get_local_size_helper, GlobalComm
|
||||
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
|
||||
|
||||
|
||||
|
@ -28,8 +28,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
|
|||
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
|
||||
|
||||
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||
DEFAULT_BACKEND = Backend("hccl")
|
||||
|
||||
|
||||
def _get_group(group):
|
||||
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
||||
|
@ -38,11 +36,6 @@ def _get_group(group):
|
|||
return group
|
||||
|
||||
|
||||
class GlobalComm:
|
||||
"""World communication information."""
|
||||
BACKEND = DEFAULT_BACKEND
|
||||
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
|
||||
|
||||
|
||||
def init(backend_name=None):
|
||||
"""
|
||||
|
@ -78,10 +71,12 @@ def init(backend_name=None):
|
|||
init_hccl()
|
||||
GlobalComm.BACKEND = Backend("hccl")
|
||||
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||
GlobalComm.INITED = True
|
||||
elif backend_name == "nccl":
|
||||
init_gpu_collective()
|
||||
GlobalComm.BACKEND = Backend("nccl")
|
||||
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
|
||||
GlobalComm.INITED = True
|
||||
else:
|
||||
raise RuntimeError("Backend name {} is not supported.".format(backend_name))
|
||||
|
||||
|
|
Loading…
Reference in New Issue