polish world rank and world size

This commit is contained in:
Dun Liang 2021-08-05 20:09:25 +08:00
parent 39165cccfb
commit 692cbddb8e
3 changed files with 6 additions and 2 deletions

View File

@ -95,6 +95,9 @@ def val(epoch):
下面是 jittor 的 mpi api reference.
* `jt.world_rank`: 获取当前进程总数量如果没有用mpi则为1。
* `jt.rank`: 获取当前进程的编号,区间为`0 jt.world_rank-1` 如果没有用mpi则为0。
```eval_rst
.. automodule:: jittor_mpi_core
:members:

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.89'
__version__ = '1.2.3.90'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -23,7 +23,7 @@ with lock.lock_scope():
from jittor_core import *
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
if core.get_device_count() == 0:
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
if has_cuda:

View File

@ -504,6 +504,7 @@ if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
setup_mpi()
in_mpi = inside_mpi()
rank = mpi.world_rank() if in_mpi else 0
world_size = mpi.world_size() if in_mpi else 1
setup_nccl()
setup_cutt()