mirror of https://github.com/Jittor/Jittor
fix
This commit is contained in:
parent
fb726ad22a
commit
ba0afe5099
|
@ -1,6 +1,7 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
|
|
|
@ -15,6 +15,7 @@ from jittor import nn
|
|||
from jittor import nn, Module
|
||||
import copy
|
||||
n = 2
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
def test_all_reduce():
|
||||
print("test all_reduce")
|
||||
|
@ -80,17 +81,31 @@ def main():
|
|||
test_broadcast()
|
||||
test_reduce()
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||
class TestNcclOps(unittest.TestCase):
|
||||
# @unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
# class TestNcclOps(unittest.TestCase):
|
||||
# def test(self):
|
||||
# mpi = jt.compile_extern.mpi
|
||||
# if mpi.world_size() == 1 and n != 1:
|
||||
# mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
# cmd = f"{mpirun_path} -np {n} {sys.executable} -m jittor.test.test_nccl_ops"
|
||||
# print("run cmd", cmd)
|
||||
# jt.compiler.run_cmd(cmd)
|
||||
# else:
|
||||
# main()
|
||||
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
class TestMpi(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if mpi.world_size() == 1 and n != 1:
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
cmd = f"{mpirun_path} -np {n} {sys.executable} -m jittor.test.test_nccl_ops"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
else:
|
||||
main()
|
||||
main()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestNcclOps(unittest.TestCase):
|
||||
def test_entry(self):
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = f"{mpirun_path} -np {n} {sys.executable} -m jittor.test.test_nccl_ops -v"
|
||||
print("run cmd:", cmd)
|
||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue