This commit is contained in:
Dun Liang 2020-04-03 13:59:21 +08:00
commit 45b8375e80
9 changed files with 293 additions and 10 deletions

View File

@ -1,5 +1,6 @@
// *************************************************************** // ***************************************************************
// Copyright (c) 2020 // Copyright (c) 2020 Jittor.
// Authors:
// Dun Liang <randonlang@gmail.com>. // Dun Liang <randonlang@gmail.com>.
// All Rights Reserved. // All Rights Reserved.
// This file is subject to the terms and conditions defined in // This file is subject to the terms and conditions defined in

35
extern/mpi/inc/mpi_warper.h vendored Normal file
View File

@ -0,0 +1,35 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#define OMPI_SKIP_MPICXX
#include <mpi.h>
extern void throw_mpi_error(int result,
char const *const func, const char *const file, int const line);
static inline void mpi_check(int result,
char const *const func, const char *const file, int const line) {
if (result != MPI_SUCCESS) {
throw_mpi_error(result, func, file, line);
}
}
#define MPI_CHECK(val) mpi_check((val), #val, __FILE__, __LINE__)
namespace jittor {
extern int mpi_world_size;
extern int mpi_world_rank;
// @pyjt(world_size)
int _mpi_world_size();
// @pyjt(world_rank)
int _mpi_world_rank();
} // jittor

42
extern/mpi/ops/mpi_test_op.cc vendored Normal file
View File

@ -0,0 +1,42 @@
// ***************************************************************
// Copyright (c) 2019 Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mpi_warper.h"
#include "var.h"
#include "mpi_test_op.h"
#include "misc/str_utils.h"
namespace jittor {
#ifndef JIT
MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) {
output = create_output(1, ns_float32);
}
void MpiTestOp::jit_prepare() {
add_jit_define("T", ns_float32);
}
#else // JIT
void MpiTestOp::jit_run() {
output->ptr<T>()[0] = 123;
int world_size = mpi_world_size;
int world_rank = mpi_world_rank;
char processor_name[MPI_MAX_PROCESSOR_NAME];
int name_len;
MPI_CHECK(MPI_Get_processor_name(processor_name, &name_len));
printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size);
}
#endif // JIT
} // jittor

23
extern/mpi/ops/mpi_test_op.h vendored Normal file
View File

@ -0,0 +1,23 @@
// ***************************************************************
// Copyright (c) 2020
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct MpiTestOp : Op {
Var* output;
string cmd;
MpiTestOp(string cmd);
const char* name() const override { return "mpi_test"; }
DECLARE_jit_run;
};
} // jittor

55
extern/mpi/src/mpi_warper.cc vendored Normal file
View File

@ -0,0 +1,55 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mpi_warper.h"
#include "common.h"
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
void throw_mpi_error(int result,
char const *const func, const char *const file, int const line) {
int resultlen;
MPI_Error_string(result, jt_mpi_err_buffer, &resultlen);
fprintf(stderr, "MPI error at %s:%d code=%d(%s) \"%s\" \n",
file, line,
static_cast<unsigned int>(result), jt_mpi_err_buffer, func);
throw std::runtime_error("MPI error");
}
namespace jittor {
int mpi_world_size;
int mpi_world_rank;
int _mpi_world_size() {
return mpi_world_size;
}
int _mpi_world_rank() {
return mpi_world_rank;
}
struct mpi_initer {
mpi_initer() {
MPI_CHECK(MPI_Init(NULL, NULL));
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
}
~mpi_initer() {
MPI_CHECK(MPI_Finalize());
}
};
static mpi_initer mpi_init;
} // jittor

View File

@ -296,11 +296,74 @@ def setup_nccl():
nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops") nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops")
nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)] nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)]
nccl_ops = compile_custom_ops(nccl_op_files, nccl_ops = compile_custom_ops(nccl_op_files,
extra_flags=f" -I'{nccl_include_path}'") extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ")
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops))) LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
def manual_link(flags):
lib_dirs = []
libs = []
for f in flags.split():
if f.startswith("-l"):
libs.append(f[2:])
elif f.startswith("-L"):
lib_dirs.append(f[2:])
LOG.v("manual_link:", flags)
LOG.v("lib_dirs:", lib_dirs)
LOG.v("libs:", libs)
for lib in libs:
for d in lib_dirs:
libname = os.path.join(d, f"lib{lib}.so")
if os.path.isfile(libname):
LOG.v("link:", libname)
ctypes.CDLL(libname, dlopen_flags)
break
def setup_mpi():
global mpi_ops, mpi, use_mpi
global mpicc_path, has_mpi
use_mpi = os.environ.get("use_mpi", "1")=="1"
mpi_ops = None
mpi = None
has_mpi = False
if not use_mpi: return
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
if mpicc_path == "":
LOG.i("mpicc not found, distribution disabled.")
use_mpi = False
else:
use_mpi = True
has_mpi = True
if not use_mpi:
return
global mpi_compile_flags, mpi_link_flags, mpi_flags
mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile")
mpi_link_flags = run_cmd(mpicc_path+" --showme:link")
mpi_flags = mpi_compile_flags + " " + mpi_link_flags
LOG.i("mpi_flags: "+mpi_flags)
manual_link(mpi_flags)
# find all source files
mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
mpi_src_files = []
for r, _, f in os.walk(mpi_src_dir):
for fname in f:
mpi_src_files.append(os.path.join(r, fname))
# mpi compile flags add for nccl
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
mpi = compile_custom_ops(mpi_src_files,
extra_flags=f" {mpi_flags} ", return_module=True)
mpi_ops = mpi.ops
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
setup_mpi()
setup_nccl()
setup_cutt() setup_cutt()
setup_mkl() setup_mkl()
setup_nccl()
setup_cuda_extern() setup_cuda_extern()

View File

@ -554,22 +554,31 @@ def compile_custom_op(header, source, op_name, warp=True):
m = compile_custom_ops([hname, ccname]) m = compile_custom_ops([hname, ccname])
return getattr(m, op_name) return getattr(m, op_name)
def compile_custom_ops(filenames, extra_flags=""): def compile_custom_ops(filenames, extra_flags="", return_module=False):
"""Compile custom ops """Compile custom ops
filenames: path of op source files, filenames must be filenames: path of op source files, filenames must be
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
type name of op must be XxxXxxOp. type name of op must be XxxXxxOp.
extra_flags: extra compile flags extra_flags: extra compile flags
return_module: return module rather than ops(default: False)
return: compiled ops return: compiled ops
""" """
srcs = {} srcs = {}
headers = {} headers = {}
builds = [] builds = []
includes = [] includes = []
pyjt_includes = []
for name in filenames: for name in filenames:
name = os.path.realpath(name) name = os.path.realpath(name)
if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"): if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"):
builds.append(name) builds.append(name)
if name.endswith(".h"):
dirname = os.path.dirname(name)
if dirname.endswith("inc"):
includes.append(dirname)
with open(name, "r") as f:
if "@pyjt" in f.read():
pyjt_includes.append(name)
bname = os.path.basename(name) bname = os.path.basename(name)
bname = os.path.splitext(bname)[0] bname = os.path.splitext(bname)[0]
if bname.endswith("_op"): if bname.endswith("_op"):
@ -597,14 +606,33 @@ def compile_custom_ops(filenames, extra_flags=""):
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc") gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h") gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
gen_lib = os.path.join("custom_ops", gen_name+extension_suffix) gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
with open(gen_head_fname, "w") as f: pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
f.write(gen_src)
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname)
# gen src initialize first # gen src initialize first
builds.insert(0, gen_src_fname) builds.insert(0, gen_src_fname)
def insert_anchor(gen_src, anchor_str, insert_str):
# insert insert_str after anchor_str into gen_src
return gen_src.replace(anchor_str, anchor_str+insert_str, 1)
for name in pyjt_includes:
LOG.i("handle pyjt_include", name)
bname = name.split("/")[-1].split(".")[0]
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+"_"+bname+".cc")
pyjt_compiler.compile_single(name, gen_src_fname)
builds.insert(1, gen_src_fname)
gen_src = insert_anchor(gen_src,
"namespace jittor {",
f"extern void pyjt_def_{bname}(PyObject* m);")
gen_src = insert_anchor(gen_src,
"init_module(PyModuleDef* mdef, PyObject* m) {",
f"jittor::pyjt_def_{bname}(m);")
with open(gen_head_fname, "w") as f:
f.write(gen_src)
LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvv(f"Build custum ops lib:{gen_lib}")
LOG.vvvv(f"Build sources:{builds}") LOG.vvvv(f"Build sources:{builds}")
compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib) compile(cc_path, extra_flags+cc_flags+opt_flags+includes, builds, gen_lib)
# add python path and import # add python path and import
LOG.vvv(f"Import custum ops lib:{gen_lib}") LOG.vvv(f"Import custum ops lib:{gen_lib}")
@ -613,7 +641,10 @@ def compile_custom_ops(filenames, extra_flags=""):
os.sys.path.append(lib_path) os.sys.path.append(lib_path)
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND): with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
exec(f"import {gen_name}") exec(f"import {gen_name}")
return (locals()[gen_name]).ops mod = locals()[gen_name]
if return_module:
return mod
return mod.ops
def get_full_path_of_executable(name): def get_full_path_of_executable(name):

View File

@ -0,0 +1,30 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import os, sys
import jittor as jt
import numpy as np
def main():
jt.compile_extern.mpi_ops.mpi_test("").data == 123
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
class TestMpi(unittest.TestCase):
def test(self):
mpi = jt.compile_extern.mpi
if mpi.world_size() == 1:
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
print("run cmd", cmd)
jt.compiler.run_cmd(cmd)
else:
main()
if __name__ == "__main__":
unittest.main()

View File

@ -178,7 +178,10 @@ def find_cache_path():
return path return path
def get_version(output): def get_version(output):
version = run_cmd(output+" --version") if output.endswith("mpicc"):
version = run_cmd(output+" --showme:version")
else:
version = run_cmd(output+" --version")
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version) v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
if len(v) == 0: if len(v) == 0:
v = re.findall("[0-9]+\\.[0-9]+", version) v = re.findall("[0-9]+\\.[0-9]+", version)