fix segfault

This commit is contained in:
Dun Liang 2020-04-08 21:52:43 +08:00
parent 5e5c8de82f
commit 7aa61c8900
5 changed files with 32 additions and 15 deletions

View File

@ -28,9 +28,9 @@ void throw_mpi_error(int result,
namespace jittor {
int mpi_world_size;
int mpi_world_rank;
int mpi_local_rank;
int mpi_world_size = 1;
int mpi_world_rank = 0;
int mpi_local_rank = 0;
int _mpi_world_size() {
return mpi_world_size;
@ -69,6 +69,7 @@ static void getHostName(char* hostname, int maxlen) {
struct mpi_initer {
mpi_initer() {
LOGvv << "MPI init...";
MPI_CHECK(MPI_Init(NULL, NULL));
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
@ -84,6 +85,9 @@ mpi_initer() {
if (p == mpi_world_rank) break;
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_rank++;
}
LOGv << "MPI init finished: local" << mpi_local_rank
<< "global" << mpi_world_rank
<< "size" << mpi_world_size;
}
~mpi_initer() {

View File

@ -322,6 +322,8 @@ def manual_link(flags):
ctypes.CDLL(libname, dlopen_flags)
break
def inside_mpi():
return "OMPI_COMM_WORLD_SIZE" in os.environ
def setup_mpi():
global mpi_ops, mpi, use_mpi
@ -330,7 +332,6 @@ def setup_mpi():
mpi_ops = None
mpi = None
has_mpi = False
if not use_mpi: return
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
if mpicc_path == "":
LOG.i("mpicc not found, distribution disabled.")
@ -338,6 +339,8 @@ def setup_mpi():
else:
use_mpi = True
has_mpi = True
if not inside_mpi():
use_mpi = False
if not use_mpi:
return
@ -345,8 +348,8 @@ def setup_mpi():
mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile")
mpi_link_flags = run_cmd(mpicc_path+" --showme:link")
mpi_flags = mpi_compile_flags + " " + mpi_link_flags
LOG.i("mpi_flags: "+mpi_flags)
manual_link(mpi_flags)
LOG.v("mpi_flags: "+mpi_flags)
# manual_link(mpi_flags)
# find all source files
mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
@ -359,8 +362,11 @@ def setup_mpi():
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
# libmpi cannot use deepbind, it need to
# share the 'environ' symbol.
mpi = compile_custom_ops(mpi_src_files,
extra_flags=f" {mpi_flags} ", return_module=True)
extra_flags=f" {mpi_flags} ", return_module=True,
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW)
mpi_ops = mpi.ops
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))

View File

@ -554,7 +554,11 @@ def compile_custom_op(header, source, op_name, warp=True):
m = compile_custom_ops([hname, ccname])
return getattr(m, op_name)
def compile_custom_ops(filenames, extra_flags="", return_module=False):
def compile_custom_ops(
filenames,
extra_flags="",
return_module=False,
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
"""Compile custom ops
filenames: path of op source files, filenames must be
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
@ -639,7 +643,7 @@ def compile_custom_ops(filenames, extra_flags="", return_module=False):
lib_path = os.path.join(cache_path, "custom_ops")
if lib_path not in os.sys.path:
os.sys.path.append(lib_path)
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
with jit_utils.import_scope(dlopen_flags):
exec(f"import {gen_name}")
mod = locals()[gen_name]
if return_module:

View File

@ -19,15 +19,13 @@ def main():
with jt.flag_scope(use_cuda=1):
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
@unittest.skipIf(jt.compile_extern.has_mpi is None, "no mpi found")
class TestMpi(unittest.TestCase):
def test(self):
mpi = jt.compile_extern.mpi
if mpi.world_size() == 1:
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
print("run cmd", cmd)
jt.compiler.run_cmd(cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd
else:
main()

View File

@ -54,17 +54,22 @@ struct EventQueue {
static void worker_caller();
void run_sync(Func func) {
// send work to worker and do something by self
std::unique_lock<std::mutex> l(mtx);
this->func = func;
run_sync_done = false;
// send func to worker
worker.run(worker_caller);
while (1) {
// check self work or worker's status
cv.wait(l);
list<Func> ts = move(tasks);
l.unlock();
// do self works
for (auto func : ts)
func();
l.lock();
// worker is finished
if (run_sync_done)
return;
}