mirror of https://github.com/Jittor/Jittor
Merge branch 'mpi_init' of https://github.com/Jittor/jittor into mpi_init
This commit is contained in:
commit
5d838d1366
|
@ -0,0 +1,21 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "mpi_warper.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <nccl.h>
|
||||
#include <helper_cuda.h>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
extern ncclComm_t comm;
|
||||
extern ncclUniqueId id;
|
||||
|
||||
} // jittor
|
|
@ -7,15 +7,8 @@
|
|||
#include "nccl_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "nccl_warper.h"
|
||||
|
||||
#ifndef JIT
|
||||
const char *_cudaGetErrorEnum(ncclResult_t error) {
|
||||
return ncclGetErrorString(error);
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -33,16 +26,41 @@ void NcclTestOp::jit_prepare() {
|
|||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
static void test_with_mpi() {
|
||||
int size = 32*1024*1024;
|
||||
int myRank = mpi_world_rank;
|
||||
int nRanks = mpi_world_size;
|
||||
int localRank = mpi_local_rank;
|
||||
|
||||
float *sendbuff, *recvbuff;
|
||||
cudaStream_t s;
|
||||
checkCudaErrors(cudaMalloc(&sendbuff, size * sizeof(float)));
|
||||
checkCudaErrors(cudaMalloc(&recvbuff, size * sizeof(float)));
|
||||
checkCudaErrors(cudaStreamCreate(&s));
|
||||
|
||||
//communicating using NCCL
|
||||
checkCudaErrors(ncclAllReduce((const void*)sendbuff, (void*)recvbuff, size, ncclFloat, ncclSum,
|
||||
comm, s));
|
||||
|
||||
//completing NCCL operation by synchronizing on the CUDA stream
|
||||
checkCudaErrors(cudaStreamSynchronize(s));
|
||||
|
||||
//free device buffers
|
||||
checkCudaErrors(cudaFree(sendbuff));
|
||||
checkCudaErrors(cudaFree(recvbuff));
|
||||
checkCudaErrors(cudaStreamDestroy(s));
|
||||
|
||||
LOGi << "MPI rank" << myRank << "Success";
|
||||
}
|
||||
|
||||
void NcclTestOp::jit_run() {
|
||||
auto args = split(cmd, " ");
|
||||
if (!cmd.size()) args.clear();
|
||||
vector<char*> v(args.size());
|
||||
for (uint i=0; i<args.size(); i++)
|
||||
v[i] = &args[i][0];
|
||||
output->ptr<T>()[0] = 123;
|
||||
if (cmd == "test_with_mpi") {
|
||||
test_with_mpi();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//managing 4 devices
|
||||
int nDev;
|
||||
checkCudaErrors(cudaGetDeviceCount(&nDev));
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "nccl_warper.h"
|
||||
|
||||
|
||||
const char *_cudaGetErrorEnum(ncclResult_t error) {
|
||||
return ncclGetErrorString(error);
|
||||
}
|
||||
|
||||
namespace jittor {
|
||||
|
||||
ncclComm_t comm;
|
||||
ncclUniqueId id;
|
||||
|
||||
|
||||
struct nccl_initer {
|
||||
|
||||
nccl_initer() {
|
||||
if (mpi_world_rank == 0)
|
||||
checkCudaErrors(ncclGetUniqueId(&id));
|
||||
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
|
||||
checkCudaErrors(cudaSetDevice(mpi_local_rank));
|
||||
checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank));
|
||||
}
|
||||
|
||||
~nccl_initer() {
|
||||
checkCudaErrors(ncclCommDestroy(comm));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
static nccl_initer nccl_init;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,40 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#define OMPI_SKIP_MPICXX
|
||||
#include <mpi.h>
|
||||
|
||||
extern void throw_mpi_error(int result,
|
||||
char const *const func, const char *const file, int const line);
|
||||
|
||||
static inline void mpi_check(int result,
|
||||
char const *const func, const char *const file, int const line) {
|
||||
if (result != MPI_SUCCESS) {
|
||||
throw_mpi_error(result, func, file, line);
|
||||
}
|
||||
}
|
||||
|
||||
#define MPI_CHECK(val) mpi_check((val), #val, __FILE__, __LINE__)
|
||||
|
||||
namespace jittor {
|
||||
|
||||
extern int mpi_world_size;
|
||||
extern int mpi_world_rank;
|
||||
extern int mpi_local_rank;
|
||||
|
||||
// @pyjt(world_size)
|
||||
int _mpi_world_size();
|
||||
|
||||
// @pyjt(world_rank)
|
||||
int _mpi_world_rank();
|
||||
|
||||
// @pyjt(local_rank)
|
||||
int _mpi_local_rank();
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,42 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2019 Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mpi_warper.h"
|
||||
|
||||
#include "var.h"
|
||||
#include "mpi_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) {
|
||||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void MpiTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
void MpiTestOp::jit_run() {
|
||||
output->ptr<T>()[0] = 123;
|
||||
|
||||
int world_size = mpi_world_size;
|
||||
|
||||
int world_rank = mpi_world_rank;
|
||||
|
||||
char processor_name[MPI_MAX_PROCESSOR_NAME];
|
||||
int name_len;
|
||||
MPI_CHECK(MPI_Get_processor_name(processor_name, &name_len));
|
||||
|
||||
printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size);
|
||||
|
||||
}
|
||||
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,23 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MpiTestOp : Op {
|
||||
Var* output;
|
||||
string cmd;
|
||||
|
||||
MpiTestOp(string cmd);
|
||||
|
||||
const char* name() const override { return "mpi_test"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,97 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <unistd.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mpi_warper.h"
|
||||
#include "common.h"
|
||||
|
||||
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
|
||||
|
||||
void throw_mpi_error(int result,
|
||||
char const *const func, const char *const file, int const line) {
|
||||
int resultlen;
|
||||
MPI_Error_string(result, jt_mpi_err_buffer, &resultlen);
|
||||
fprintf(stderr, "MPI error at %s:%d code=%d(%s) \"%s\" \n",
|
||||
file, line,
|
||||
static_cast<unsigned int>(result), jt_mpi_err_buffer, func);
|
||||
throw std::runtime_error("MPI error");
|
||||
}
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
||||
int mpi_world_size;
|
||||
int mpi_world_rank;
|
||||
int mpi_local_rank;
|
||||
|
||||
int _mpi_world_size() {
|
||||
return mpi_world_size;
|
||||
}
|
||||
|
||||
int _mpi_world_rank() {
|
||||
return mpi_world_rank;
|
||||
}
|
||||
|
||||
int _mpi_local_rank() {
|
||||
return mpi_local_rank;
|
||||
}
|
||||
|
||||
|
||||
|
||||
static uint64_t getHostHash(const char* string) {
|
||||
// Based on DJB2, result = result * 33 + char
|
||||
uint64_t result = 5381;
|
||||
for (int c = 0; string[c] != '\0'; c++){
|
||||
result = ((result << 5) + result) + string[c];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
static void getHostName(char* hostname, int maxlen) {
|
||||
gethostname(hostname, maxlen);
|
||||
for (int i=0; i< maxlen; i++) {
|
||||
if (hostname[i] == '.') {
|
||||
hostname[i] = '\0';
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct mpi_initer {
|
||||
|
||||
mpi_initer() {
|
||||
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
|
||||
|
||||
//calculating localRank based on hostname which is used in selecting a GPU
|
||||
uint64_t hostHashs[mpi_world_rank];
|
||||
char hostname[1024];
|
||||
getHostName(hostname, 1024);
|
||||
hostHashs[mpi_world_rank] = getHostHash(hostname);
|
||||
MPI_CHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));
|
||||
mpi_local_rank = 0;
|
||||
for (int p=0; p<mpi_world_size; p++) {
|
||||
if (p == mpi_world_rank) break;
|
||||
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_rank++;
|
||||
}
|
||||
}
|
||||
|
||||
~mpi_initer() {
|
||||
MPI_CHECK(MPI_Finalize());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
static mpi_initer mpi_init;
|
||||
|
||||
} // jittor
|
|
@ -265,7 +265,7 @@ def setup_nccl():
|
|||
global nccl_ops, use_nccl
|
||||
use_nccl = os.environ.get("use_nccl", "1")=="1"
|
||||
nccl_ops = None
|
||||
if not has_cuda:
|
||||
if not has_cuda or mpi is None:
|
||||
use_nccl = False
|
||||
return
|
||||
if not use_nccl: return
|
||||
|
@ -293,14 +293,82 @@ def setup_nccl():
|
|||
# We do not link manualy, link in custom ops
|
||||
ctypes.CDLL(nccl_lib_name, dlopen_flags)
|
||||
|
||||
nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops")
|
||||
nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)]
|
||||
nccl_ops = compile_custom_ops(nccl_op_files,
|
||||
extra_flags=f" -I'{nccl_include_path}'")
|
||||
nccl_src_dir = os.path.join(jittor_path, "extern", "cuda", "nccl")
|
||||
nccl_src_files = []
|
||||
for r, _, f in os.walk(nccl_src_dir):
|
||||
for fname in f:
|
||||
nccl_src_files.append(os.path.join(r, fname))
|
||||
|
||||
nccl_ops = compile_custom_ops(nccl_src_files,
|
||||
extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ")
|
||||
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
||||
|
||||
def manual_link(flags):
|
||||
lib_dirs = []
|
||||
libs = []
|
||||
for f in flags.split():
|
||||
if f.startswith("-l"):
|
||||
libs.append(f[2:])
|
||||
elif f.startswith("-L"):
|
||||
lib_dirs.append(f[2:])
|
||||
LOG.v("manual_link:", flags)
|
||||
LOG.v("lib_dirs:", lib_dirs)
|
||||
LOG.v("libs:", libs)
|
||||
for lib in libs:
|
||||
for d in lib_dirs:
|
||||
libname = os.path.join(d, f"lib{lib}.so")
|
||||
if os.path.isfile(libname):
|
||||
LOG.v("link:", libname)
|
||||
ctypes.CDLL(libname, dlopen_flags)
|
||||
break
|
||||
|
||||
|
||||
def setup_mpi():
|
||||
global mpi_ops, mpi, use_mpi
|
||||
global mpicc_path, has_mpi
|
||||
use_mpi = os.environ.get("use_mpi", "1")=="1"
|
||||
mpi_ops = None
|
||||
mpi = None
|
||||
has_mpi = False
|
||||
if not use_mpi: return
|
||||
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
|
||||
if mpicc_path == "":
|
||||
LOG.i("mpicc not found, distribution disabled.")
|
||||
use_mpi = False
|
||||
else:
|
||||
use_mpi = True
|
||||
has_mpi = True
|
||||
if not use_mpi:
|
||||
return
|
||||
|
||||
global mpi_compile_flags, mpi_link_flags, mpi_flags
|
||||
mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile")
|
||||
mpi_link_flags = run_cmd(mpicc_path+" --showme:link")
|
||||
mpi_flags = mpi_compile_flags + " " + mpi_link_flags
|
||||
LOG.i("mpi_flags: "+mpi_flags)
|
||||
manual_link(mpi_flags)
|
||||
|
||||
# find all source files
|
||||
mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
|
||||
mpi_src_files = []
|
||||
for r, _, f in os.walk(mpi_src_dir):
|
||||
for fname in f:
|
||||
mpi_src_files.append(os.path.join(r, fname))
|
||||
|
||||
# mpi compile flags add for nccl
|
||||
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
|
||||
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
|
||||
|
||||
mpi = compile_custom_ops(mpi_src_files,
|
||||
extra_flags=f" {mpi_flags} ", return_module=True)
|
||||
mpi_ops = mpi.ops
|
||||
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
|
||||
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
|
||||
|
||||
setup_mpi()
|
||||
setup_nccl()
|
||||
|
||||
setup_cutt()
|
||||
setup_mkl()
|
||||
setup_nccl()
|
||||
|
||||
setup_cuda_extern()
|
||||
|
|
|
@ -554,22 +554,31 @@ def compile_custom_op(header, source, op_name, warp=True):
|
|||
m = compile_custom_ops([hname, ccname])
|
||||
return getattr(m, op_name)
|
||||
|
||||
def compile_custom_ops(filenames, extra_flags=""):
|
||||
def compile_custom_ops(filenames, extra_flags="", return_module=False):
|
||||
"""Compile custom ops
|
||||
filenames: path of op source files, filenames must be
|
||||
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
|
||||
type name of op must be XxxXxxOp.
|
||||
extra_flags: extra compile flags
|
||||
return_module: return module rather than ops(default: False)
|
||||
return: compiled ops
|
||||
"""
|
||||
srcs = {}
|
||||
headers = {}
|
||||
builds = []
|
||||
includes = []
|
||||
pyjt_includes = []
|
||||
for name in filenames:
|
||||
name = os.path.realpath(name)
|
||||
if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"):
|
||||
builds.append(name)
|
||||
if name.endswith(".h"):
|
||||
dirname = os.path.dirname(name)
|
||||
if dirname.endswith("inc"):
|
||||
includes.append(dirname)
|
||||
with open(name, "r") as f:
|
||||
if "@pyjt" in f.read():
|
||||
pyjt_includes.append(name)
|
||||
bname = os.path.basename(name)
|
||||
bname = os.path.splitext(bname)[0]
|
||||
if bname.endswith("_op"):
|
||||
|
@ -597,14 +606,33 @@ def compile_custom_ops(filenames, extra_flags=""):
|
|||
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
|
||||
gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
|
||||
gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
|
||||
with open(gen_head_fname, "w") as f:
|
||||
f.write(gen_src)
|
||||
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname)
|
||||
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
|
||||
# gen src initialize first
|
||||
builds.insert(0, gen_src_fname)
|
||||
|
||||
def insert_anchor(gen_src, anchor_str, insert_str):
|
||||
# insert insert_str after anchor_str into gen_src
|
||||
return gen_src.replace(anchor_str, anchor_str+insert_str, 1)
|
||||
|
||||
for name in pyjt_includes:
|
||||
LOG.i("handle pyjt_include", name)
|
||||
bname = name.split("/")[-1].split(".")[0]
|
||||
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+"_"+bname+".cc")
|
||||
pyjt_compiler.compile_single(name, gen_src_fname)
|
||||
builds.insert(1, gen_src_fname)
|
||||
gen_src = insert_anchor(gen_src,
|
||||
"namespace jittor {",
|
||||
f"extern void pyjt_def_{bname}(PyObject* m);")
|
||||
gen_src = insert_anchor(gen_src,
|
||||
"init_module(PyModuleDef* mdef, PyObject* m) {",
|
||||
f"jittor::pyjt_def_{bname}(m);")
|
||||
|
||||
with open(gen_head_fname, "w") as f:
|
||||
f.write(gen_src)
|
||||
|
||||
LOG.vvv(f"Build custum ops lib:{gen_lib}")
|
||||
LOG.vvvv(f"Build sources:{builds}")
|
||||
compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib)
|
||||
compile(cc_path, extra_flags+cc_flags+opt_flags+includes, builds, gen_lib)
|
||||
|
||||
# add python path and import
|
||||
LOG.vvv(f"Import custum ops lib:{gen_lib}")
|
||||
|
@ -613,7 +641,10 @@ def compile_custom_ops(filenames, extra_flags=""):
|
|||
os.sys.path.append(lib_path)
|
||||
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
||||
exec(f"import {gen_name}")
|
||||
return (locals()[gen_name]).ops
|
||||
mod = locals()[gen_name]
|
||||
if return_module:
|
||||
return mod
|
||||
return mod.ops
|
||||
|
||||
|
||||
def get_full_path_of_executable(name):
|
||||
|
@ -781,12 +812,6 @@ if not os.path.isfile(py3_config_path) :
|
|||
|
||||
assert os.path.isfile(py3_config_path)
|
||||
nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc')
|
||||
if 'mpi_path' in os.environ:
|
||||
mpi_path = os.environ['mpi_path']
|
||||
else:
|
||||
mpi_path = '/usr/local/openmpi'
|
||||
assert os.path.isfile(os.path.join(mpi_path,"include","mpi.h"))
|
||||
assert os.path.isfile(os.path.join(mpi_path,"lib","libmpi.so"))
|
||||
gdb_path = try_find_exe('gdb')
|
||||
addr2line_path = try_find_exe('addr2line')
|
||||
has_pybt = check_pybt(gdb_path, python_path)
|
||||
|
@ -930,7 +955,6 @@ flags.cc_path = cc_path
|
|||
flags.cc_type = cc_type
|
||||
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags
|
||||
flags.nvcc_path = nvcc_path
|
||||
flags.mpi_path = mpi_path
|
||||
flags.nvcc_flags = nvcc_flags
|
||||
flags.python_path = python_path
|
||||
flags.cache_path = cache_path
|
||||
|
|
|
@ -7,58 +7,29 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os
|
||||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
def find_cache_path():
|
||||
from pathlib import Path
|
||||
path = str(Path.home())
|
||||
dirs = [".cache", "jittor"]
|
||||
for d in dirs:
|
||||
path = os.path.join(path, d)
|
||||
if not os.path.isdir(path):
|
||||
os.mkdir(path)
|
||||
assert os.path.isdir(path)
|
||||
return path
|
||||
|
||||
cache_path = find_cache_path()
|
||||
def main():
|
||||
print("test mpi_test")
|
||||
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||
if jt.compile_extern.nccl_ops:
|
||||
print("test test_with_mpi")
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||
class TestMpi(unittest.TestCase):
|
||||
def test(self):
|
||||
# Modified from: https://mpitutorial.com/tutorials/mpi-hello-world/zh_cn/
|
||||
content="""
|
||||
#include <mpi.h>
|
||||
#include <stdio.h>
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
MPI_Init(NULL, NULL);
|
||||
|
||||
int world_size;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||
|
||||
int world_rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
|
||||
|
||||
char processor_name[MPI_MAX_PROCESSOR_NAME];
|
||||
int name_len;
|
||||
MPI_Get_processor_name(processor_name, &name_len);
|
||||
|
||||
printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size);
|
||||
|
||||
MPI_Finalize();
|
||||
}
|
||||
"""
|
||||
test_path=os.path.join(cache_path,"test_mpi.cc")
|
||||
f=open(test_path,"w")
|
||||
f.write(content)
|
||||
f.close()
|
||||
mpi_path=jt.flags.mpi_path
|
||||
mpi_include = os.path.join(mpi_path, "include")
|
||||
mpi_lib = os.path.join(mpi_path, "lib")
|
||||
cmd = f"cd {cache_path} && g++ {test_path} -I {mpi_include} -L {mpi_lib} -lmpi -o test_mpi && mpirun -n 4 ./test_mpi"
|
||||
self.assertEqual(os.system(cmd), 0)
|
||||
mpi = jt.compile_extern.mpi
|
||||
if mpi.world_size() == 1:
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
else:
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -178,7 +178,10 @@ def find_cache_path():
|
|||
return path
|
||||
|
||||
def get_version(output):
|
||||
version = run_cmd(output+" --version")
|
||||
if output.endswith("mpicc"):
|
||||
version = run_cmd(output+" --showme:version")
|
||||
else:
|
||||
version = run_cmd(output+" --version")
|
||||
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
|
||||
if len(v) == 0:
|
||||
v = re.findall("[0-9]+\\.[0-9]+", version)
|
||||
|
|
|
@ -20,7 +20,6 @@ DEFINE_FLAG(string, jittor_path, "", "Source path of jittor");
|
|||
DEFINE_FLAG(string, cc_path, "", "Path of C++ compiler");
|
||||
DEFINE_FLAG(string, cc_type, "", "Type of C++ compiler(clang, icc, g++)");
|
||||
DEFINE_FLAG(string, cc_flags, "", "Flags of C++ compiler");
|
||||
DEFINE_FLAG(string, mpi_path, "", "Path of mpi dir");
|
||||
DEFINE_FLAG(string, nvcc_path, "", "Path of CUDA C++ compiler");
|
||||
DEFINE_FLAG(string, nvcc_flags, "", "Flags of CUDA C++ compiler");
|
||||
DEFINE_FLAG(string, python_path, "", "Path of python interpreter");
|
||||
|
|
Loading…
Reference in New Issue