mirror of https://github.com/Jittor/Jittor
jt.mpi -> jt.in_mpi
This commit is contained in:
parent
3a6470f4a5
commit
a445f9f017
|
@ -28,6 +28,7 @@ extern int mpi_world_size;
|
|||
extern int mpi_world_rank;
|
||||
extern int mpi_local_rank;
|
||||
extern bool inside_mpi;
|
||||
extern bool mpi_enabled;
|
||||
|
||||
/**
|
||||
Return number of MPI nodes.
|
||||
|
@ -47,6 +48,19 @@ Return local ID of this MPI node.
|
|||
// @pyjt(local_rank)
|
||||
int _mpi_local_rank();
|
||||
|
||||
/**
|
||||
Set MPI state, enable or disable, if disabled, all mpi operators
|
||||
have no affect.
|
||||
*/
|
||||
// @pyjt(set_state)
|
||||
inline void _mpi_set_state(bool enable) { mpi_enabled = enable; }
|
||||
|
||||
/**
|
||||
Get MPI state, enable or disable.
|
||||
*/
|
||||
// @pyjt(get_state)
|
||||
inline int _mpi_get_state() { return mpi_enabled; }
|
||||
|
||||
struct ArrayArgs;
|
||||
|
||||
/**
|
||||
|
|
|
@ -25,6 +25,10 @@ static auto make_mpi_all_reduce = get_op_info("mpi_all_reduce")
|
|||
.get_constructor<VarPtr, Var*, NanoString>();
|
||||
|
||||
MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
|
||||
if (!mpi_enabled) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
if (op == ns_mean) {
|
||||
auto var = make_mpi_all_reduce(x, ns_add);
|
||||
var = make_binary(var, make_array(&mpi_world_size, 1, ns_int32), ns_divide);
|
||||
|
|
|
@ -17,6 +17,10 @@ namespace jittor {
|
|||
|
||||
#ifndef JIT
|
||||
MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
|
||||
if (!mpi_enabled) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static auto nccl_broadcast = has_op("nccl_broadcast")
|
||||
|
|
|
@ -25,6 +25,10 @@ static auto make_mpi_reduce = get_op_info("mpi_reduce")
|
|||
.get_constructor<VarPtr, Var*, NanoString, int>();
|
||||
|
||||
MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(root) {
|
||||
if (!mpi_enabled) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
if (op == ns_mean) {
|
||||
auto var = make_mpi_reduce(x, ns_add, root);
|
||||
var = make_binary(var, make_array(&mpi_world_size, 1, ns_int32), ns_divide);
|
||||
|
|
|
@ -31,20 +31,22 @@ int mpi_world_size = 1;
|
|||
int mpi_world_rank = 0;
|
||||
int mpi_local_rank = 0;
|
||||
bool inside_mpi = false;
|
||||
bool mpi_enabled = false;
|
||||
|
||||
int _mpi_world_size() {
|
||||
return mpi_world_size;
|
||||
return mpi_enabled ? mpi_world_size : 1;
|
||||
}
|
||||
|
||||
int _mpi_world_rank() {
|
||||
return mpi_world_rank;
|
||||
return mpi_enabled ? mpi_world_rank : 0;
|
||||
}
|
||||
|
||||
int _mpi_local_rank() {
|
||||
return mpi_local_rank;
|
||||
return mpi_enabled ? mpi_local_rank : 0;
|
||||
}
|
||||
|
||||
void _mpi_broadcast(ArrayArgs&& args, int root) {
|
||||
if (!mpi_enabled) return;
|
||||
int64 size = args.dtype.dsize();
|
||||
for (auto j : args.shape)
|
||||
size *= j;
|
||||
|
@ -52,23 +54,23 @@ void _mpi_broadcast(ArrayArgs&& args, int root) {
|
|||
}
|
||||
|
||||
static uint64_t getHostHash(const char* string) {
|
||||
// Based on DJB2, result = result * 33 + char
|
||||
uint64_t result = 5381;
|
||||
for (int c = 0; string[c] != '\0'; c++){
|
||||
result = ((result << 5) + result) + string[c];
|
||||
}
|
||||
return result;
|
||||
// Based on DJB2, result = result * 33 + char
|
||||
uint64_t result = 5381;
|
||||
for (int c = 0; string[c] != '\0'; c++){
|
||||
result = ((result << 5) + result) + string[c];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
static void getHostName(char* hostname, int maxlen) {
|
||||
gethostname(hostname, maxlen);
|
||||
for (int i=0; i< maxlen; i++) {
|
||||
if (hostname[i] == '.') {
|
||||
hostname[i] = '\0';
|
||||
return;
|
||||
gethostname(hostname, maxlen);
|
||||
for (int i=0; i< maxlen; i++) {
|
||||
if (hostname[i] == '.') {
|
||||
hostname[i] = '\0';
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct mpi_initer {
|
||||
|
@ -76,6 +78,7 @@ struct mpi_initer {
|
|||
mpi_initer() {
|
||||
inside_mpi = !!getenv("OMPI_COMM_WORLD_SIZE");
|
||||
if (!inside_mpi) return;
|
||||
mpi_enabled = true;
|
||||
LOGvv << "MPI init...";
|
||||
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||
|
|
|
@ -16,7 +16,7 @@ with lock.lock_scope():
|
|||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi
|
||||
if core.get_device_count() == 0:
|
||||
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
|
||||
if has_cuda:
|
||||
|
@ -125,15 +125,19 @@ class profile_scope(_call_no_record_scope):
|
|||
profiler.stop()
|
||||
self.report.extend(profiler.report())
|
||||
|
||||
class single_process_scope(_call_no_record_scope):
|
||||
class single_process_scope:
|
||||
""" single_process_scope
|
||||
|
||||
Code in this scope will only be executed by single process.
|
||||
|
||||
All the mpi code inside this scope will have not affect.
|
||||
mpi.world_rank() and mpi.local_rank() will return 0, world_size() will return 1,
|
||||
|
||||
example::
|
||||
|
||||
with jt.single_process_scope(rank=0):
|
||||
......
|
||||
with jt.single_process_scope(rank=0) as flag:
|
||||
if flag:
|
||||
......
|
||||
|
||||
@jt.single_process_scope(rank=0)
|
||||
def xxx():
|
||||
|
@ -143,23 +147,31 @@ class single_process_scope(_call_no_record_scope):
|
|||
self.rank = rank
|
||||
|
||||
def __enter__(self):
|
||||
global mpi
|
||||
from jittor.dataset import dataset
|
||||
self.mpi_backup = mpi
|
||||
mpi = dataset.mpi = None
|
||||
global in_mpi
|
||||
self.bk_in_mpi = in_mpi
|
||||
if mpi:
|
||||
self.bk_mpi_state = mpi.get_state()
|
||||
if not in_mpi:
|
||||
return True
|
||||
|
||||
ret = self.rank == mpi.world_rank()
|
||||
in_mpi = compile_extern.in_mpi = False
|
||||
mpi.set_state(False)
|
||||
return ret
|
||||
|
||||
def __exit__(self, *exc):
|
||||
global mpi
|
||||
from jittor.dataset import dataset
|
||||
mpi = dataset.mpi = self.mpi_backup
|
||||
global in_mpi
|
||||
in_mpi = compile_extern.in_mpi = self.bk_in_mpi
|
||||
if mpi:
|
||||
mpi.set_state(self.bk_mpi_state)
|
||||
|
||||
def __call__(self, func):
|
||||
global mpi
|
||||
def inner(*args, **kw):
|
||||
if mpi and mpi.world_rank() != self.rank:
|
||||
return
|
||||
with self:
|
||||
ret = func(*args, **kw)
|
||||
ret = None
|
||||
with self as flag:
|
||||
if flag:
|
||||
ret = func(*args, **kw)
|
||||
return ret
|
||||
return inner
|
||||
|
||||
|
@ -504,7 +516,7 @@ class Module:
|
|||
p.start_grad()
|
||||
|
||||
def mpi_param_broadcast(self, root=0):
|
||||
if mpi is None: return
|
||||
if not in_mpi: return
|
||||
for p in self.parameters():
|
||||
p.assign(p.mpi_broadcast(root).detach())
|
||||
|
||||
|
|
|
@ -395,8 +395,7 @@ def setup_mpi():
|
|||
setattr(core.Var, k, warper(mpi_ops.__dict__[k]))
|
||||
|
||||
setup_mpi()
|
||||
if not inside_mpi():
|
||||
mpi = None
|
||||
in_mpi = inside_mpi()
|
||||
setup_nccl()
|
||||
|
||||
setup_cutt()
|
||||
|
|
|
@ -197,7 +197,7 @@ class Dataset(object):
|
|||
# pad to world_size
|
||||
# last batch
|
||||
# [.] -> [012]
|
||||
if mpi:
|
||||
if jt.in_mpi:
|
||||
world_size = mpi.world_size()
|
||||
world_rank = mpi.world_rank()
|
||||
index_list = np.int32(index_list)
|
||||
|
|
|
@ -185,7 +185,7 @@ class BatchNorm(Module):
|
|||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
if self.sync and jt.mpi:
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
|
@ -219,7 +219,7 @@ class BatchNorm1d(Module):
|
|||
xmean = jt.mean(x, dims=[0], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
|
||||
|
||||
if self.sync and jt.mpi:
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
|
@ -249,7 +249,7 @@ class InstanceNorm2d(Module):
|
|||
def execute(self, x):
|
||||
xmean = jt.mean(x, dims=[2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
|
||||
if self.sync and jt.mpi:
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ class Optimizer(object):
|
|||
grads = jt.grad(loss, params_has_grad)
|
||||
|
||||
# sync grads and model if in mpi
|
||||
if jt.mpi:
|
||||
if jt.in_mpi:
|
||||
for g in grads:
|
||||
g.assign(g.mpi_all_reduce("mean"))
|
||||
if self.n_step % self.param_sync_iter == 0:
|
||||
|
|
|
@ -12,7 +12,7 @@ import jittor as jt
|
|||
import numpy as np
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
@unittest.skipIf(not jt.in_mpi, "no inside mpirun")
|
||||
class TestMpi(unittest.TestCase):
|
||||
def test_mpi_test_op(self):
|
||||
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||
|
|
|
@ -48,7 +48,7 @@ class FakeMpiBatchNorm(nn.Module):
|
|||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
@unittest.skipIf(not jt.in_mpi, "no inside mpirun")
|
||||
class TestMpiBatchnorm(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
|
|
|
@ -16,7 +16,7 @@ mpi = jt.compile_extern.mpi
|
|||
if mpi:
|
||||
n = mpi.world_size()
|
||||
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
@unittest.skipIf(not jt.in_mpi, "no inside mpirun")
|
||||
class TestMpiOps(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
|
|
|
@ -113,7 +113,7 @@ class TestResnet(unittest.TestCase):
|
|||
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
|
||||
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
|
||||
|
||||
if jt.mpi:
|
||||
if jt.in_mpi:
|
||||
assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
|
||||
|
|
|
@ -30,7 +30,7 @@ def val2():
|
|||
if i == 5:
|
||||
break
|
||||
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
@unittest.skipIf(not jt.in_mpi, "no inside mpirun")
|
||||
class TestSingleProcessScope(unittest.TestCase):
|
||||
def test_single_process_scope(self):
|
||||
val1()
|
||||
|
|
Loading…
Reference in New Issue