This commit is contained in:
cxjyxx_me 2020-04-16 09:24:10 +08:00
parent fb726ad22a
commit ba0afe5099
2 changed files with 26 additions and 10 deletions

View File

@ -1,6 +1,7 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Guoye Yang <498731903@qq.com>
# Wenyang Zhou <576825820@qq.com>
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.

View File

@ -15,6 +15,7 @@ from jittor import nn
from jittor import nn, Module
import copy
n = 2
mpi = jt.compile_extern.mpi
def test_all_reduce():
print("test all_reduce")
@ -80,17 +81,31 @@ def main():
test_broadcast()
test_reduce()
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
class TestNcclOps(unittest.TestCase):
# @unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
# class TestNcclOps(unittest.TestCase):
# def test(self):
# mpi = jt.compile_extern.mpi
# if mpi.world_size() == 1 and n != 1:
# mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
# cmd = f"{mpirun_path} -np {n} {sys.executable} -m jittor.test.test_nccl_ops"
# print("run cmd", cmd)
# jt.compiler.run_cmd(cmd)
# else:
# main()
@unittest.skipIf(mpi is None, "no inside mpirun")
class TestMpi(unittest.TestCase):
def test(self):
mpi = jt.compile_extern.mpi
if mpi.world_size() == 1 and n != 1:
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
cmd = f"{mpirun_path} -np {n} {sys.executable} -m jittor.test.test_nccl_ops"
print("run cmd", cmd)
jt.compiler.run_cmd(cmd)
else:
main()
main()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestNcclOps(unittest.TestCase):
def test_entry(self):
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np {n} {sys.executable} -m jittor.test.test_nccl_ops -v"
print("run cmd:", cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd
if __name__ == "__main__":
unittest.main()