diff --git a/extern/cuda/nccl/ops/nccl_test_op.h b/extern/cuda/nccl/ops/nccl_test_op.h index cb4b18bd..9f882aba 100644 --- a/extern/cuda/nccl/ops/nccl_test_op.h +++ b/extern/cuda/nccl/ops/nccl_test_op.h @@ -1,5 +1,6 @@ // *************************************************************** -// Copyright (c) 2020 +// Copyright (c) 2020 Jittor. +// Authors: // Dun Liang . // All Rights Reserved. // This file is subject to the terms and conditions defined in diff --git a/extern/mpi/inc/mpi_warper.h b/extern/mpi/inc/mpi_warper.h new file mode 100644 index 00000000..936cdd5b --- /dev/null +++ b/extern/mpi/inc/mpi_warper.h @@ -0,0 +1,35 @@ +// *************************************************************** +// Copyright (c) 2020 Jittor. +// Authors: +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#define OMPI_SKIP_MPICXX +#include + +extern void throw_mpi_error(int result, + char const *const func, const char *const file, int const line); + +static inline void mpi_check(int result, + char const *const func, const char *const file, int const line) { + if (result != MPI_SUCCESS) { + throw_mpi_error(result, func, file, line); + } +} + +#define MPI_CHECK(val) mpi_check((val), #val, __FILE__, __LINE__) + +namespace jittor { + +extern int mpi_world_size; +extern int mpi_world_rank; + +// @pyjt(world_size) +int _mpi_world_size(); + +// @pyjt(world_rank) +int _mpi_world_rank(); + +} // jittor diff --git a/extern/mpi/ops/mpi_test_op.cc b/extern/mpi/ops/mpi_test_op.cc new file mode 100644 index 00000000..fe97d14b --- /dev/null +++ b/extern/mpi/ops/mpi_test_op.cc @@ -0,0 +1,42 @@ +// *************************************************************** +// Copyright (c) 2019 Dun Liang . All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mpi_warper.h" + +#include "var.h" +#include "mpi_test_op.h" +#include "misc/str_utils.h" + +namespace jittor { + +#ifndef JIT +MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) { + output = create_output(1, ns_float32); +} + +void MpiTestOp::jit_prepare() { + add_jit_define("T", ns_float32); +} + +#else // JIT + +void MpiTestOp::jit_run() { + output->ptr()[0] = 123; + + int world_size = mpi_world_size; + + int world_rank = mpi_world_rank; + + char processor_name[MPI_MAX_PROCESSOR_NAME]; + int name_len; + MPI_CHECK(MPI_Get_processor_name(processor_name, &name_len)); + + printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size); + +} + +#endif // JIT + +} // jittor diff --git a/extern/mpi/ops/mpi_test_op.h b/extern/mpi/ops/mpi_test_op.h new file mode 100644 index 00000000..e366d84b --- /dev/null +++ b/extern/mpi/ops/mpi_test_op.h @@ -0,0 +1,23 @@ +// *************************************************************** +// Copyright (c) 2020 +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct MpiTestOp : Op { + Var* output; + string cmd; + + MpiTestOp(string cmd); + + const char* name() const override { return "mpi_test"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/extern/mpi/src/mpi_warper.cc b/extern/mpi/src/mpi_warper.cc new file mode 100644 index 00000000..7dd3f00b --- /dev/null +++ b/extern/mpi/src/mpi_warper.cc @@ -0,0 +1,55 @@ +// *************************************************************** +// Copyright (c) 2020 Jittor. +// Authors: +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mpi_warper.h" +#include "common.h" + +char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING]; + +void throw_mpi_error(int result, + char const *const func, const char *const file, int const line) { + int resultlen; + MPI_Error_string(result, jt_mpi_err_buffer, &resultlen); + fprintf(stderr, "MPI error at %s:%d code=%d(%s) \"%s\" \n", + file, line, + static_cast(result), jt_mpi_err_buffer, func); + throw std::runtime_error("MPI error"); +} + +namespace jittor { + + +int mpi_world_size; +int mpi_world_rank; + +int _mpi_world_size() { + return mpi_world_size; +} + +int _mpi_world_rank() { + return mpi_world_rank; +} + + +struct mpi_initer { + +mpi_initer() { + MPI_CHECK(MPI_Init(NULL, NULL)); + MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size)); + MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank)); +} + +~mpi_initer() { + MPI_CHECK(MPI_Finalize()); +} + +}; + +static mpi_initer mpi_init; + +} // jittor \ No newline at end of file diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 188d13dd..888260af 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -296,11 +296,74 @@ def setup_nccl(): nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops") nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)] nccl_ops = compile_custom_ops(nccl_op_files, - extra_flags=f" -I'{nccl_include_path}'") + extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ") LOG.vv("Get nccl_ops: "+str(dir(nccl_ops))) +def manual_link(flags): + lib_dirs = [] + libs = [] + for f in flags.split(): + if f.startswith("-l"): + libs.append(f[2:]) + elif f.startswith("-L"): + lib_dirs.append(f[2:]) + LOG.v("manual_link:", flags) + LOG.v("lib_dirs:", lib_dirs) + LOG.v("libs:", libs) + for lib in libs: + for d in lib_dirs: + libname = os.path.join(d, f"lib{lib}.so") + if os.path.isfile(libname): + LOG.v("link:", libname) + ctypes.CDLL(libname, dlopen_flags) + break + + +def setup_mpi(): + global mpi_ops, mpi, use_mpi + global mpicc_path, has_mpi + use_mpi = os.environ.get("use_mpi", "1")=="1" + mpi_ops = None + mpi = None + has_mpi = False + if not use_mpi: return + mpicc_path = env_or_try_find('mpicc_path', 'mpicc') + if mpicc_path == "": + LOG.i("mpicc not found, distribution disabled.") + use_mpi = False + else: + use_mpi = True + has_mpi = True + if not use_mpi: + return + + global mpi_compile_flags, mpi_link_flags, mpi_flags + mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile") + mpi_link_flags = run_cmd(mpicc_path+" --showme:link") + mpi_flags = mpi_compile_flags + " " + mpi_link_flags + LOG.i("mpi_flags: "+mpi_flags) + manual_link(mpi_flags) + + # find all source files + mpi_src_dir = os.path.join(jittor_path, "extern", "mpi") + mpi_src_files = [] + for r, _, f in os.walk(mpi_src_dir): + for fname in f: + mpi_src_files.append(os.path.join(r, fname)) + + # mpi compile flags add for nccl + mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' " + + mpi = compile_custom_ops(mpi_src_files, + extra_flags=f" {mpi_flags} ", return_module=True) + mpi_ops = mpi.ops + LOG.vv("Get mpi: "+str(mpi.__dict__.keys())) + LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys())) + +setup_mpi() +setup_nccl() + setup_cutt() setup_mkl() -setup_nccl() setup_cuda_extern() diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index 481ed444..773102fc 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -554,22 +554,31 @@ def compile_custom_op(header, source, op_name, warp=True): m = compile_custom_ops([hname, ccname]) return getattr(m, op_name) -def compile_custom_ops(filenames, extra_flags=""): +def compile_custom_ops(filenames, extra_flags="", return_module=False): """Compile custom ops filenames: path of op source files, filenames must be pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the type name of op must be XxxXxxOp. extra_flags: extra compile flags + return_module: return module rather than ops(default: False) return: compiled ops """ srcs = {} headers = {} builds = [] includes = [] + pyjt_includes = [] for name in filenames: name = os.path.realpath(name) if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"): builds.append(name) + if name.endswith(".h"): + dirname = os.path.dirname(name) + if dirname.endswith("inc"): + includes.append(dirname) + with open(name, "r") as f: + if "@pyjt" in f.read(): + pyjt_includes.append(name) bname = os.path.basename(name) bname = os.path.splitext(bname)[0] if bname.endswith("_op"): @@ -597,14 +606,33 @@ def compile_custom_ops(filenames, extra_flags=""): gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc") gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h") gen_lib = os.path.join("custom_ops", gen_name+extension_suffix) - with open(gen_head_fname, "w") as f: - f.write(gen_src) - pyjt_compiler.compile_single(gen_head_fname, gen_src_fname) + pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src) # gen src initialize first builds.insert(0, gen_src_fname) + + def insert_anchor(gen_src, anchor_str, insert_str): + # insert insert_str after anchor_str into gen_src + return gen_src.replace(anchor_str, anchor_str+insert_str, 1) + + for name in pyjt_includes: + LOG.i("handle pyjt_include", name) + bname = name.split("/")[-1].split(".")[0] + gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+"_"+bname+".cc") + pyjt_compiler.compile_single(name, gen_src_fname) + builds.insert(1, gen_src_fname) + gen_src = insert_anchor(gen_src, + "namespace jittor {", + f"extern void pyjt_def_{bname}(PyObject* m);") + gen_src = insert_anchor(gen_src, + "init_module(PyModuleDef* mdef, PyObject* m) {", + f"jittor::pyjt_def_{bname}(m);") + + with open(gen_head_fname, "w") as f: + f.write(gen_src) + LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvvv(f"Build sources:{builds}") - compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib) + compile(cc_path, extra_flags+cc_flags+opt_flags+includes, builds, gen_lib) # add python path and import LOG.vvv(f"Import custum ops lib:{gen_lib}") @@ -613,7 +641,10 @@ def compile_custom_ops(filenames, extra_flags=""): os.sys.path.append(lib_path) with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND): exec(f"import {gen_name}") - return (locals()[gen_name]).ops + mod = locals()[gen_name] + if return_module: + return mod + return mod.ops def get_full_path_of_executable(name): diff --git a/python/jittor/test/test_mpi.py b/python/jittor/test/test_mpi.py new file mode 100644 index 00000000..f82bcc35 --- /dev/null +++ b/python/jittor/test/test_mpi.py @@ -0,0 +1,30 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os, sys +import jittor as jt +import numpy as np + +def main(): + jt.compile_extern.mpi_ops.mpi_test("").data == 123 + +@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found") +class TestMpi(unittest.TestCase): + def test(self): + mpi = jt.compile_extern.mpi + if mpi.world_size() == 1: + mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun') + cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi" + print("run cmd", cmd) + jt.compiler.run_cmd(cmd) + else: + main() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor_utils/__init__.py b/python/jittor_utils/__init__.py index 5c0e015d..ab909880 100644 --- a/python/jittor_utils/__init__.py +++ b/python/jittor_utils/__init__.py @@ -178,7 +178,10 @@ def find_cache_path(): return path def get_version(output): - version = run_cmd(output+" --version") + if output.endswith("mpicc"): + version = run_cmd(output+" --showme:version") + else: + version = run_cmd(output+" --version") v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version) if len(v) == 0: v = re.findall("[0-9]+\\.[0-9]+", version)