fix: wrong grad of varslice like [::-1]

This commit is contained in:
lzhengning 2021-11-21 16:50:45 +08:00
parent 4bd0c4a2f6
commit 38694a1b6e
6 changed files with 42 additions and 16 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.3.1.18' __version__ = '1.3.1.19'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -391,9 +391,9 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
// need analysis the overlap attr os var slices // need analysis the overlap attr os var slices
for (int i=0; i<vs.n; i++) for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var()) { if (vs.slices[i].is_var()) {
return make_setitem(zeros, VarSlices(vs), dout, ns_add); return make_setitem(zeros, VarSlices(vs, true), dout, ns_add);
} }
return make_setitem(zeros, VarSlices(vs), dout, ns_void); return make_setitem(zeros, VarSlices(vs, true), dout, ns_void);
} }
void GetitemOp::jit_prepare(JK& jk) { void GetitemOp::jit_prepare(JK& jk) {

View File

@ -128,41 +128,41 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index == 0) { if (v_index == 0) {
float32 number = 0; float32 number = 0;
VarPtr zero = make_array(&number, 1, ns_float32); VarPtr zero = make_array(&number, 1, ns_float32);
return make_setitem(dout, VarSlices(vs), zero, ns_void); return make_setitem(dout, VarSlices(vs, true), zero, ns_void);
} else { } else {
return make_getitem(dout, VarSlices(vs)); return make_getitem(dout, VarSlices(vs, true));
} }
} }
if (op == ns_add) { if (op == ns_add) {
if (v_index == 0) { if (v_index == 0) {
return dout; return dout;
} else { } else {
return make_getitem(dout, VarSlices(vs)); return make_getitem(dout, VarSlices(vs, true));
} }
} }
if (op == ns_subtract) { if (op == ns_subtract) {
if (v_index == 0) { if (v_index == 0) {
return dout; return dout;
} else { } else {
return make_unary(make_getitem(dout, VarSlices(vs)), ns_negative); return make_unary(make_getitem(dout, VarSlices(vs, true)), ns_negative);
} }
} }
if (op == ns_multiply) { if (op == ns_multiply) {
if (v_index == 0) { if (v_index == 0) {
return make_setitem(dout, VarSlices(vs), input(1), ns_multiply); return make_setitem(dout, VarSlices(vs, true), input(1), ns_multiply);
} else { } else {
return make_binary( return make_binary(
make_getitem(inputs().front(), VarSlices(vs)), make_getitem(inputs().front(), VarSlices(vs, true)),
make_getitem(dout, VarSlices(vs)), ns_multiply); make_getitem(dout, VarSlices(vs, true)), ns_multiply);
} }
} }
if (op == ns_divide) { if (op == ns_divide) {
if (v_index == 0) { if (v_index == 0) {
return make_setitem(dout, VarSlices(vs), input(1), ns_divide); return make_setitem(dout, VarSlices(vs, true), input(1), ns_divide);
} else { } else {
// dy = -dz*x / y^2 // dy = -dz*x / y^2
auto dout2 = make_getitem(dout, VarSlices(vs)); auto dout2 = make_getitem(dout, VarSlices(vs, true));
auto x = make_getitem(inputs().front(), VarSlices(vs)); auto x = make_getitem(inputs().front(), VarSlices(vs, true));
auto y = v; auto y = v;
auto ndz = make_unary(dout2, ns_negative); auto ndz = make_unary(dout2, ns_negative);
auto ndzx = make_binary(ndz, x, ns_multiply); auto ndzx = make_binary(ndz, x, ns_multiply);

View File

@ -47,9 +47,16 @@ struct VarSlices {
inline VarSlices(VarSlices&& other) : slices(other.slices), n(other.n) { inline VarSlices(VarSlices&& other) : slices(other.slices), n(other.n) {
other.slices = nullptr; other.slices = nullptr;
} }
inline VarSlices(const VarSlices& other) : slices(new VarSlice[other.n]), n(other.n) { inline VarSlices(const VarSlices& other, bool negtive_set_none=false) : slices(new VarSlice[other.n]), n(other.n) {
for (int i=0; i<n; i++) for (int i=0; i<n; i++) {
slices[i] = other.slices[i]; slices[i] = other.slices[i];
if (negtive_set_none &&
slices[i].is_slice() &&
slices[i].slice.step < 0 &&
slices[i].slice.stop < 0) {
slices[i].slice.mask |= 2;
}
}
} }
inline void operator=(VarSlices&& other) { inline void operator=(VarSlices&& other) {
if (slices) delete[] slices; if (slices) delete[] slices;

View File

@ -327,6 +327,19 @@ class TestRNN(unittest.TestCase):
np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06) np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06)
np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06) np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06)
def test_twobilinear_lstm(self):
x = jt.rand(5, 4, 10)
rnn1 = nn.LSTM(10, 20, bidirectional=True)
out1, _ = rnn1(x)
rnn2 = nn.LSTM(40, 20, bidirectional=True)
out2, _ = rnn2(out1)
target = jt.zeros_like(out2)
loss = nn.mse_loss(out2, target)
from jittor import optim
optimizer = optim.RMSprop(rnn1.parameters())
optimizer.step(loss)
@skipIf(not jt.has_cuda, "No Cuda found") @skipIf(not jt.has_cuda, "No Cuda found")
@jt.flag_scope(use_cuda=1) @jt.flag_scope(use_cuda=1)
def test_cudnn_rnn(self): def test_cudnn_rnn(self):

View File

@ -225,8 +225,14 @@ class TestSetitem(unittest.TestCase):
b = a[...,:,None,:2] b = a[...,:,None,:2]
assert b.shape == [2,4,1,2] assert b.shape == [2,4,1,2]
np.testing.assert_allclose(b.data, a.data[...,:,None,:2]) np.testing.assert_allclose(b.data, a.data[...,:,None,:2])
def test_flip_grad(self):
a = jt.rand(10)
b = a[::-1]
c = b[::-1]
d = c.sum()
jt.grad(d, [a])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()