mirror of https://github.com/Jittor/Jittor
fix: wrong grad of varslice like [::-1]
This commit is contained in:
parent
4bd0c4a2f6
commit
38694a1b6e
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.3.1.18'
|
__version__ = '1.3.1.19'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -391,9 +391,9 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
// need analysis the overlap attr os var slices
|
// need analysis the overlap attr os var slices
|
||||||
for (int i=0; i<vs.n; i++)
|
for (int i=0; i<vs.n; i++)
|
||||||
if (vs.slices[i].is_var()) {
|
if (vs.slices[i].is_var()) {
|
||||||
return make_setitem(zeros, VarSlices(vs), dout, ns_add);
|
return make_setitem(zeros, VarSlices(vs, true), dout, ns_add);
|
||||||
}
|
}
|
||||||
return make_setitem(zeros, VarSlices(vs), dout, ns_void);
|
return make_setitem(zeros, VarSlices(vs, true), dout, ns_void);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetitemOp::jit_prepare(JK& jk) {
|
void GetitemOp::jit_prepare(JK& jk) {
|
||||||
|
|
|
@ -128,41 +128,41 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
if (v_index == 0) {
|
if (v_index == 0) {
|
||||||
float32 number = 0;
|
float32 number = 0;
|
||||||
VarPtr zero = make_array(&number, 1, ns_float32);
|
VarPtr zero = make_array(&number, 1, ns_float32);
|
||||||
return make_setitem(dout, VarSlices(vs), zero, ns_void);
|
return make_setitem(dout, VarSlices(vs, true), zero, ns_void);
|
||||||
} else {
|
} else {
|
||||||
return make_getitem(dout, VarSlices(vs));
|
return make_getitem(dout, VarSlices(vs, true));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (op == ns_add) {
|
if (op == ns_add) {
|
||||||
if (v_index == 0) {
|
if (v_index == 0) {
|
||||||
return dout;
|
return dout;
|
||||||
} else {
|
} else {
|
||||||
return make_getitem(dout, VarSlices(vs));
|
return make_getitem(dout, VarSlices(vs, true));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (op == ns_subtract) {
|
if (op == ns_subtract) {
|
||||||
if (v_index == 0) {
|
if (v_index == 0) {
|
||||||
return dout;
|
return dout;
|
||||||
} else {
|
} else {
|
||||||
return make_unary(make_getitem(dout, VarSlices(vs)), ns_negative);
|
return make_unary(make_getitem(dout, VarSlices(vs, true)), ns_negative);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (op == ns_multiply) {
|
if (op == ns_multiply) {
|
||||||
if (v_index == 0) {
|
if (v_index == 0) {
|
||||||
return make_setitem(dout, VarSlices(vs), input(1), ns_multiply);
|
return make_setitem(dout, VarSlices(vs, true), input(1), ns_multiply);
|
||||||
} else {
|
} else {
|
||||||
return make_binary(
|
return make_binary(
|
||||||
make_getitem(inputs().front(), VarSlices(vs)),
|
make_getitem(inputs().front(), VarSlices(vs, true)),
|
||||||
make_getitem(dout, VarSlices(vs)), ns_multiply);
|
make_getitem(dout, VarSlices(vs, true)), ns_multiply);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (op == ns_divide) {
|
if (op == ns_divide) {
|
||||||
if (v_index == 0) {
|
if (v_index == 0) {
|
||||||
return make_setitem(dout, VarSlices(vs), input(1), ns_divide);
|
return make_setitem(dout, VarSlices(vs, true), input(1), ns_divide);
|
||||||
} else {
|
} else {
|
||||||
// dy = -dz*x / y^2
|
// dy = -dz*x / y^2
|
||||||
auto dout2 = make_getitem(dout, VarSlices(vs));
|
auto dout2 = make_getitem(dout, VarSlices(vs, true));
|
||||||
auto x = make_getitem(inputs().front(), VarSlices(vs));
|
auto x = make_getitem(inputs().front(), VarSlices(vs, true));
|
||||||
auto y = v;
|
auto y = v;
|
||||||
auto ndz = make_unary(dout2, ns_negative);
|
auto ndz = make_unary(dout2, ns_negative);
|
||||||
auto ndzx = make_binary(ndz, x, ns_multiply);
|
auto ndzx = make_binary(ndz, x, ns_multiply);
|
||||||
|
|
|
@ -47,9 +47,16 @@ struct VarSlices {
|
||||||
inline VarSlices(VarSlices&& other) : slices(other.slices), n(other.n) {
|
inline VarSlices(VarSlices&& other) : slices(other.slices), n(other.n) {
|
||||||
other.slices = nullptr;
|
other.slices = nullptr;
|
||||||
}
|
}
|
||||||
inline VarSlices(const VarSlices& other) : slices(new VarSlice[other.n]), n(other.n) {
|
inline VarSlices(const VarSlices& other, bool negtive_set_none=false) : slices(new VarSlice[other.n]), n(other.n) {
|
||||||
for (int i=0; i<n; i++)
|
for (int i=0; i<n; i++) {
|
||||||
slices[i] = other.slices[i];
|
slices[i] = other.slices[i];
|
||||||
|
if (negtive_set_none &&
|
||||||
|
slices[i].is_slice() &&
|
||||||
|
slices[i].slice.step < 0 &&
|
||||||
|
slices[i].slice.stop < 0) {
|
||||||
|
slices[i].slice.mask |= 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
inline void operator=(VarSlices&& other) {
|
inline void operator=(VarSlices&& other) {
|
||||||
if (slices) delete[] slices;
|
if (slices) delete[] slices;
|
||||||
|
|
|
@ -327,6 +327,19 @@ class TestRNN(unittest.TestCase):
|
||||||
np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06)
|
np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06)
|
||||||
np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06)
|
np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06)
|
||||||
|
|
||||||
|
def test_twobilinear_lstm(self):
|
||||||
|
x = jt.rand(5, 4, 10)
|
||||||
|
rnn1 = nn.LSTM(10, 20, bidirectional=True)
|
||||||
|
out1, _ = rnn1(x)
|
||||||
|
rnn2 = nn.LSTM(40, 20, bidirectional=True)
|
||||||
|
out2, _ = rnn2(out1)
|
||||||
|
target = jt.zeros_like(out2)
|
||||||
|
loss = nn.mse_loss(out2, target)
|
||||||
|
|
||||||
|
from jittor import optim
|
||||||
|
optimizer = optim.RMSprop(rnn1.parameters())
|
||||||
|
optimizer.step(loss)
|
||||||
|
|
||||||
@skipIf(not jt.has_cuda, "No Cuda found")
|
@skipIf(not jt.has_cuda, "No Cuda found")
|
||||||
@jt.flag_scope(use_cuda=1)
|
@jt.flag_scope(use_cuda=1)
|
||||||
def test_cudnn_rnn(self):
|
def test_cudnn_rnn(self):
|
||||||
|
|
|
@ -225,8 +225,14 @@ class TestSetitem(unittest.TestCase):
|
||||||
b = a[...,:,None,:2]
|
b = a[...,:,None,:2]
|
||||||
assert b.shape == [2,4,1,2]
|
assert b.shape == [2,4,1,2]
|
||||||
np.testing.assert_allclose(b.data, a.data[...,:,None,:2])
|
np.testing.assert_allclose(b.data, a.data[...,:,None,:2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_flip_grad(self):
|
||||||
|
a = jt.rand(10)
|
||||||
|
b = a[::-1]
|
||||||
|
c = b[::-1]
|
||||||
|
d = c.sum()
|
||||||
|
jt.grad(d, [a])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
Loading…
Reference in New Issue