mirror of https://github.com/Jittor/Jittor
Merge branch 'mpi_init' of https://github.com/Jittor/jittor
This commit is contained in:
commit
45b8375e80
|
@ -1,5 +1,6 @@
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
// Copyright (c) 2020
|
// Copyright (c) 2020 Jittor.
|
||||||
|
// Authors:
|
||||||
// Dun Liang <randonlang@gmail.com>.
|
// Dun Liang <randonlang@gmail.com>.
|
||||||
// All Rights Reserved.
|
// All Rights Reserved.
|
||||||
// This file is subject to the terms and conditions defined in
|
// This file is subject to the terms and conditions defined in
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2020 Jittor.
|
||||||
|
// Authors:
|
||||||
|
// Dun Liang <randonlang@gmail.com>.
|
||||||
|
// All Rights Reserved.
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#define OMPI_SKIP_MPICXX
|
||||||
|
#include <mpi.h>
|
||||||
|
|
||||||
|
extern void throw_mpi_error(int result,
|
||||||
|
char const *const func, const char *const file, int const line);
|
||||||
|
|
||||||
|
static inline void mpi_check(int result,
|
||||||
|
char const *const func, const char *const file, int const line) {
|
||||||
|
if (result != MPI_SUCCESS) {
|
||||||
|
throw_mpi_error(result, func, file, line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define MPI_CHECK(val) mpi_check((val), #val, __FILE__, __LINE__)
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
extern int mpi_world_size;
|
||||||
|
extern int mpi_world_rank;
|
||||||
|
|
||||||
|
// @pyjt(world_size)
|
||||||
|
int _mpi_world_size();
|
||||||
|
|
||||||
|
// @pyjt(world_rank)
|
||||||
|
int _mpi_world_rank();
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -0,0 +1,42 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2019 Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#include "mpi_warper.h"
|
||||||
|
|
||||||
|
#include "var.h"
|
||||||
|
#include "mpi_test_op.h"
|
||||||
|
#include "misc/str_utils.h"
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
#ifndef JIT
|
||||||
|
MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) {
|
||||||
|
output = create_output(1, ns_float32);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MpiTestOp::jit_prepare() {
|
||||||
|
add_jit_define("T", ns_float32);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else // JIT
|
||||||
|
|
||||||
|
void MpiTestOp::jit_run() {
|
||||||
|
output->ptr<T>()[0] = 123;
|
||||||
|
|
||||||
|
int world_size = mpi_world_size;
|
||||||
|
|
||||||
|
int world_rank = mpi_world_rank;
|
||||||
|
|
||||||
|
char processor_name[MPI_MAX_PROCESSOR_NAME];
|
||||||
|
int name_len;
|
||||||
|
MPI_CHECK(MPI_Get_processor_name(processor_name, &name_len));
|
||||||
|
|
||||||
|
printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // JIT
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -0,0 +1,23 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2020
|
||||||
|
// Dun Liang <randonlang@gmail.com>.
|
||||||
|
// All Rights Reserved.
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#pragma once
|
||||||
|
#include "op.h"
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
struct MpiTestOp : Op {
|
||||||
|
Var* output;
|
||||||
|
string cmd;
|
||||||
|
|
||||||
|
MpiTestOp(string cmd);
|
||||||
|
|
||||||
|
const char* name() const override { return "mpi_test"; }
|
||||||
|
DECLARE_jit_run;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -0,0 +1,55 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2020 Jittor.
|
||||||
|
// Authors:
|
||||||
|
// Dun Liang <randonlang@gmail.com>.
|
||||||
|
// All Rights Reserved.
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#include "mpi_warper.h"
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
|
||||||
|
|
||||||
|
void throw_mpi_error(int result,
|
||||||
|
char const *const func, const char *const file, int const line) {
|
||||||
|
int resultlen;
|
||||||
|
MPI_Error_string(result, jt_mpi_err_buffer, &resultlen);
|
||||||
|
fprintf(stderr, "MPI error at %s:%d code=%d(%s) \"%s\" \n",
|
||||||
|
file, line,
|
||||||
|
static_cast<unsigned int>(result), jt_mpi_err_buffer, func);
|
||||||
|
throw std::runtime_error("MPI error");
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
|
||||||
|
int mpi_world_size;
|
||||||
|
int mpi_world_rank;
|
||||||
|
|
||||||
|
int _mpi_world_size() {
|
||||||
|
return mpi_world_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
int _mpi_world_rank() {
|
||||||
|
return mpi_world_rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
struct mpi_initer {
|
||||||
|
|
||||||
|
mpi_initer() {
|
||||||
|
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||||
|
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||||
|
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
|
||||||
|
}
|
||||||
|
|
||||||
|
~mpi_initer() {
|
||||||
|
MPI_CHECK(MPI_Finalize());
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
static mpi_initer mpi_init;
|
||||||
|
|
||||||
|
} // jittor
|
|
@ -296,11 +296,74 @@ def setup_nccl():
|
||||||
nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops")
|
nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops")
|
||||||
nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)]
|
nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)]
|
||||||
nccl_ops = compile_custom_ops(nccl_op_files,
|
nccl_ops = compile_custom_ops(nccl_op_files,
|
||||||
extra_flags=f" -I'{nccl_include_path}'")
|
extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ")
|
||||||
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
||||||
|
|
||||||
|
def manual_link(flags):
|
||||||
|
lib_dirs = []
|
||||||
|
libs = []
|
||||||
|
for f in flags.split():
|
||||||
|
if f.startswith("-l"):
|
||||||
|
libs.append(f[2:])
|
||||||
|
elif f.startswith("-L"):
|
||||||
|
lib_dirs.append(f[2:])
|
||||||
|
LOG.v("manual_link:", flags)
|
||||||
|
LOG.v("lib_dirs:", lib_dirs)
|
||||||
|
LOG.v("libs:", libs)
|
||||||
|
for lib in libs:
|
||||||
|
for d in lib_dirs:
|
||||||
|
libname = os.path.join(d, f"lib{lib}.so")
|
||||||
|
if os.path.isfile(libname):
|
||||||
|
LOG.v("link:", libname)
|
||||||
|
ctypes.CDLL(libname, dlopen_flags)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def setup_mpi():
|
||||||
|
global mpi_ops, mpi, use_mpi
|
||||||
|
global mpicc_path, has_mpi
|
||||||
|
use_mpi = os.environ.get("use_mpi", "1")=="1"
|
||||||
|
mpi_ops = None
|
||||||
|
mpi = None
|
||||||
|
has_mpi = False
|
||||||
|
if not use_mpi: return
|
||||||
|
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
|
||||||
|
if mpicc_path == "":
|
||||||
|
LOG.i("mpicc not found, distribution disabled.")
|
||||||
|
use_mpi = False
|
||||||
|
else:
|
||||||
|
use_mpi = True
|
||||||
|
has_mpi = True
|
||||||
|
if not use_mpi:
|
||||||
|
return
|
||||||
|
|
||||||
|
global mpi_compile_flags, mpi_link_flags, mpi_flags
|
||||||
|
mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile")
|
||||||
|
mpi_link_flags = run_cmd(mpicc_path+" --showme:link")
|
||||||
|
mpi_flags = mpi_compile_flags + " " + mpi_link_flags
|
||||||
|
LOG.i("mpi_flags: "+mpi_flags)
|
||||||
|
manual_link(mpi_flags)
|
||||||
|
|
||||||
|
# find all source files
|
||||||
|
mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
|
||||||
|
mpi_src_files = []
|
||||||
|
for r, _, f in os.walk(mpi_src_dir):
|
||||||
|
for fname in f:
|
||||||
|
mpi_src_files.append(os.path.join(r, fname))
|
||||||
|
|
||||||
|
# mpi compile flags add for nccl
|
||||||
|
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
|
||||||
|
|
||||||
|
mpi = compile_custom_ops(mpi_src_files,
|
||||||
|
extra_flags=f" {mpi_flags} ", return_module=True)
|
||||||
|
mpi_ops = mpi.ops
|
||||||
|
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
|
||||||
|
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
|
||||||
|
|
||||||
|
setup_mpi()
|
||||||
|
setup_nccl()
|
||||||
|
|
||||||
setup_cutt()
|
setup_cutt()
|
||||||
setup_mkl()
|
setup_mkl()
|
||||||
setup_nccl()
|
|
||||||
|
|
||||||
setup_cuda_extern()
|
setup_cuda_extern()
|
||||||
|
|
|
@ -554,22 +554,31 @@ def compile_custom_op(header, source, op_name, warp=True):
|
||||||
m = compile_custom_ops([hname, ccname])
|
m = compile_custom_ops([hname, ccname])
|
||||||
return getattr(m, op_name)
|
return getattr(m, op_name)
|
||||||
|
|
||||||
def compile_custom_ops(filenames, extra_flags=""):
|
def compile_custom_ops(filenames, extra_flags="", return_module=False):
|
||||||
"""Compile custom ops
|
"""Compile custom ops
|
||||||
filenames: path of op source files, filenames must be
|
filenames: path of op source files, filenames must be
|
||||||
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
|
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
|
||||||
type name of op must be XxxXxxOp.
|
type name of op must be XxxXxxOp.
|
||||||
extra_flags: extra compile flags
|
extra_flags: extra compile flags
|
||||||
|
return_module: return module rather than ops(default: False)
|
||||||
return: compiled ops
|
return: compiled ops
|
||||||
"""
|
"""
|
||||||
srcs = {}
|
srcs = {}
|
||||||
headers = {}
|
headers = {}
|
||||||
builds = []
|
builds = []
|
||||||
includes = []
|
includes = []
|
||||||
|
pyjt_includes = []
|
||||||
for name in filenames:
|
for name in filenames:
|
||||||
name = os.path.realpath(name)
|
name = os.path.realpath(name)
|
||||||
if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"):
|
if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"):
|
||||||
builds.append(name)
|
builds.append(name)
|
||||||
|
if name.endswith(".h"):
|
||||||
|
dirname = os.path.dirname(name)
|
||||||
|
if dirname.endswith("inc"):
|
||||||
|
includes.append(dirname)
|
||||||
|
with open(name, "r") as f:
|
||||||
|
if "@pyjt" in f.read():
|
||||||
|
pyjt_includes.append(name)
|
||||||
bname = os.path.basename(name)
|
bname = os.path.basename(name)
|
||||||
bname = os.path.splitext(bname)[0]
|
bname = os.path.splitext(bname)[0]
|
||||||
if bname.endswith("_op"):
|
if bname.endswith("_op"):
|
||||||
|
@ -597,14 +606,33 @@ def compile_custom_ops(filenames, extra_flags=""):
|
||||||
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
|
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
|
||||||
gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
|
gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
|
||||||
gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
|
gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
|
||||||
with open(gen_head_fname, "w") as f:
|
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
|
||||||
f.write(gen_src)
|
|
||||||
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname)
|
|
||||||
# gen src initialize first
|
# gen src initialize first
|
||||||
builds.insert(0, gen_src_fname)
|
builds.insert(0, gen_src_fname)
|
||||||
|
|
||||||
|
def insert_anchor(gen_src, anchor_str, insert_str):
|
||||||
|
# insert insert_str after anchor_str into gen_src
|
||||||
|
return gen_src.replace(anchor_str, anchor_str+insert_str, 1)
|
||||||
|
|
||||||
|
for name in pyjt_includes:
|
||||||
|
LOG.i("handle pyjt_include", name)
|
||||||
|
bname = name.split("/")[-1].split(".")[0]
|
||||||
|
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+"_"+bname+".cc")
|
||||||
|
pyjt_compiler.compile_single(name, gen_src_fname)
|
||||||
|
builds.insert(1, gen_src_fname)
|
||||||
|
gen_src = insert_anchor(gen_src,
|
||||||
|
"namespace jittor {",
|
||||||
|
f"extern void pyjt_def_{bname}(PyObject* m);")
|
||||||
|
gen_src = insert_anchor(gen_src,
|
||||||
|
"init_module(PyModuleDef* mdef, PyObject* m) {",
|
||||||
|
f"jittor::pyjt_def_{bname}(m);")
|
||||||
|
|
||||||
|
with open(gen_head_fname, "w") as f:
|
||||||
|
f.write(gen_src)
|
||||||
|
|
||||||
LOG.vvv(f"Build custum ops lib:{gen_lib}")
|
LOG.vvv(f"Build custum ops lib:{gen_lib}")
|
||||||
LOG.vvvv(f"Build sources:{builds}")
|
LOG.vvvv(f"Build sources:{builds}")
|
||||||
compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib)
|
compile(cc_path, extra_flags+cc_flags+opt_flags+includes, builds, gen_lib)
|
||||||
|
|
||||||
# add python path and import
|
# add python path and import
|
||||||
LOG.vvv(f"Import custum ops lib:{gen_lib}")
|
LOG.vvv(f"Import custum ops lib:{gen_lib}")
|
||||||
|
@ -613,7 +641,10 @@ def compile_custom_ops(filenames, extra_flags=""):
|
||||||
os.sys.path.append(lib_path)
|
os.sys.path.append(lib_path)
|
||||||
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
||||||
exec(f"import {gen_name}")
|
exec(f"import {gen_name}")
|
||||||
return (locals()[gen_name]).ops
|
mod = locals()[gen_name]
|
||||||
|
if return_module:
|
||||||
|
return mod
|
||||||
|
return mod.ops
|
||||||
|
|
||||||
|
|
||||||
def get_full_path_of_executable(name):
|
def get_full_path_of_executable(name):
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2020 Jittor. Authors:
|
||||||
|
# Guowei Yang <471184555@qq.com>
|
||||||
|
# Dun Liang <randonlang@gmail.com>.
|
||||||
|
# All Rights Reserved.
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
import unittest
|
||||||
|
import os, sys
|
||||||
|
import jittor as jt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def main():
|
||||||
|
jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||||
|
|
||||||
|
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||||
|
class TestMpi(unittest.TestCase):
|
||||||
|
def test(self):
|
||||||
|
mpi = jt.compile_extern.mpi
|
||||||
|
if mpi.world_size() == 1:
|
||||||
|
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||||
|
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
|
||||||
|
print("run cmd", cmd)
|
||||||
|
jt.compiler.run_cmd(cmd)
|
||||||
|
else:
|
||||||
|
main()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
|
@ -178,7 +178,10 @@ def find_cache_path():
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def get_version(output):
|
def get_version(output):
|
||||||
version = run_cmd(output+" --version")
|
if output.endswith("mpicc"):
|
||||||
|
version = run_cmd(output+" --showme:version")
|
||||||
|
else:
|
||||||
|
version = run_cmd(output+" --version")
|
||||||
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
|
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
|
||||||
if len(v) == 0:
|
if len(v) == 0:
|
||||||
v = re.findall("[0-9]+\\.[0-9]+", version)
|
v = re.findall("[0-9]+\\.[0-9]+", version)
|
||||||
|
|
Loading…
Reference in New Issue