mirror of https://github.com/Jittor/Jittor
mpi init
This commit is contained in:
parent
2153b65856
commit
c4ff963af9
|
@ -780,6 +780,12 @@ if not os.path.isfile(py3_config_path) :
|
|||
|
||||
assert os.path.isfile(py3_config_path)
|
||||
nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc')
|
||||
if 'mpi_path' in os.environ:
|
||||
mpi_path = os.environ['mpi_path']
|
||||
else:
|
||||
mpi_path = '/usr/local/openmpi'
|
||||
assert os.path.isfile(os.path.join(mpi_path,"include","mpi.h"))
|
||||
assert os.path.isfile(os.path.join(mpi_path,"lib","libmpi.so"))
|
||||
gdb_path = try_find_exe('gdb')
|
||||
addr2line_path = try_find_exe('addr2line')
|
||||
has_pybt = check_pybt(gdb_path, python_path)
|
||||
|
@ -919,6 +925,7 @@ flags.cc_path = cc_path
|
|||
flags.cc_type = cc_type
|
||||
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags
|
||||
flags.nvcc_path = nvcc_path
|
||||
flags.mpi_path = mpi_path
|
||||
flags.nvcc_flags = nvcc_flags
|
||||
flags.python_path = python_path
|
||||
flags.cache_path = cache_path
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
def find_cache_path():
|
||||
from pathlib import Path
|
||||
path = str(Path.home())
|
||||
dirs = [".cache", "jittor"]
|
||||
for d in dirs:
|
||||
path = os.path.join(path, d)
|
||||
if not os.path.isdir(path):
|
||||
os.mkdir(path)
|
||||
assert os.path.isdir(path)
|
||||
return path
|
||||
|
||||
cache_path = find_cache_path()
|
||||
|
||||
class TestMpi(unittest.TestCase):
|
||||
def test(self):
|
||||
# Modified from: https://mpitutorial.com/tutorials/mpi-hello-world/zh_cn/
|
||||
content="""
|
||||
#include <mpi.h>
|
||||
#include <stdio.h>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
MPI_Init(NULL, NULL);
|
||||
|
||||
int world_size;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||
|
||||
int world_rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
|
||||
|
||||
char processor_name[MPI_MAX_PROCESSOR_NAME];
|
||||
int name_len;
|
||||
MPI_Get_processor_name(processor_name, &name_len);
|
||||
|
||||
printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size);
|
||||
|
||||
MPI_Finalize();
|
||||
}
|
||||
"""
|
||||
test_path=os.path.join(cache_path,"test_mpi.cc")
|
||||
f=open(test_path,"w")
|
||||
f.write(content)
|
||||
f.close()
|
||||
mpi_path=jt.flags.mpi_path
|
||||
mpi_include = os.path.join(mpi_path, "include")
|
||||
mpi_lib = os.path.join(mpi_path, "lib")
|
||||
cmd = f"cd {cache_path} && g++ {test_path} -I {mpi_include} -L {mpi_lib} -lmpi -o test_mpi && mpirun -n 4 ./test_mpi"
|
||||
self.assertEqual(os.system(cmd), 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -20,6 +20,7 @@ DEFINE_FLAG(string, jittor_path, "", "Source path of jittor");
|
|||
DEFINE_FLAG(string, cc_path, "", "Path of C++ compiler");
|
||||
DEFINE_FLAG(string, cc_type, "", "Type of C++ compiler(clang, icc, g++)");
|
||||
DEFINE_FLAG(string, cc_flags, "", "Flags of C++ compiler");
|
||||
DEFINE_FLAG(string, mpi_path, "", "Path of mpi dir");
|
||||
DEFINE_FLAG(string, nvcc_path, "", "Path of CUDA C++ compiler");
|
||||
DEFINE_FLAG(string, nvcc_flags, "", "Flags of CUDA C++ compiler");
|
||||
DEFINE_FLAG(string, python_path, "", "Path of python interpreter");
|
||||
|
|
Loading…
Reference in New Issue