mirror of https://github.com/Jittor/Jittor
polish world rank and world size
This commit is contained in:
parent
39165cccfb
commit
692cbddb8e
|
@ -95,6 +95,9 @@ def val(epoch):
|
|||
|
||||
下面是 jittor 的 mpi api reference.
|
||||
|
||||
* `jt.world_rank`: 获取当前进程总数量,如果没有用mpi,则为1。
|
||||
* `jt.rank`: 获取当前进程的编号,区间为`0 ~ jt.world_rank-1`, 如果没有用mpi,则为0。
|
||||
|
||||
```eval_rst
|
||||
.. automodule:: jittor_mpi_core
|
||||
:members:
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.89'
|
||||
__version__ = '1.2.3.90'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -23,7 +23,7 @@ with lock.lock_scope():
|
|||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
|
||||
if core.get_device_count() == 0:
|
||||
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
|
||||
if has_cuda:
|
||||
|
|
|
@ -504,6 +504,7 @@ if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
|||
setup_mpi()
|
||||
in_mpi = inside_mpi()
|
||||
rank = mpi.world_rank() if in_mpi else 0
|
||||
world_size = mpi.world_size() if in_mpi else 1
|
||||
setup_nccl()
|
||||
|
||||
setup_cutt()
|
||||
|
|
Loading…
Reference in New Issue