diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index a8ecb600..2c3fe408 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -780,6 +780,12 @@ if not os.path.isfile(py3_config_path) : assert os.path.isfile(py3_config_path) nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc') +if 'mpi_path' in os.environ: + mpi_path = os.environ['mpi_path'] +else: + mpi_path = '/usr/local/openmpi' +assert os.path.isfile(os.path.join(mpi_path,"include","mpi.h")) +assert os.path.isfile(os.path.join(mpi_path,"lib","libmpi.so")) gdb_path = try_find_exe('gdb') addr2line_path = try_find_exe('addr2line') has_pybt = check_pybt(gdb_path, python_path) @@ -919,6 +925,7 @@ flags.cc_path = cc_path flags.cc_type = cc_type flags.cc_flags = cc_flags + link_flags + kernel_opt_flags flags.nvcc_path = nvcc_path +flags.mpi_path = mpi_path flags.nvcc_flags = nvcc_flags flags.python_path = python_path flags.cache_path = cache_path diff --git a/python/jittor/test/test_mpi.py b/python/jittor/test/test_mpi.py new file mode 100644 index 00000000..b876d238 --- /dev/null +++ b/python/jittor/test/test_mpi.py @@ -0,0 +1,64 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os +import jittor as jt +import numpy as np +import copy + +def find_cache_path(): + from pathlib import Path + path = str(Path.home()) + dirs = [".cache", "jittor"] + for d in dirs: + path = os.path.join(path, d) + if not os.path.isdir(path): + os.mkdir(path) + assert os.path.isdir(path) + return path + +cache_path = find_cache_path() + +class TestMpi(unittest.TestCase): + def test(self): + # Modified from: https://mpitutorial.com/tutorials/mpi-hello-world/zh_cn/ + content=""" + #include + #include + + int main(int argc, char** argv) { + MPI_Init(NULL, NULL); + + int world_size; + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + + int world_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); + + char processor_name[MPI_MAX_PROCESSOR_NAME]; + int name_len; + MPI_Get_processor_name(processor_name, &name_len); + + printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size); + + MPI_Finalize(); + } + """ + test_path=os.path.join(cache_path,"test_mpi.cc") + f=open(test_path,"w") + f.write(content) + f.close() + mpi_path=jt.flags.mpi_path + mpi_include = os.path.join(mpi_path, "include") + mpi_lib = os.path.join(mpi_path, "lib") + cmd = f"cd {cache_path} && g++ {test_path} -I {mpi_include} -L {mpi_lib} -lmpi -o test_mpi && mpirun -n 4 ./test_mpi" + self.assertEqual(os.system(cmd), 0) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/src/jit_compiler.cc b/src/jit_compiler.cc index 75e63ef8..14323536 100755 --- a/src/jit_compiler.cc +++ b/src/jit_compiler.cc @@ -20,6 +20,7 @@ DEFINE_FLAG(string, jittor_path, "", "Source path of jittor"); DEFINE_FLAG(string, cc_path, "", "Path of C++ compiler"); DEFINE_FLAG(string, cc_type, "", "Type of C++ compiler(clang, icc, g++)"); DEFINE_FLAG(string, cc_flags, "", "Flags of C++ compiler"); +DEFINE_FLAG(string, mpi_path, "", "Path of mpi dir"); DEFINE_FLAG(string, nvcc_path, "", "Path of CUDA C++ compiler"); DEFINE_FLAG(string, nvcc_flags, "", "Flags of CUDA C++ compiler"); DEFINE_FLAG(string, python_path, "", "Path of python interpreter");