mirror of https://github.com/Jittor/Jittor
fix: wrong grad of varslice like [::-1]
This commit is contained in:
parent
4bd0c4a2f6
commit
38694a1b6e
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.18'
|
||||
__version__ = '1.3.1.19'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -391,9 +391,9 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
// need analysis the overlap attr os var slices
|
||||
for (int i=0; i<vs.n; i++)
|
||||
if (vs.slices[i].is_var()) {
|
||||
return make_setitem(zeros, VarSlices(vs), dout, ns_add);
|
||||
return make_setitem(zeros, VarSlices(vs, true), dout, ns_add);
|
||||
}
|
||||
return make_setitem(zeros, VarSlices(vs), dout, ns_void);
|
||||
return make_setitem(zeros, VarSlices(vs, true), dout, ns_void);
|
||||
}
|
||||
|
||||
void GetitemOp::jit_prepare(JK& jk) {
|
||||
|
|
|
@ -128,41 +128,41 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
if (v_index == 0) {
|
||||
float32 number = 0;
|
||||
VarPtr zero = make_array(&number, 1, ns_float32);
|
||||
return make_setitem(dout, VarSlices(vs), zero, ns_void);
|
||||
return make_setitem(dout, VarSlices(vs, true), zero, ns_void);
|
||||
} else {
|
||||
return make_getitem(dout, VarSlices(vs));
|
||||
return make_getitem(dout, VarSlices(vs, true));
|
||||
}
|
||||
}
|
||||
if (op == ns_add) {
|
||||
if (v_index == 0) {
|
||||
return dout;
|
||||
} else {
|
||||
return make_getitem(dout, VarSlices(vs));
|
||||
return make_getitem(dout, VarSlices(vs, true));
|
||||
}
|
||||
}
|
||||
if (op == ns_subtract) {
|
||||
if (v_index == 0) {
|
||||
return dout;
|
||||
} else {
|
||||
return make_unary(make_getitem(dout, VarSlices(vs)), ns_negative);
|
||||
return make_unary(make_getitem(dout, VarSlices(vs, true)), ns_negative);
|
||||
}
|
||||
}
|
||||
if (op == ns_multiply) {
|
||||
if (v_index == 0) {
|
||||
return make_setitem(dout, VarSlices(vs), input(1), ns_multiply);
|
||||
return make_setitem(dout, VarSlices(vs, true), input(1), ns_multiply);
|
||||
} else {
|
||||
return make_binary(
|
||||
make_getitem(inputs().front(), VarSlices(vs)),
|
||||
make_getitem(dout, VarSlices(vs)), ns_multiply);
|
||||
make_getitem(inputs().front(), VarSlices(vs, true)),
|
||||
make_getitem(dout, VarSlices(vs, true)), ns_multiply);
|
||||
}
|
||||
}
|
||||
if (op == ns_divide) {
|
||||
if (v_index == 0) {
|
||||
return make_setitem(dout, VarSlices(vs), input(1), ns_divide);
|
||||
return make_setitem(dout, VarSlices(vs, true), input(1), ns_divide);
|
||||
} else {
|
||||
// dy = -dz*x / y^2
|
||||
auto dout2 = make_getitem(dout, VarSlices(vs));
|
||||
auto x = make_getitem(inputs().front(), VarSlices(vs));
|
||||
auto dout2 = make_getitem(dout, VarSlices(vs, true));
|
||||
auto x = make_getitem(inputs().front(), VarSlices(vs, true));
|
||||
auto y = v;
|
||||
auto ndz = make_unary(dout2, ns_negative);
|
||||
auto ndzx = make_binary(ndz, x, ns_multiply);
|
||||
|
|
|
@ -47,9 +47,16 @@ struct VarSlices {
|
|||
inline VarSlices(VarSlices&& other) : slices(other.slices), n(other.n) {
|
||||
other.slices = nullptr;
|
||||
}
|
||||
inline VarSlices(const VarSlices& other) : slices(new VarSlice[other.n]), n(other.n) {
|
||||
for (int i=0; i<n; i++)
|
||||
inline VarSlices(const VarSlices& other, bool negtive_set_none=false) : slices(new VarSlice[other.n]), n(other.n) {
|
||||
for (int i=0; i<n; i++) {
|
||||
slices[i] = other.slices[i];
|
||||
if (negtive_set_none &&
|
||||
slices[i].is_slice() &&
|
||||
slices[i].slice.step < 0 &&
|
||||
slices[i].slice.stop < 0) {
|
||||
slices[i].slice.mask |= 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void operator=(VarSlices&& other) {
|
||||
if (slices) delete[] slices;
|
||||
|
|
|
@ -327,6 +327,19 @@ class TestRNN(unittest.TestCase):
|
|||
np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06)
|
||||
np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06)
|
||||
|
||||
def test_twobilinear_lstm(self):
|
||||
x = jt.rand(5, 4, 10)
|
||||
rnn1 = nn.LSTM(10, 20, bidirectional=True)
|
||||
out1, _ = rnn1(x)
|
||||
rnn2 = nn.LSTM(40, 20, bidirectional=True)
|
||||
out2, _ = rnn2(out1)
|
||||
target = jt.zeros_like(out2)
|
||||
loss = nn.mse_loss(out2, target)
|
||||
|
||||
from jittor import optim
|
||||
optimizer = optim.RMSprop(rnn1.parameters())
|
||||
optimizer.step(loss)
|
||||
|
||||
@skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cudnn_rnn(self):
|
||||
|
|
|
@ -226,6 +226,12 @@ class TestSetitem(unittest.TestCase):
|
|||
assert b.shape == [2,4,1,2]
|
||||
np.testing.assert_allclose(b.data, a.data[...,:,None,:2])
|
||||
|
||||
def test_flip_grad(self):
|
||||
a = jt.rand(10)
|
||||
b = a[::-1]
|
||||
c = b[::-1]
|
||||
d = c.sum()
|
||||
jt.grad(d, [a])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue