fix fetcher segfault

This commit is contained in:
Dun Liang 2020-04-22 19:38:01 +08:00
parent 3ecce8eb9e
commit c9052e090d
12 changed files with 46 additions and 80 deletions

View File

@ -15,7 +15,7 @@ struct NcclBroadcastOp : Op {
Var* x, * y; Var* x, * y;
int root; int root;
NcclBroadcastOp(Var* x, int root); NcclBroadcastOp(Var* x, int root=0);
void infer_shape() override; void infer_shape() override;
const char* name() const override { return "nccl_broadcast"; } const char* name() const override { return "nccl_broadcast"; }

View File

@ -15,7 +15,7 @@ struct NcclReduceOp : Op {
Var* x, * y; Var* x, * y;
int root; int root;
NcclReduceOp(Var* x, int root); NcclReduceOp(Var* x, int root=0);
void infer_shape() override; void infer_shape() override;
const char* name() const override { return "nccl_reduce"; } const char* name() const override { return "nccl_reduce"; }

View File

@ -45,7 +45,6 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
} }
#endif #endif
y = create_output(nullptr, x->dtype()); y = create_output(nullptr, x->dtype());
ASSERT(x->dtype().is_float());
} }
void MpiAllReduceOp::infer_shape() { void MpiAllReduceOp::infer_shape() {
@ -80,10 +79,6 @@ void MpiAllReduceOp::jit_run() {
index_t num = y->num; index_t num = y->num;
MPI_Allreduce(xp, yp, num, T_MPI, OP_MPI, MPI_COMM_WORLD); MPI_Allreduce(xp, yp, num, T_MPI, OP_MPI, MPI_COMM_WORLD);
} }
#else
void MpiAllReduceOp::jit_run() {
// cuda device code
}
#endif // JIT_cpu #endif // JIT_cpu
#endif // JIT #endif // JIT

View File

@ -30,7 +30,6 @@ MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
} }
#endif #endif
y = create_output(nullptr, x->dtype()); y = create_output(nullptr, x->dtype());
ASSERT(x->dtype().is_float());
} }
void MpiBroadcastOp::infer_shape() { void MpiBroadcastOp::infer_shape() {
@ -61,10 +60,6 @@ void MpiBroadcastOp::jit_run() {
auto* __restrict__ yp = y->ptr<Tx>(); auto* __restrict__ yp = y->ptr<Tx>();
MPI_Bcast(yp, y->num, T_MPI, root, MPI_COMM_WORLD); MPI_Bcast(yp, y->num, T_MPI, root, MPI_COMM_WORLD);
} }
#else
void MpiBroadcastOp::jit_run() {
// cuda device code
}
#endif // JIT_cpu #endif // JIT_cpu
#endif // JIT #endif // JIT

View File

@ -15,7 +15,7 @@ struct MpiBroadcastOp : Op {
Var* x, * y; Var* x, * y;
int root; int root;
MpiBroadcastOp(Var* x, int root); MpiBroadcastOp(Var* x, int root=0);
void infer_shape() override; void infer_shape() override;
const char* name() const override { return "mpi_broadcast"; } const char* name() const override { return "mpi_broadcast"; }

View File

@ -45,7 +45,6 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r
} }
#endif #endif
y = create_output(nullptr, x->dtype()); y = create_output(nullptr, x->dtype());
ASSERT(x->dtype().is_float());
} }
void MpiReduceOp::infer_shape() { void MpiReduceOp::infer_shape() {
@ -82,10 +81,6 @@ void MpiReduceOp::jit_run() {
if (root != mpi_world_rank) if (root != mpi_world_rank)
for (index_t i=0; i<num; i++) yp[i] = 0; for (index_t i=0; i<num; i++) yp[i] = 0;
} }
#else
void MpiReduceOp::jit_run() {
// cuda device code
}
#endif // JIT_cpu #endif // JIT_cpu
#endif // JIT #endif // JIT

View File

@ -16,7 +16,7 @@ with lock.lock_scope():
from jittor_core import * from jittor_core import *
from jittor_core.ops import * from jittor_core.ops import *
from . import compile_extern from . import compile_extern
from .compile_extern import mkl_ops, inside_mpi, mpi_ops from .compile_extern import mkl_ops, mpi, mpi_ops
import contextlib import contextlib
import numpy as np import numpy as np
@ -614,14 +614,10 @@ class Module:
if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]: if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]:
p.start_grad() p.start_grad()
def mpi_sync(self): def mpi_param_broadcast(self, root=0):
if not inside_mpi(): if mpi is None: return
return for p in self.parameters():
ps = self.parameters() p.assign(p.mpi_broadcast(root).detach())
for p in ps:
temp = mpi_ops.mpi_broadcast(p, 0)
p.assign(temp.detach())
p.detach_inplace()
def make_module(func, exec_n_args=1): def make_module(func, exec_n_args=1):
class MakeModule(Module): class MakeModule(Module):

View File

@ -373,6 +373,15 @@ def setup_mpi():
mpi_ops = mpi.ops mpi_ops = mpi.ops
LOG.vv("Get mpi: "+str(mpi.__dict__.keys())) LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys())) LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
def warper(func):
def inner(self, *args, **kw):
return func(self, *args, **kw)
inner.__doc__ = func.__doc__
return inner
for k in mpi_ops.__dict__:
if not k.startswith("mpi_"): continue
if k == "mpi_test": continue
setattr(core.Var, k, warper(mpi_ops.__dict__[k]))
setup_mpi() setup_mpi()
setup_nccl() setup_nccl()

View File

@ -42,37 +42,6 @@ jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
def get_init_var_rand(shape, dtype): def get_init_var_rand(shape, dtype):
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32)) return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
@jt.var_scope('batch_norm')
def batch_norm(x, is_train, eps=1e-5, momentum=0.1, sync=True):
assert not (jt.compile_extern.mpi_ops is None)
w = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
b = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
running_mean = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
running_var = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
w = w.broadcast(x, [0,2,3])
b = b.broadcast(x, [0,2,3])
if is_train:
if self.sync and not (jt.compile_extern.mpi_ops is None):
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
else:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+eps)
running_mean += (xmean.sum([0,2,3])-running_mean)*momentum
running_var += (xvar.sum([0,2,3])-running_var)*momentum
else:
running_mean = running_mean.broadcast(x, [0,2,3])
running_var = running_var.broadcast(x, [0,2,3])
norm_x = (x-running_mean)/jt.sqrt(running_var+eps)
return norm_x * w + b
@jt.var_scope('conv') @jt.var_scope('conv')
def conv(x, in_planes, out_planes, kernel_size, padding, stride = 1, init_method=None): def conv(x, in_planes, out_planes, kernel_size, padding, stride = 1, init_method=None):
Kw = kernel_size Kw = kernel_size
@ -147,6 +116,9 @@ class SGD(object):
self.parameters = [] self.parameters = []
self.values = [] self.values = []
for p in parameters: for p in parameters:
# broadcast parameter from 0 node when init
if jt.mpi:
p.assign(p.mpi_broadcast().detach())
if p.is_stop_grad(): if p.is_stop_grad():
self.no_grad_parameters.append(p) self.no_grad_parameters.append(p)
continue continue
@ -156,9 +128,9 @@ class SGD(object):
def step(self, loss): def step(self, loss):
ps = self.parameters ps = self.parameters
gs = jt.grad(loss, ps) gs = jt.grad(loss, ps)
if jt.compile_extern.inside_mpi(): if jt.mpi:
for g in gs: for g in gs:
g.assign(jt.compile_extern.mpi_ops.mpi_all_reduce(g)) g.assign(g.mpi_all_reduce("mean"))
for p, g, v in zip(ps, gs, self.values): for p, g, v in zip(ps, gs, self.values):
dp = p * self.weight_decay + g dp = p * self.weight_decay + g
v.assign(self.momentum * v + dp * (1 - self.dampening)) v.assign(self.momentum * v + dp * (1 - self.dampening))
@ -191,6 +163,8 @@ class Adam(object):
self.values = [] self.values = []
self.m = [] self.m = []
for p in parameters: for p in parameters:
if jt.mpi:
p.assign(p.mpi_broadcast().detach())
if p.is_stop_grad(): if p.is_stop_grad():
self.no_grad_parameters.append(p) self.no_grad_parameters.append(p)
continue continue
@ -201,9 +175,9 @@ class Adam(object):
def step(self, loss): def step(self, loss):
ps = self.parameters ps = self.parameters
gs = jt.grad(loss, ps) gs = jt.grad(loss, ps)
if jt.compile_extern.inside_mpi(): if jt.mpi:
for g in gs: for g in gs:
g.assign(jt.compile_extern.mpi_ops.mpi_all_reduce(g)) g.assign(g.mpi_all_reduce("mean"))
self.adam_step += 1 self.adam_step += 1
n, (b0, b1) = float(self.adam_step), self.betas n, (b0, b1) = float(self.adam_step), self.betas
for p, g, v, m in zip(ps, gs, self.values, self.m): for p, g, v, m in zip(ps, gs, self.values, self.m):
@ -270,13 +244,12 @@ class BatchNorm(Module):
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad() self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
def execute(self, x): def execute(self, x):
mpi = jt.compile_extern.mpi
if self.is_train: if self.is_train:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1) xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1) x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
if self.sync and jt.compile_extern.mpi_ops is not None: if self.sync and jt.mpi:
xmean = jt.compile_extern.mpi_ops.mpi_all_reduce(xmean)/jt.compile_extern.mpi.world_size() xmean = xmean.mpi_all_reduce("mean")
x2mean = jt.compile_extern.mpi_ops.mpi_all_reduce(x2mean)/jt.compile_extern.mpi.world_size() x2mean = x2mean.mpi_all_reduce("mean")
xvar = x2mean-xmean*xmean xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps) norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)

View File

@ -23,14 +23,14 @@ class TestMpiOps(unittest.TestCase):
def test_all_reduce(self): def test_all_reduce(self):
x = jt.random([5, 5]) x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x) y = x.mpi_all_reduce()
assert np.allclose(y.data, (x*3).data) assert np.allclose(y.data, (x*3).data)
g = jt.grad(y,x) g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5])*3) assert np.allclose(g.data, np.ones([5,5])*3)
def test_all_reduce_mean(self): def test_all_reduce_mean(self):
x = jt.random([5, 5]) x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x, "mean") y = x.mpi_all_reduce("mean")
assert np.allclose(y.data, x.data) assert np.allclose(y.data, x.data)
g = jt.grad(y,x) g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5])) assert np.allclose(g.data, np.ones([5,5]))
@ -41,7 +41,7 @@ class TestMpiOps(unittest.TestCase):
x = data x = data
else: else:
x = jt.zeros([5, 5]) x = jt.zeros([5, 5])
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0) y = x.mpi_broadcast(0)
assert np.allclose(y.data, data.data) assert np.allclose(y.data, data.data)
g = jt.grad(y,x) g = jt.grad(y,x)
if mpi.world_rank() == 0: if mpi.world_rank() == 0:
@ -51,7 +51,7 @@ class TestMpiOps(unittest.TestCase):
def test_reduce(self): def test_reduce(self):
x = jt.random([5, 5]) x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_reduce(x, root=0) y = x.mpi_reduce(root=0)
y.sync() y.sync()
if mpi.world_rank() == 0: if mpi.world_rank() == 0:
assert np.allclose(y.data, (x*3).data) assert np.allclose(y.data, (x*3).data)

View File

@ -16,7 +16,7 @@ from jittor import nn, Module
import copy import copy
from jittor.test.test_log import find_log_with_re from jittor.test.test_log import find_log_with_re
from jittor.test.test_mpi import run_mpi_test from jittor.test.test_mpi import run_mpi_test
from jittor.compile_extern import mpi, mpi_ops, nccl_ops from jittor.compile_extern import mpi, nccl_ops
n = 2 n = 2
@unittest.skipIf(nccl_ops is None, "nccl not found") @unittest.skipIf(nccl_ops is None, "nccl not found")
@ -32,7 +32,7 @@ class TestNcclOps(unittest.TestCase):
log_v=1, log_vprefix="op.cc=100,exe=1000" log_v=1, log_vprefix="op.cc=100,exe=1000"
) as raw_log: ) as raw_log:
x = jt.random([5, 5]) x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x) y = x.mpi_all_reduce()
assert np.allclose(y.data, (x*n).data) assert np.allclose(y.data, (x*n).data)
g = jt.grad(y,x) g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5])*n) assert np.allclose(g.data, np.ones([5,5])*n)
@ -50,7 +50,7 @@ class TestNcclOps(unittest.TestCase):
x = data x = data
else: else:
x = jt.zeros([5, 5]) x = jt.zeros([5, 5])
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0) y = x.mpi_broadcast(0)
assert np.allclose(y.data, data.data) assert np.allclose(y.data, data.data)
g = jt.grad(y.sum(),x) g = jt.grad(y.sum(),x)
g_ = g.data g_ = g.data
@ -65,7 +65,7 @@ class TestNcclOps(unittest.TestCase):
log_v=1, log_vprefix="op.cc=100,exe=1000" log_v=1, log_vprefix="op.cc=100,exe=1000"
) as raw_log: ) as raw_log:
x = jt.random([5, 5]) x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_reduce(x, root=0) y = x.mpi_reduce(root=0)
y_ = y.data y_ = y.data
x_ = (x*n).data x_ = (x*n).data
if mpi.world_rank() == 0: if mpi.world_rank() == 0:
@ -96,7 +96,7 @@ class TestNcclOps(unittest.TestCase):
net.linear1.weight += 1 net.linear1.weight += 1
net.linear2.weight += 1 net.linear2.weight += 1
net.linear1.bias += 1 net.linear1.bias += 1
net.mpi_sync() net.mpi_param_broadcast()
assert np.allclose(net.linear1.weight.data, jt.ones(net.linear1.weight.shape).data) assert np.allclose(net.linear1.weight.data, jt.ones(net.linear1.weight.shape).data)
assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data) assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data)
assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data) assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data)
@ -122,16 +122,16 @@ class TestNcclOps(unittest.TestCase):
num = 2000 num = 2000
model = Model2(1) model = Model2(1)
model.mpi_sync() model.mpi_param_broadcast()
optimizer = nn.SGD(model.parameters(), 0.05) optimizer = nn.SGD(model.parameters(), 0.1)
dataset = list(enumerate(get_data(num))) dataset = list(enumerate(get_data(num)))
for i in range(mpi.world_rank(), num, n): for i in range(mpi.world_rank(), num, n):
id, (x, y) = dataset[i] id, (x, y) = dataset[i]
pred_y = model(x) pred_y = model(x)
loss = (pred_y - y)*(pred_y - y) loss = (pred_y - y)**2
loss_mean = loss.mean() loss_mean = loss.mean()
optimizer.step(loss_mean) optimizer.step(loss_mean)
assert loss_mean.data < 0.0025 assert loss_mean.data < 0.0025, loss_mean.data
jt.clean() jt.clean()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found") @unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")

View File

@ -51,7 +51,7 @@ Init() {
checkCudaErrors(cudaStreamDestroy(stream)); checkCudaErrors(cudaStreamDestroy(stream));
checkCudaErrors(cudaEventDestroy(event)); checkCudaErrors(cudaEventDestroy(event));
} }
} init; };
} }
using namespace fetcher_local; using namespace fetcher_local;
@ -59,6 +59,9 @@ using namespace fetcher_local;
#endif #endif
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) { void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
#ifdef HAS_CUDA
static Init init;
#endif
sync(vh); sync(vh);
vector<Allocation> allocations(vh.size()); vector<Allocation> allocations(vh.size());
vector<ArrayArgs> arrays(vh.size()); vector<ArrayArgs> arrays(vh.size());