Merge branch 'mpi_init' of https://github.com/Jittor/jittor into mpi_init

This commit is contained in:
guowei yang 2020-04-06 11:16:28 +08:00
commit 5d838d1366
13 changed files with 430 additions and 83 deletions

21
extern/cuda/nccl/inc/nccl_warper.h vendored Normal file
View File

@ -0,0 +1,21 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "mpi_warper.h"
#include <cuda_runtime.h>
#include <nccl.h>
#include <helper_cuda.h>
namespace jittor {
extern ncclComm_t comm;
extern ncclUniqueId id;
} // jittor

View File

@ -7,15 +7,8 @@
#include "nccl_test_op.h"
#include "misc/str_utils.h"
#include <nccl.h>
#include <cuda_runtime.h>
#include <helper_cuda.h>
#include "nccl_warper.h"
#ifndef JIT
const char *_cudaGetErrorEnum(ncclResult_t error) {
return ncclGetErrorString(error);
}
#endif
namespace jittor {
@ -33,16 +26,41 @@ void NcclTestOp::jit_prepare() {
#else // JIT
#ifdef JIT_cuda
static void test_with_mpi() {
int size = 32*1024*1024;
int myRank = mpi_world_rank;
int nRanks = mpi_world_size;
int localRank = mpi_local_rank;
float *sendbuff, *recvbuff;
cudaStream_t s;
checkCudaErrors(cudaMalloc(&sendbuff, size * sizeof(float)));
checkCudaErrors(cudaMalloc(&recvbuff, size * sizeof(float)));
checkCudaErrors(cudaStreamCreate(&s));
//communicating using NCCL
checkCudaErrors(ncclAllReduce((const void*)sendbuff, (void*)recvbuff, size, ncclFloat, ncclSum,
comm, s));
//completing NCCL operation by synchronizing on the CUDA stream
checkCudaErrors(cudaStreamSynchronize(s));
//free device buffers
checkCudaErrors(cudaFree(sendbuff));
checkCudaErrors(cudaFree(recvbuff));
checkCudaErrors(cudaStreamDestroy(s));
LOGi << "MPI rank" << myRank << "Success";
}
void NcclTestOp::jit_run() {
auto args = split(cmd, " ");
if (!cmd.size()) args.clear();
vector<char*> v(args.size());
for (uint i=0; i<args.size(); i++)
v[i] = &args[i][0];
output->ptr<T>()[0] = 123;
if (cmd == "test_with_mpi") {
test_with_mpi();
return;
}
//managing 4 devices
int nDev;
checkCudaErrors(cudaGetDeviceCount(&nDev));

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020
// Copyright (c) 2020 Jittor.
// Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in

40
extern/cuda/nccl/src/nccl_warper.cc vendored Normal file
View File

@ -0,0 +1,40 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "nccl_warper.h"
const char *_cudaGetErrorEnum(ncclResult_t error) {
return ncclGetErrorString(error);
}
namespace jittor {
ncclComm_t comm;
ncclUniqueId id;
struct nccl_initer {
nccl_initer() {
if (mpi_world_rank == 0)
checkCudaErrors(ncclGetUniqueId(&id));
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
checkCudaErrors(cudaSetDevice(mpi_local_rank));
checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank));
}
~nccl_initer() {
checkCudaErrors(ncclCommDestroy(comm));
}
};
static nccl_initer nccl_init;
} // jittor

40
extern/mpi/inc/mpi_warper.h vendored Normal file
View File

@ -0,0 +1,40 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#define OMPI_SKIP_MPICXX
#include <mpi.h>
extern void throw_mpi_error(int result,
char const *const func, const char *const file, int const line);
static inline void mpi_check(int result,
char const *const func, const char *const file, int const line) {
if (result != MPI_SUCCESS) {
throw_mpi_error(result, func, file, line);
}
}
#define MPI_CHECK(val) mpi_check((val), #val, __FILE__, __LINE__)
namespace jittor {
extern int mpi_world_size;
extern int mpi_world_rank;
extern int mpi_local_rank;
// @pyjt(world_size)
int _mpi_world_size();
// @pyjt(world_rank)
int _mpi_world_rank();
// @pyjt(local_rank)
int _mpi_local_rank();
} // jittor

42
extern/mpi/ops/mpi_test_op.cc vendored Normal file
View File

@ -0,0 +1,42 @@
// ***************************************************************
// Copyright (c) 2019 Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mpi_warper.h"
#include "var.h"
#include "mpi_test_op.h"
#include "misc/str_utils.h"
namespace jittor {
#ifndef JIT
MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) {
output = create_output(1, ns_float32);
}
void MpiTestOp::jit_prepare() {
add_jit_define("T", ns_float32);
}
#else // JIT
void MpiTestOp::jit_run() {
output->ptr<T>()[0] = 123;
int world_size = mpi_world_size;
int world_rank = mpi_world_rank;
char processor_name[MPI_MAX_PROCESSOR_NAME];
int name_len;
MPI_CHECK(MPI_Get_processor_name(processor_name, &name_len));
printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size);
}
#endif // JIT
} // jittor

23
extern/mpi/ops/mpi_test_op.h vendored Normal file
View File

@ -0,0 +1,23 @@
// ***************************************************************
// Copyright (c) 2020
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct MpiTestOp : Op {
Var* output;
string cmd;
MpiTestOp(string cmd);
const char* name() const override { return "mpi_test"; }
DECLARE_jit_run;
};
} // jittor

97
extern/mpi/src/mpi_warper.cc vendored Normal file
View File

@ -0,0 +1,97 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <unistd.h>
#include <stdint.h>
#include <stdio.h>
#include "mpi_warper.h"
#include "common.h"
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
void throw_mpi_error(int result,
char const *const func, const char *const file, int const line) {
int resultlen;
MPI_Error_string(result, jt_mpi_err_buffer, &resultlen);
fprintf(stderr, "MPI error at %s:%d code=%d(%s) \"%s\" \n",
file, line,
static_cast<unsigned int>(result), jt_mpi_err_buffer, func);
throw std::runtime_error("MPI error");
}
namespace jittor {
int mpi_world_size;
int mpi_world_rank;
int mpi_local_rank;
int _mpi_world_size() {
return mpi_world_size;
}
int _mpi_world_rank() {
return mpi_world_rank;
}
int _mpi_local_rank() {
return mpi_local_rank;
}
static uint64_t getHostHash(const char* string) {
// Based on DJB2, result = result * 33 + char
uint64_t result = 5381;
for (int c = 0; string[c] != '\0'; c++){
result = ((result << 5) + result) + string[c];
}
return result;
}
static void getHostName(char* hostname, int maxlen) {
gethostname(hostname, maxlen);
for (int i=0; i< maxlen; i++) {
if (hostname[i] == '.') {
hostname[i] = '\0';
return;
}
}
}
struct mpi_initer {
mpi_initer() {
MPI_CHECK(MPI_Init(NULL, NULL));
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
//calculating localRank based on hostname which is used in selecting a GPU
uint64_t hostHashs[mpi_world_rank];
char hostname[1024];
getHostName(hostname, 1024);
hostHashs[mpi_world_rank] = getHostHash(hostname);
MPI_CHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));
mpi_local_rank = 0;
for (int p=0; p<mpi_world_size; p++) {
if (p == mpi_world_rank) break;
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_rank++;
}
}
~mpi_initer() {
MPI_CHECK(MPI_Finalize());
}
};
static mpi_initer mpi_init;
} // jittor

View File

@ -265,7 +265,7 @@ def setup_nccl():
global nccl_ops, use_nccl
use_nccl = os.environ.get("use_nccl", "1")=="1"
nccl_ops = None
if not has_cuda:
if not has_cuda or mpi is None:
use_nccl = False
return
if not use_nccl: return
@ -293,14 +293,82 @@ def setup_nccl():
# We do not link manualy, link in custom ops
ctypes.CDLL(nccl_lib_name, dlopen_flags)
nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops")
nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)]
nccl_ops = compile_custom_ops(nccl_op_files,
extra_flags=f" -I'{nccl_include_path}'")
nccl_src_dir = os.path.join(jittor_path, "extern", "cuda", "nccl")
nccl_src_files = []
for r, _, f in os.walk(nccl_src_dir):
for fname in f:
nccl_src_files.append(os.path.join(r, fname))
nccl_ops = compile_custom_ops(nccl_src_files,
extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ")
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
def manual_link(flags):
lib_dirs = []
libs = []
for f in flags.split():
if f.startswith("-l"):
libs.append(f[2:])
elif f.startswith("-L"):
lib_dirs.append(f[2:])
LOG.v("manual_link:", flags)
LOG.v("lib_dirs:", lib_dirs)
LOG.v("libs:", libs)
for lib in libs:
for d in lib_dirs:
libname = os.path.join(d, f"lib{lib}.so")
if os.path.isfile(libname):
LOG.v("link:", libname)
ctypes.CDLL(libname, dlopen_flags)
break
def setup_mpi():
global mpi_ops, mpi, use_mpi
global mpicc_path, has_mpi
use_mpi = os.environ.get("use_mpi", "1")=="1"
mpi_ops = None
mpi = None
has_mpi = False
if not use_mpi: return
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
if mpicc_path == "":
LOG.i("mpicc not found, distribution disabled.")
use_mpi = False
else:
use_mpi = True
has_mpi = True
if not use_mpi:
return
global mpi_compile_flags, mpi_link_flags, mpi_flags
mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile")
mpi_link_flags = run_cmd(mpicc_path+" --showme:link")
mpi_flags = mpi_compile_flags + " " + mpi_link_flags
LOG.i("mpi_flags: "+mpi_flags)
manual_link(mpi_flags)
# find all source files
mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
mpi_src_files = []
for r, _, f in os.walk(mpi_src_dir):
for fname in f:
mpi_src_files.append(os.path.join(r, fname))
# mpi compile flags add for nccl
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
mpi = compile_custom_ops(mpi_src_files,
extra_flags=f" {mpi_flags} ", return_module=True)
mpi_ops = mpi.ops
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
setup_mpi()
setup_nccl()
setup_cutt()
setup_mkl()
setup_nccl()
setup_cuda_extern()

View File

@ -554,22 +554,31 @@ def compile_custom_op(header, source, op_name, warp=True):
m = compile_custom_ops([hname, ccname])
return getattr(m, op_name)
def compile_custom_ops(filenames, extra_flags=""):
def compile_custom_ops(filenames, extra_flags="", return_module=False):
"""Compile custom ops
filenames: path of op source files, filenames must be
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
type name of op must be XxxXxxOp.
extra_flags: extra compile flags
return_module: return module rather than ops(default: False)
return: compiled ops
"""
srcs = {}
headers = {}
builds = []
includes = []
pyjt_includes = []
for name in filenames:
name = os.path.realpath(name)
if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"):
builds.append(name)
if name.endswith(".h"):
dirname = os.path.dirname(name)
if dirname.endswith("inc"):
includes.append(dirname)
with open(name, "r") as f:
if "@pyjt" in f.read():
pyjt_includes.append(name)
bname = os.path.basename(name)
bname = os.path.splitext(bname)[0]
if bname.endswith("_op"):
@ -597,14 +606,33 @@ def compile_custom_ops(filenames, extra_flags=""):
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
with open(gen_head_fname, "w") as f:
f.write(gen_src)
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname)
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
# gen src initialize first
builds.insert(0, gen_src_fname)
def insert_anchor(gen_src, anchor_str, insert_str):
# insert insert_str after anchor_str into gen_src
return gen_src.replace(anchor_str, anchor_str+insert_str, 1)
for name in pyjt_includes:
LOG.i("handle pyjt_include", name)
bname = name.split("/")[-1].split(".")[0]
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+"_"+bname+".cc")
pyjt_compiler.compile_single(name, gen_src_fname)
builds.insert(1, gen_src_fname)
gen_src = insert_anchor(gen_src,
"namespace jittor {",
f"extern void pyjt_def_{bname}(PyObject* m);")
gen_src = insert_anchor(gen_src,
"init_module(PyModuleDef* mdef, PyObject* m) {",
f"jittor::pyjt_def_{bname}(m);")
with open(gen_head_fname, "w") as f:
f.write(gen_src)
LOG.vvv(f"Build custum ops lib:{gen_lib}")
LOG.vvvv(f"Build sources:{builds}")
compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib)
compile(cc_path, extra_flags+cc_flags+opt_flags+includes, builds, gen_lib)
# add python path and import
LOG.vvv(f"Import custum ops lib:{gen_lib}")
@ -613,7 +641,10 @@ def compile_custom_ops(filenames, extra_flags=""):
os.sys.path.append(lib_path)
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
exec(f"import {gen_name}")
return (locals()[gen_name]).ops
mod = locals()[gen_name]
if return_module:
return mod
return mod.ops
def get_full_path_of_executable(name):
@ -781,12 +812,6 @@ if not os.path.isfile(py3_config_path) :
assert os.path.isfile(py3_config_path)
nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc')
if 'mpi_path' in os.environ:
mpi_path = os.environ['mpi_path']
else:
mpi_path = '/usr/local/openmpi'
assert os.path.isfile(os.path.join(mpi_path,"include","mpi.h"))
assert os.path.isfile(os.path.join(mpi_path,"lib","libmpi.so"))
gdb_path = try_find_exe('gdb')
addr2line_path = try_find_exe('addr2line')
has_pybt = check_pybt(gdb_path, python_path)
@ -930,7 +955,6 @@ flags.cc_path = cc_path
flags.cc_type = cc_type
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags
flags.nvcc_path = nvcc_path
flags.mpi_path = mpi_path
flags.nvcc_flags = nvcc_flags
flags.python_path = python_path
flags.cache_path = cache_path

View File

@ -7,58 +7,29 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import os
import os, sys
import jittor as jt
import numpy as np
import copy
def find_cache_path():
from pathlib import Path
path = str(Path.home())
dirs = [".cache", "jittor"]
for d in dirs:
path = os.path.join(path, d)
if not os.path.isdir(path):
os.mkdir(path)
assert os.path.isdir(path)
return path
cache_path = find_cache_path()
def main():
print("test mpi_test")
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
if jt.compile_extern.nccl_ops:
print("test test_with_mpi")
with jt.flag_scope(use_cuda=1):
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
class TestMpi(unittest.TestCase):
def test(self):
# Modified from: https://mpitutorial.com/tutorials/mpi-hello-world/zh_cn/
content="""
#include <mpi.h>
#include <stdio.h>
int main(int argc, char** argv) {
MPI_Init(NULL, NULL);
int world_size;
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
int world_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
char processor_name[MPI_MAX_PROCESSOR_NAME];
int name_len;
MPI_Get_processor_name(processor_name, &name_len);
printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size);
MPI_Finalize();
}
"""
test_path=os.path.join(cache_path,"test_mpi.cc")
f=open(test_path,"w")
f.write(content)
f.close()
mpi_path=jt.flags.mpi_path
mpi_include = os.path.join(mpi_path, "include")
mpi_lib = os.path.join(mpi_path, "lib")
cmd = f"cd {cache_path} && g++ {test_path} -I {mpi_include} -L {mpi_lib} -lmpi -o test_mpi && mpirun -n 4 ./test_mpi"
self.assertEqual(os.system(cmd), 0)
mpi = jt.compile_extern.mpi
if mpi.world_size() == 1:
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
print("run cmd", cmd)
jt.compiler.run_cmd(cmd)
else:
main()
if __name__ == "__main__":
unittest.main()

View File

@ -178,7 +178,10 @@ def find_cache_path():
return path
def get_version(output):
version = run_cmd(output+" --version")
if output.endswith("mpicc"):
version = run_cmd(output+" --showme:version")
else:
version = run_cmd(output+" --version")
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
if len(v) == 0:
v = re.findall("[0-9]+\\.[0-9]+", version)

View File

@ -20,7 +20,6 @@ DEFINE_FLAG(string, jittor_path, "", "Source path of jittor");
DEFINE_FLAG(string, cc_path, "", "Path of C++ compiler");
DEFINE_FLAG(string, cc_type, "", "Type of C++ compiler(clang, icc, g++)");
DEFINE_FLAG(string, cc_flags, "", "Flags of C++ compiler");
DEFINE_FLAG(string, mpi_path, "", "Path of mpi dir");
DEFINE_FLAG(string, nvcc_path, "", "Path of CUDA C++ compiler");
DEFINE_FLAG(string, nvcc_flags, "", "Flags of CUDA C++ compiler");
DEFINE_FLAG(string, python_path, "", "Path of python interpreter");