mirror of https://github.com/Jittor/Jittor
polish world rank and world size
This commit is contained in:
parent
39165cccfb
commit
692cbddb8e
|
@ -95,6 +95,9 @@ def val(epoch):
|
||||||
|
|
||||||
下面是 jittor 的 mpi api reference.
|
下面是 jittor 的 mpi api reference.
|
||||||
|
|
||||||
|
* `jt.world_rank`: 获取当前进程总数量,如果没有用mpi,则为1。
|
||||||
|
* `jt.rank`: 获取当前进程的编号,区间为`0 ~ jt.world_rank-1`, 如果没有用mpi,则为0。
|
||||||
|
|
||||||
```eval_rst
|
```eval_rst
|
||||||
.. automodule:: jittor_mpi_core
|
.. automodule:: jittor_mpi_core
|
||||||
:members:
|
:members:
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.2.3.89'
|
__version__ = '1.2.3.90'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
@ -23,7 +23,7 @@ with lock.lock_scope():
|
||||||
from jittor_core import *
|
from jittor_core import *
|
||||||
from jittor_core.ops import *
|
from jittor_core.ops import *
|
||||||
from . import compile_extern
|
from . import compile_extern
|
||||||
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank
|
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
|
||||||
if core.get_device_count() == 0:
|
if core.get_device_count() == 0:
|
||||||
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
|
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
|
||||||
if has_cuda:
|
if has_cuda:
|
||||||
|
|
|
@ -504,6 +504,7 @@ if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
||||||
setup_mpi()
|
setup_mpi()
|
||||||
in_mpi = inside_mpi()
|
in_mpi = inside_mpi()
|
||||||
rank = mpi.world_rank() if in_mpi else 0
|
rank = mpi.world_rank() if in_mpi else 0
|
||||||
|
world_size = mpi.world_size() if in_mpi else 1
|
||||||
setup_nccl()
|
setup_nccl()
|
||||||
|
|
||||||
setup_cutt()
|
setup_cutt()
|
||||||
|
|
Loading…
Reference in New Issue