mirror of https://github.com/Jittor/Jittor
fix segfault
This commit is contained in:
parent
5e5c8de82f
commit
7aa61c8900
|
@ -28,9 +28,9 @@ void throw_mpi_error(int result,
|
|||
namespace jittor {
|
||||
|
||||
|
||||
int mpi_world_size;
|
||||
int mpi_world_rank;
|
||||
int mpi_local_rank;
|
||||
int mpi_world_size = 1;
|
||||
int mpi_world_rank = 0;
|
||||
int mpi_local_rank = 0;
|
||||
|
||||
int _mpi_world_size() {
|
||||
return mpi_world_size;
|
||||
|
@ -69,6 +69,7 @@ static void getHostName(char* hostname, int maxlen) {
|
|||
struct mpi_initer {
|
||||
|
||||
mpi_initer() {
|
||||
LOGvv << "MPI init...";
|
||||
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
|
||||
|
@ -84,6 +85,9 @@ mpi_initer() {
|
|||
if (p == mpi_world_rank) break;
|
||||
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_rank++;
|
||||
}
|
||||
LOGv << "MPI init finished: local" << mpi_local_rank
|
||||
<< "global" << mpi_world_rank
|
||||
<< "size" << mpi_world_size;
|
||||
}
|
||||
|
||||
~mpi_initer() {
|
||||
|
|
|
@ -322,6 +322,8 @@ def manual_link(flags):
|
|||
ctypes.CDLL(libname, dlopen_flags)
|
||||
break
|
||||
|
||||
def inside_mpi():
|
||||
return "OMPI_COMM_WORLD_SIZE" in os.environ
|
||||
|
||||
def setup_mpi():
|
||||
global mpi_ops, mpi, use_mpi
|
||||
|
@ -330,7 +332,6 @@ def setup_mpi():
|
|||
mpi_ops = None
|
||||
mpi = None
|
||||
has_mpi = False
|
||||
if not use_mpi: return
|
||||
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
|
||||
if mpicc_path == "":
|
||||
LOG.i("mpicc not found, distribution disabled.")
|
||||
|
@ -338,6 +339,8 @@ def setup_mpi():
|
|||
else:
|
||||
use_mpi = True
|
||||
has_mpi = True
|
||||
if not inside_mpi():
|
||||
use_mpi = False
|
||||
if not use_mpi:
|
||||
return
|
||||
|
||||
|
@ -345,8 +348,8 @@ def setup_mpi():
|
|||
mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile")
|
||||
mpi_link_flags = run_cmd(mpicc_path+" --showme:link")
|
||||
mpi_flags = mpi_compile_flags + " " + mpi_link_flags
|
||||
LOG.i("mpi_flags: "+mpi_flags)
|
||||
manual_link(mpi_flags)
|
||||
LOG.v("mpi_flags: "+mpi_flags)
|
||||
# manual_link(mpi_flags)
|
||||
|
||||
# find all source files
|
||||
mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
|
||||
|
@ -359,8 +362,11 @@ def setup_mpi():
|
|||
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
|
||||
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
|
||||
|
||||
# libmpi cannot use deepbind, it need to
|
||||
# share the 'environ' symbol.
|
||||
mpi = compile_custom_ops(mpi_src_files,
|
||||
extra_flags=f" {mpi_flags} ", return_module=True)
|
||||
extra_flags=f" {mpi_flags} ", return_module=True,
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW)
|
||||
mpi_ops = mpi.ops
|
||||
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
|
||||
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
|
||||
|
|
|
@ -554,7 +554,11 @@ def compile_custom_op(header, source, op_name, warp=True):
|
|||
m = compile_custom_ops([hname, ccname])
|
||||
return getattr(m, op_name)
|
||||
|
||||
def compile_custom_ops(filenames, extra_flags="", return_module=False):
|
||||
def compile_custom_ops(
|
||||
filenames,
|
||||
extra_flags="",
|
||||
return_module=False,
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
||||
"""Compile custom ops
|
||||
filenames: path of op source files, filenames must be
|
||||
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
|
||||
|
@ -639,7 +643,7 @@ def compile_custom_ops(filenames, extra_flags="", return_module=False):
|
|||
lib_path = os.path.join(cache_path, "custom_ops")
|
||||
if lib_path not in os.sys.path:
|
||||
os.sys.path.append(lib_path)
|
||||
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
||||
with jit_utils.import_scope(dlopen_flags):
|
||||
exec(f"import {gen_name}")
|
||||
mod = locals()[gen_name]
|
||||
if return_module:
|
||||
|
|
|
@ -19,15 +19,13 @@ def main():
|
|||
with jt.flag_scope(use_cuda=1):
|
||||
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||
@unittest.skipIf(jt.compile_extern.has_mpi is None, "no mpi found")
|
||||
class TestMpi(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if mpi.world_size() == 1:
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||
else:
|
||||
main()
|
||||
|
||||
|
|
|
@ -54,17 +54,22 @@ struct EventQueue {
|
|||
static void worker_caller();
|
||||
|
||||
void run_sync(Func func) {
|
||||
// send work to worker and do something by self
|
||||
std::unique_lock<std::mutex> l(mtx);
|
||||
this->func = func;
|
||||
run_sync_done = false;
|
||||
// send func to worker
|
||||
worker.run(worker_caller);
|
||||
while (1) {
|
||||
// check self work or worker's status
|
||||
cv.wait(l);
|
||||
list<Func> ts = move(tasks);
|
||||
l.unlock();
|
||||
// do self works
|
||||
for (auto func : ts)
|
||||
func();
|
||||
l.lock();
|
||||
// worker is finished
|
||||
if (run_sync_done)
|
||||
return;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue