mirror of https://github.com/Jittor/Jittor
fix fetcher segfault
This commit is contained in:
parent
3ecce8eb9e
commit
c9052e090d
|
@ -15,7 +15,7 @@ struct NcclBroadcastOp : Op {
|
|||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
NcclBroadcastOp(Var* x, int root);
|
||||
NcclBroadcastOp(Var* x, int root=0);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "nccl_broadcast"; }
|
||||
|
|
|
@ -15,7 +15,7 @@ struct NcclReduceOp : Op {
|
|||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
NcclReduceOp(Var* x, int root);
|
||||
NcclReduceOp(Var* x, int root=0);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "nccl_reduce"; }
|
||||
|
|
|
@ -45,7 +45,6 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
|
|||
}
|
||||
#endif
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void MpiAllReduceOp::infer_shape() {
|
||||
|
@ -80,10 +79,6 @@ void MpiAllReduceOp::jit_run() {
|
|||
index_t num = y->num;
|
||||
MPI_Allreduce(xp, yp, num, T_MPI, OP_MPI, MPI_COMM_WORLD);
|
||||
}
|
||||
#else
|
||||
void MpiAllReduceOp::jit_run() {
|
||||
// cuda device code
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@ MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
|
|||
}
|
||||
#endif
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void MpiBroadcastOp::infer_shape() {
|
||||
|
@ -61,10 +60,6 @@ void MpiBroadcastOp::jit_run() {
|
|||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
MPI_Bcast(yp, y->num, T_MPI, root, MPI_COMM_WORLD);
|
||||
}
|
||||
#else
|
||||
void MpiBroadcastOp::jit_run() {
|
||||
// cuda device code
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ struct MpiBroadcastOp : Op {
|
|||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
MpiBroadcastOp(Var* x, int root);
|
||||
MpiBroadcastOp(Var* x, int root=0);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_broadcast"; }
|
||||
|
|
|
@ -45,7 +45,6 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r
|
|||
}
|
||||
#endif
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void MpiReduceOp::infer_shape() {
|
||||
|
@ -82,10 +81,6 @@ void MpiReduceOp::jit_run() {
|
|||
if (root != mpi_world_rank)
|
||||
for (index_t i=0; i<num; i++) yp[i] = 0;
|
||||
}
|
||||
#else
|
||||
void MpiReduceOp::jit_run() {
|
||||
// cuda device code
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ with lock.lock_scope():
|
|||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops, inside_mpi, mpi_ops
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops
|
||||
|
||||
import contextlib
|
||||
import numpy as np
|
||||
|
@ -614,14 +614,10 @@ class Module:
|
|||
if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]:
|
||||
p.start_grad()
|
||||
|
||||
def mpi_sync(self):
|
||||
if not inside_mpi():
|
||||
return
|
||||
ps = self.parameters()
|
||||
for p in ps:
|
||||
temp = mpi_ops.mpi_broadcast(p, 0)
|
||||
p.assign(temp.detach())
|
||||
p.detach_inplace()
|
||||
def mpi_param_broadcast(self, root=0):
|
||||
if mpi is None: return
|
||||
for p in self.parameters():
|
||||
p.assign(p.mpi_broadcast(root).detach())
|
||||
|
||||
def make_module(func, exec_n_args=1):
|
||||
class MakeModule(Module):
|
||||
|
|
|
@ -373,6 +373,15 @@ def setup_mpi():
|
|||
mpi_ops = mpi.ops
|
||||
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
|
||||
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
|
||||
def warper(func):
|
||||
def inner(self, *args, **kw):
|
||||
return func(self, *args, **kw)
|
||||
inner.__doc__ = func.__doc__
|
||||
return inner
|
||||
for k in mpi_ops.__dict__:
|
||||
if not k.startswith("mpi_"): continue
|
||||
if k == "mpi_test": continue
|
||||
setattr(core.Var, k, warper(mpi_ops.__dict__[k]))
|
||||
|
||||
setup_mpi()
|
||||
setup_nccl()
|
||||
|
|
|
@ -42,37 +42,6 @@ jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
|||
def get_init_var_rand(shape, dtype):
|
||||
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
|
||||
|
||||
@jt.var_scope('batch_norm')
|
||||
def batch_norm(x, is_train, eps=1e-5, momentum=0.1, sync=True):
|
||||
assert not (jt.compile_extern.mpi_ops is None)
|
||||
w = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
|
||||
b = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
|
||||
running_mean = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
|
||||
running_var = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
|
||||
|
||||
w = w.broadcast(x, [0,2,3])
|
||||
b = b.broadcast(x, [0,2,3])
|
||||
if is_train:
|
||||
if self.sync and not (jt.compile_extern.mpi_ops is None):
|
||||
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
|
||||
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
|
||||
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
|
||||
else:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+eps)
|
||||
|
||||
running_mean += (xmean.sum([0,2,3])-running_mean)*momentum
|
||||
running_var += (xvar.sum([0,2,3])-running_var)*momentum
|
||||
else:
|
||||
running_mean = running_mean.broadcast(x, [0,2,3])
|
||||
running_var = running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+eps)
|
||||
|
||||
return norm_x * w + b
|
||||
|
||||
@jt.var_scope('conv')
|
||||
def conv(x, in_planes, out_planes, kernel_size, padding, stride = 1, init_method=None):
|
||||
Kw = kernel_size
|
||||
|
@ -147,6 +116,9 @@ class SGD(object):
|
|||
self.parameters = []
|
||||
self.values = []
|
||||
for p in parameters:
|
||||
# broadcast parameter from 0 node when init
|
||||
if jt.mpi:
|
||||
p.assign(p.mpi_broadcast().detach())
|
||||
if p.is_stop_grad():
|
||||
self.no_grad_parameters.append(p)
|
||||
continue
|
||||
|
@ -156,9 +128,9 @@ class SGD(object):
|
|||
def step(self, loss):
|
||||
ps = self.parameters
|
||||
gs = jt.grad(loss, ps)
|
||||
if jt.compile_extern.inside_mpi():
|
||||
if jt.mpi:
|
||||
for g in gs:
|
||||
g.assign(jt.compile_extern.mpi_ops.mpi_all_reduce(g))
|
||||
g.assign(g.mpi_all_reduce("mean"))
|
||||
for p, g, v in zip(ps, gs, self.values):
|
||||
dp = p * self.weight_decay + g
|
||||
v.assign(self.momentum * v + dp * (1 - self.dampening))
|
||||
|
@ -191,6 +163,8 @@ class Adam(object):
|
|||
self.values = []
|
||||
self.m = []
|
||||
for p in parameters:
|
||||
if jt.mpi:
|
||||
p.assign(p.mpi_broadcast().detach())
|
||||
if p.is_stop_grad():
|
||||
self.no_grad_parameters.append(p)
|
||||
continue
|
||||
|
@ -201,9 +175,9 @@ class Adam(object):
|
|||
def step(self, loss):
|
||||
ps = self.parameters
|
||||
gs = jt.grad(loss, ps)
|
||||
if jt.compile_extern.inside_mpi():
|
||||
if jt.mpi:
|
||||
for g in gs:
|
||||
g.assign(jt.compile_extern.mpi_ops.mpi_all_reduce(g))
|
||||
g.assign(g.mpi_all_reduce("mean"))
|
||||
self.adam_step += 1
|
||||
n, (b0, b1) = float(self.adam_step), self.betas
|
||||
for p, g, v, m in zip(ps, gs, self.values, self.m):
|
||||
|
@ -270,13 +244,12 @@ class BatchNorm(Module):
|
|||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
if self.sync and jt.compile_extern.mpi_ops is not None:
|
||||
xmean = jt.compile_extern.mpi_ops.mpi_all_reduce(xmean)/jt.compile_extern.mpi.world_size()
|
||||
x2mean = jt.compile_extern.mpi_ops.mpi_all_reduce(x2mean)/jt.compile_extern.mpi.world_size()
|
||||
if self.sync and jt.mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
|
|
|
@ -23,14 +23,14 @@ class TestMpiOps(unittest.TestCase):
|
|||
|
||||
def test_all_reduce(self):
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
|
||||
y = x.mpi_all_reduce()
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
||||
|
||||
def test_all_reduce_mean(self):
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x, "mean")
|
||||
y = x.mpi_all_reduce("mean")
|
||||
assert np.allclose(y.data, x.data)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5]))
|
||||
|
@ -41,7 +41,7 @@ class TestMpiOps(unittest.TestCase):
|
|||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
y = x.mpi_broadcast(0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
g = jt.grad(y,x)
|
||||
if mpi.world_rank() == 0:
|
||||
|
@ -51,7 +51,7 @@ class TestMpiOps(unittest.TestCase):
|
|||
|
||||
def test_reduce(self):
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, root=0)
|
||||
y = x.mpi_reduce(root=0)
|
||||
y.sync()
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
|
|
|
@ -16,7 +16,7 @@ from jittor import nn, Module
|
|||
import copy
|
||||
from jittor.test.test_log import find_log_with_re
|
||||
from jittor.test.test_mpi import run_mpi_test
|
||||
from jittor.compile_extern import mpi, mpi_ops, nccl_ops
|
||||
from jittor.compile_extern import mpi, nccl_ops
|
||||
n = 2
|
||||
|
||||
@unittest.skipIf(nccl_ops is None, "nccl not found")
|
||||
|
@ -32,7 +32,7 @@ class TestNcclOps(unittest.TestCase):
|
|||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
|
||||
y = x.mpi_all_reduce()
|
||||
assert np.allclose(y.data, (x*n).data)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5])*n)
|
||||
|
@ -50,7 +50,7 @@ class TestNcclOps(unittest.TestCase):
|
|||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
y = x.mpi_broadcast(0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
g = jt.grad(y.sum(),x)
|
||||
g_ = g.data
|
||||
|
@ -65,7 +65,7 @@ class TestNcclOps(unittest.TestCase):
|
|||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, root=0)
|
||||
y = x.mpi_reduce(root=0)
|
||||
y_ = y.data
|
||||
x_ = (x*n).data
|
||||
if mpi.world_rank() == 0:
|
||||
|
@ -96,7 +96,7 @@ class TestNcclOps(unittest.TestCase):
|
|||
net.linear1.weight += 1
|
||||
net.linear2.weight += 1
|
||||
net.linear1.bias += 1
|
||||
net.mpi_sync()
|
||||
net.mpi_param_broadcast()
|
||||
assert np.allclose(net.linear1.weight.data, jt.ones(net.linear1.weight.shape).data)
|
||||
assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data)
|
||||
assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data)
|
||||
|
@ -122,16 +122,16 @@ class TestNcclOps(unittest.TestCase):
|
|||
|
||||
num = 2000
|
||||
model = Model2(1)
|
||||
model.mpi_sync()
|
||||
optimizer = nn.SGD(model.parameters(), 0.05)
|
||||
model.mpi_param_broadcast()
|
||||
optimizer = nn.SGD(model.parameters(), 0.1)
|
||||
dataset = list(enumerate(get_data(num)))
|
||||
for i in range(mpi.world_rank(), num, n):
|
||||
id, (x, y) = dataset[i]
|
||||
pred_y = model(x)
|
||||
loss = (pred_y - y)*(pred_y - y)
|
||||
loss = (pred_y - y)**2
|
||||
loss_mean = loss.mean()
|
||||
optimizer.step(loss_mean)
|
||||
assert loss_mean.data < 0.0025
|
||||
assert loss_mean.data < 0.0025, loss_mean.data
|
||||
jt.clean()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
|
|
|
@ -51,7 +51,7 @@ Init() {
|
|||
checkCudaErrors(cudaStreamDestroy(stream));
|
||||
checkCudaErrors(cudaEventDestroy(event));
|
||||
}
|
||||
} init;
|
||||
};
|
||||
|
||||
}
|
||||
using namespace fetcher_local;
|
||||
|
@ -59,6 +59,9 @@ using namespace fetcher_local;
|
|||
#endif
|
||||
|
||||
void fetch(const vector<VarHolder*>& vh, FetchFunc&& func) {
|
||||
#ifdef HAS_CUDA
|
||||
static Init init;
|
||||
#endif
|
||||
sync(vh);
|
||||
vector<Allocation> allocations(vh.size());
|
||||
vector<ArrayArgs> arrays(vh.size());
|
||||
|
|
Loading…
Reference in New Issue