jt.mpi -> jt.in_mpi

This commit is contained in:
Dun Liang 2020-06-11 17:06:52 +08:00
parent 3a6470f4a5
commit a445f9f017
15 changed files with 83 additions and 43 deletions

View File

@ -28,6 +28,7 @@ extern int mpi_world_size;
extern int mpi_world_rank;
extern int mpi_local_rank;
extern bool inside_mpi;
extern bool mpi_enabled;
/**
Return number of MPI nodes.
@ -47,6 +48,19 @@ Return local ID of this MPI node.
// @pyjt(local_rank)
int _mpi_local_rank();
/**
Set MPI state, enable or disable, if disabled, all mpi operators
have no affect.
*/
// @pyjt(set_state)
inline void _mpi_set_state(bool enable) { mpi_enabled = enable; }
/**
Get MPI state, enable or disable.
*/
// @pyjt(get_state)
inline int _mpi_get_state() { return mpi_enabled; }
struct ArrayArgs;
/**

View File

@ -25,6 +25,10 @@ static auto make_mpi_all_reduce = get_op_info("mpi_all_reduce")
.get_constructor<VarPtr, Var*, NanoString>();
MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
if (!mpi_enabled) {
forward(x);
return;
}
if (op == ns_mean) {
auto var = make_mpi_all_reduce(x, ns_add);
var = make_binary(var, make_array(&mpi_world_size, 1, ns_int32), ns_divide);

View File

@ -17,6 +17,10 @@ namespace jittor {
#ifndef JIT
MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
if (!mpi_enabled) {
forward(x);
return;
}
#ifdef HAS_CUDA
if (use_cuda) {
static auto nccl_broadcast = has_op("nccl_broadcast")

View File

@ -25,6 +25,10 @@ static auto make_mpi_reduce = get_op_info("mpi_reduce")
.get_constructor<VarPtr, Var*, NanoString, int>();
MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(root) {
if (!mpi_enabled) {
forward(x);
return;
}
if (op == ns_mean) {
auto var = make_mpi_reduce(x, ns_add, root);
var = make_binary(var, make_array(&mpi_world_size, 1, ns_int32), ns_divide);

View File

@ -31,20 +31,22 @@ int mpi_world_size = 1;
int mpi_world_rank = 0;
int mpi_local_rank = 0;
bool inside_mpi = false;
bool mpi_enabled = false;
int _mpi_world_size() {
return mpi_world_size;
return mpi_enabled ? mpi_world_size : 1;
}
int _mpi_world_rank() {
return mpi_world_rank;
return mpi_enabled ? mpi_world_rank : 0;
}
int _mpi_local_rank() {
return mpi_local_rank;
return mpi_enabled ? mpi_local_rank : 0;
}
void _mpi_broadcast(ArrayArgs&& args, int root) {
if (!mpi_enabled) return;
int64 size = args.dtype.dsize();
for (auto j : args.shape)
size *= j;
@ -52,23 +54,23 @@ void _mpi_broadcast(ArrayArgs&& args, int root) {
}
static uint64_t getHostHash(const char* string) {
// Based on DJB2, result = result * 33 + char
uint64_t result = 5381;
for (int c = 0; string[c] != '\0'; c++){
result = ((result << 5) + result) + string[c];
}
return result;
// Based on DJB2, result = result * 33 + char
uint64_t result = 5381;
for (int c = 0; string[c] != '\0'; c++){
result = ((result << 5) + result) + string[c];
}
return result;
}
static void getHostName(char* hostname, int maxlen) {
gethostname(hostname, maxlen);
for (int i=0; i< maxlen; i++) {
if (hostname[i] == '.') {
hostname[i] = '\0';
return;
gethostname(hostname, maxlen);
for (int i=0; i< maxlen; i++) {
if (hostname[i] == '.') {
hostname[i] = '\0';
return;
}
}
}
}
struct mpi_initer {
@ -76,6 +78,7 @@ struct mpi_initer {
mpi_initer() {
inside_mpi = !!getenv("OMPI_COMM_WORLD_SIZE");
if (!inside_mpi) return;
mpi_enabled = true;
LOGvv << "MPI init...";
MPI_CHECK(MPI_Init(NULL, NULL));
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));

View File

@ -16,7 +16,7 @@ with lock.lock_scope():
from jittor_core import *
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops, mpi, mpi_ops
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi
if core.get_device_count() == 0:
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
if has_cuda:
@ -125,15 +125,19 @@ class profile_scope(_call_no_record_scope):
profiler.stop()
self.report.extend(profiler.report())
class single_process_scope(_call_no_record_scope):
class single_process_scope:
""" single_process_scope
Code in this scope will only be executed by single process.
All the mpi code inside this scope will have not affect.
mpi.world_rank() and mpi.local_rank() will return 0, world_size() will return 1,
example::
with jt.single_process_scope(rank=0):
......
with jt.single_process_scope(rank=0) as flag:
if flag:
......
@jt.single_process_scope(rank=0)
def xxx():
@ -143,23 +147,31 @@ class single_process_scope(_call_no_record_scope):
self.rank = rank
def __enter__(self):
global mpi
from jittor.dataset import dataset
self.mpi_backup = mpi
mpi = dataset.mpi = None
global in_mpi
self.bk_in_mpi = in_mpi
if mpi:
self.bk_mpi_state = mpi.get_state()
if not in_mpi:
return True
ret = self.rank == mpi.world_rank()
in_mpi = compile_extern.in_mpi = False
mpi.set_state(False)
return ret
def __exit__(self, *exc):
global mpi
from jittor.dataset import dataset
mpi = dataset.mpi = self.mpi_backup
global in_mpi
in_mpi = compile_extern.in_mpi = self.bk_in_mpi
if mpi:
mpi.set_state(self.bk_mpi_state)
def __call__(self, func):
global mpi
def inner(*args, **kw):
if mpi and mpi.world_rank() != self.rank:
return
with self:
ret = func(*args, **kw)
ret = None
with self as flag:
if flag:
ret = func(*args, **kw)
return ret
return inner
@ -504,7 +516,7 @@ class Module:
p.start_grad()
def mpi_param_broadcast(self, root=0):
if mpi is None: return
if not in_mpi: return
for p in self.parameters():
p.assign(p.mpi_broadcast(root).detach())

View File

@ -395,8 +395,7 @@ def setup_mpi():
setattr(core.Var, k, warper(mpi_ops.__dict__[k]))
setup_mpi()
if not inside_mpi():
mpi = None
in_mpi = inside_mpi()
setup_nccl()
setup_cutt()

View File

@ -197,7 +197,7 @@ class Dataset(object):
# pad to world_size
# last batch
# [.] -> [012]
if mpi:
if jt.in_mpi:
world_size = mpi.world_size()
world_rank = mpi.world_rank()
index_list = np.int32(index_list)

View File

@ -185,7 +185,7 @@ class BatchNorm(Module):
if self.is_train:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
if self.sync and jt.mpi:
if self.sync and jt.in_mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
@ -219,7 +219,7 @@ class BatchNorm1d(Module):
xmean = jt.mean(x, dims=[0], keepdims=1)
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
if self.sync and jt.mpi:
if self.sync and jt.in_mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
@ -249,7 +249,7 @@ class InstanceNorm2d(Module):
def execute(self, x):
xmean = jt.mean(x, dims=[2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
if self.sync and jt.mpi:
if self.sync and jt.in_mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")

View File

@ -61,7 +61,7 @@ class Optimizer(object):
grads = jt.grad(loss, params_has_grad)
# sync grads and model if in mpi
if jt.mpi:
if jt.in_mpi:
for g in grads:
g.assign(g.mpi_all_reduce("mean"))
if self.n_step % self.param_sync_iter == 0:

View File

@ -12,7 +12,7 @@ import jittor as jt
import numpy as np
mpi = jt.compile_extern.mpi
@unittest.skipIf(mpi is None, "no inside mpirun")
@unittest.skipIf(not jt.in_mpi, "no inside mpirun")
class TestMpi(unittest.TestCase):
def test_mpi_test_op(self):
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123

View File

@ -48,7 +48,7 @@ class FakeMpiBatchNorm(nn.Module):
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
@unittest.skipIf(mpi is None, "no inside mpirun")
@unittest.skipIf(not jt.in_mpi, "no inside mpirun")
class TestMpiBatchnorm(unittest.TestCase):
@classmethod
def setUpClass(self):

View File

@ -16,7 +16,7 @@ mpi = jt.compile_extern.mpi
if mpi:
n = mpi.world_size()
@unittest.skipIf(mpi is None, "no inside mpirun")
@unittest.skipIf(not jt.in_mpi, "no inside mpirun")
class TestMpiOps(unittest.TestCase):
@classmethod
def setUpClass(self):

View File

@ -113,7 +113,7 @@ class TestResnet(unittest.TestCase):
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
if jt.mpi:
if jt.in_mpi:
assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
else:
assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()

View File

@ -30,7 +30,7 @@ def val2():
if i == 5:
break
@unittest.skipIf(mpi is None, "no inside mpirun")
@unittest.skipIf(not jt.in_mpi, "no inside mpirun")
class TestSingleProcessScope(unittest.TestCase):
def test_single_process_scope(self):
val1()