mirror of https://github.com/Jittor/Jittor
Merge branch 'mpi_init' of https://github.com/Jittor/jittor
This commit is contained in:
commit
45b8375e80
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#define OMPI_SKIP_MPICXX
|
||||
#include <mpi.h>
|
||||
|
||||
extern void throw_mpi_error(int result,
|
||||
char const *const func, const char *const file, int const line);
|
||||
|
||||
static inline void mpi_check(int result,
|
||||
char const *const func, const char *const file, int const line) {
|
||||
if (result != MPI_SUCCESS) {
|
||||
throw_mpi_error(result, func, file, line);
|
||||
}
|
||||
}
|
||||
|
||||
#define MPI_CHECK(val) mpi_check((val), #val, __FILE__, __LINE__)
|
||||
|
||||
namespace jittor {
|
||||
|
||||
extern int mpi_world_size;
|
||||
extern int mpi_world_rank;
|
||||
|
||||
// @pyjt(world_size)
|
||||
int _mpi_world_size();
|
||||
|
||||
// @pyjt(world_rank)
|
||||
int _mpi_world_rank();
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,42 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2019 Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mpi_warper.h"
|
||||
|
||||
#include "var.h"
|
||||
#include "mpi_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) {
|
||||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void MpiTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
void MpiTestOp::jit_run() {
|
||||
output->ptr<T>()[0] = 123;
|
||||
|
||||
int world_size = mpi_world_size;
|
||||
|
||||
int world_rank = mpi_world_rank;
|
||||
|
||||
char processor_name[MPI_MAX_PROCESSOR_NAME];
|
||||
int name_len;
|
||||
MPI_CHECK(MPI_Get_processor_name(processor_name, &name_len));
|
||||
|
||||
printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size);
|
||||
|
||||
}
|
||||
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,23 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MpiTestOp : Op {
|
||||
Var* output;
|
||||
string cmd;
|
||||
|
||||
MpiTestOp(string cmd);
|
||||
|
||||
const char* name() const override { return "mpi_test"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,55 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mpi_warper.h"
|
||||
#include "common.h"
|
||||
|
||||
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
|
||||
|
||||
void throw_mpi_error(int result,
|
||||
char const *const func, const char *const file, int const line) {
|
||||
int resultlen;
|
||||
MPI_Error_string(result, jt_mpi_err_buffer, &resultlen);
|
||||
fprintf(stderr, "MPI error at %s:%d code=%d(%s) \"%s\" \n",
|
||||
file, line,
|
||||
static_cast<unsigned int>(result), jt_mpi_err_buffer, func);
|
||||
throw std::runtime_error("MPI error");
|
||||
}
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
||||
int mpi_world_size;
|
||||
int mpi_world_rank;
|
||||
|
||||
int _mpi_world_size() {
|
||||
return mpi_world_size;
|
||||
}
|
||||
|
||||
int _mpi_world_rank() {
|
||||
return mpi_world_rank;
|
||||
}
|
||||
|
||||
|
||||
struct mpi_initer {
|
||||
|
||||
mpi_initer() {
|
||||
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
|
||||
}
|
||||
|
||||
~mpi_initer() {
|
||||
MPI_CHECK(MPI_Finalize());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
static mpi_initer mpi_init;
|
||||
|
||||
} // jittor
|
|
@ -296,11 +296,74 @@ def setup_nccl():
|
|||
nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops")
|
||||
nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)]
|
||||
nccl_ops = compile_custom_ops(nccl_op_files,
|
||||
extra_flags=f" -I'{nccl_include_path}'")
|
||||
extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ")
|
||||
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
||||
|
||||
def manual_link(flags):
|
||||
lib_dirs = []
|
||||
libs = []
|
||||
for f in flags.split():
|
||||
if f.startswith("-l"):
|
||||
libs.append(f[2:])
|
||||
elif f.startswith("-L"):
|
||||
lib_dirs.append(f[2:])
|
||||
LOG.v("manual_link:", flags)
|
||||
LOG.v("lib_dirs:", lib_dirs)
|
||||
LOG.v("libs:", libs)
|
||||
for lib in libs:
|
||||
for d in lib_dirs:
|
||||
libname = os.path.join(d, f"lib{lib}.so")
|
||||
if os.path.isfile(libname):
|
||||
LOG.v("link:", libname)
|
||||
ctypes.CDLL(libname, dlopen_flags)
|
||||
break
|
||||
|
||||
|
||||
def setup_mpi():
|
||||
global mpi_ops, mpi, use_mpi
|
||||
global mpicc_path, has_mpi
|
||||
use_mpi = os.environ.get("use_mpi", "1")=="1"
|
||||
mpi_ops = None
|
||||
mpi = None
|
||||
has_mpi = False
|
||||
if not use_mpi: return
|
||||
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
|
||||
if mpicc_path == "":
|
||||
LOG.i("mpicc not found, distribution disabled.")
|
||||
use_mpi = False
|
||||
else:
|
||||
use_mpi = True
|
||||
has_mpi = True
|
||||
if not use_mpi:
|
||||
return
|
||||
|
||||
global mpi_compile_flags, mpi_link_flags, mpi_flags
|
||||
mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile")
|
||||
mpi_link_flags = run_cmd(mpicc_path+" --showme:link")
|
||||
mpi_flags = mpi_compile_flags + " " + mpi_link_flags
|
||||
LOG.i("mpi_flags: "+mpi_flags)
|
||||
manual_link(mpi_flags)
|
||||
|
||||
# find all source files
|
||||
mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
|
||||
mpi_src_files = []
|
||||
for r, _, f in os.walk(mpi_src_dir):
|
||||
for fname in f:
|
||||
mpi_src_files.append(os.path.join(r, fname))
|
||||
|
||||
# mpi compile flags add for nccl
|
||||
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
|
||||
|
||||
mpi = compile_custom_ops(mpi_src_files,
|
||||
extra_flags=f" {mpi_flags} ", return_module=True)
|
||||
mpi_ops = mpi.ops
|
||||
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
|
||||
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
|
||||
|
||||
setup_mpi()
|
||||
setup_nccl()
|
||||
|
||||
setup_cutt()
|
||||
setup_mkl()
|
||||
setup_nccl()
|
||||
|
||||
setup_cuda_extern()
|
||||
|
|
|
@ -554,22 +554,31 @@ def compile_custom_op(header, source, op_name, warp=True):
|
|||
m = compile_custom_ops([hname, ccname])
|
||||
return getattr(m, op_name)
|
||||
|
||||
def compile_custom_ops(filenames, extra_flags=""):
|
||||
def compile_custom_ops(filenames, extra_flags="", return_module=False):
|
||||
"""Compile custom ops
|
||||
filenames: path of op source files, filenames must be
|
||||
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
|
||||
type name of op must be XxxXxxOp.
|
||||
extra_flags: extra compile flags
|
||||
return_module: return module rather than ops(default: False)
|
||||
return: compiled ops
|
||||
"""
|
||||
srcs = {}
|
||||
headers = {}
|
||||
builds = []
|
||||
includes = []
|
||||
pyjt_includes = []
|
||||
for name in filenames:
|
||||
name = os.path.realpath(name)
|
||||
if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"):
|
||||
builds.append(name)
|
||||
if name.endswith(".h"):
|
||||
dirname = os.path.dirname(name)
|
||||
if dirname.endswith("inc"):
|
||||
includes.append(dirname)
|
||||
with open(name, "r") as f:
|
||||
if "@pyjt" in f.read():
|
||||
pyjt_includes.append(name)
|
||||
bname = os.path.basename(name)
|
||||
bname = os.path.splitext(bname)[0]
|
||||
if bname.endswith("_op"):
|
||||
|
@ -597,14 +606,33 @@ def compile_custom_ops(filenames, extra_flags=""):
|
|||
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
|
||||
gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
|
||||
gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
|
||||
with open(gen_head_fname, "w") as f:
|
||||
f.write(gen_src)
|
||||
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname)
|
||||
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
|
||||
# gen src initialize first
|
||||
builds.insert(0, gen_src_fname)
|
||||
|
||||
def insert_anchor(gen_src, anchor_str, insert_str):
|
||||
# insert insert_str after anchor_str into gen_src
|
||||
return gen_src.replace(anchor_str, anchor_str+insert_str, 1)
|
||||
|
||||
for name in pyjt_includes:
|
||||
LOG.i("handle pyjt_include", name)
|
||||
bname = name.split("/")[-1].split(".")[0]
|
||||
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+"_"+bname+".cc")
|
||||
pyjt_compiler.compile_single(name, gen_src_fname)
|
||||
builds.insert(1, gen_src_fname)
|
||||
gen_src = insert_anchor(gen_src,
|
||||
"namespace jittor {",
|
||||
f"extern void pyjt_def_{bname}(PyObject* m);")
|
||||
gen_src = insert_anchor(gen_src,
|
||||
"init_module(PyModuleDef* mdef, PyObject* m) {",
|
||||
f"jittor::pyjt_def_{bname}(m);")
|
||||
|
||||
with open(gen_head_fname, "w") as f:
|
||||
f.write(gen_src)
|
||||
|
||||
LOG.vvv(f"Build custum ops lib:{gen_lib}")
|
||||
LOG.vvvv(f"Build sources:{builds}")
|
||||
compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib)
|
||||
compile(cc_path, extra_flags+cc_flags+opt_flags+includes, builds, gen_lib)
|
||||
|
||||
# add python path and import
|
||||
LOG.vvv(f"Import custum ops lib:{gen_lib}")
|
||||
|
@ -613,7 +641,10 @@ def compile_custom_ops(filenames, extra_flags=""):
|
|||
os.sys.path.append(lib_path)
|
||||
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
||||
exec(f"import {gen_name}")
|
||||
return (locals()[gen_name]).ops
|
||||
mod = locals()[gen_name]
|
||||
if return_module:
|
||||
return mod
|
||||
return mod.ops
|
||||
|
||||
|
||||
def get_full_path_of_executable(name):
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
def main():
|
||||
jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||
class TestMpi(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if mpi.world_size() == 1:
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
else:
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -178,7 +178,10 @@ def find_cache_path():
|
|||
return path
|
||||
|
||||
def get_version(output):
|
||||
version = run_cmd(output+" --version")
|
||||
if output.endswith("mpicc"):
|
||||
version = run_cmd(output+" --showme:version")
|
||||
else:
|
||||
version = run_cmd(output+" --version")
|
||||
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
|
||||
if len(v) == 0:
|
||||
v = re.findall("[0-9]+\\.[0-9]+", version)
|
||||
|
|
Loading…
Reference in New Issue