polish setitem grad nullptr

This commit is contained in:
Dun Liang 2022-05-19 11:34:50 +08:00
parent ba266fa99c
commit 643ca5bbb4
5 changed files with 34 additions and 27 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.4.6'
__version__ = '1.3.4.7'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -622,8 +622,11 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
check_nan(var);
#endif
#ifdef JT_SYNC
#ifdef HAS_CUDA
checkCudaErrors(cudaGetLastError());
checkCudaErrors(cudaDeviceSynchronize());
#endif
#endif
LOGvvv << "Finished Op(" >> op->name() << rid >>
"/" >> queue.size() >> ") output:" << op->outputs();
if (is_fused_op) {

View File

@ -419,12 +419,33 @@ unordered_set<string> binary_ops = {
};
BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
auto xdim = x->shape.size();
auto ydim = y->shape.size();
bool need_broadcast = xdim != ydim;
for (size_t i=0; i<xdim && i<ydim; i++) {
auto xshape = x->shape[xdim-i-1];
auto yshape = y->shape[ydim-i-1];
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
need_broadcast = true;
continue;
}
CHECKop(xshape,==,yshape) << "Shape not match, x:" >> x->to_string()
<< " y:" >> y->to_string();
}
if (need_broadcast) {
auto xp = make_broadcast_to(x, y, {});
auto yp = make_broadcast_to(y, x, {});
auto zp = make_binary(xp, yp, op);
forward(zp);
return;
}
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
set_type(OpType::element);
ns = op;
ASSERT(ns.is_binary());
z = create_output(nullptr, binary_dtype_infer(op, x->ns, y->ns));
z = create_output(x->shape, binary_dtype_infer(op, x->ns, y->ns));
bool bin = ns.get(NanoString::_no_need_back_in);
bool bout = ns.get(NanoString::_no_need_back_out);
if (bin || bout) {
@ -516,31 +537,6 @@ VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
}
void BinaryOp::infer_shape() {
auto xdim = x->shape.size();
auto ydim = y->shape.size();
bool need_broadcast = xdim != ydim;
for (size_t i=0; i<xdim && i<ydim; i++) {
auto xshape = x->shape[xdim-i-1];
auto yshape = y->shape[ydim-i-1];
// -1 1 need b
// has 1, b, both 1, not b, 0, error
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
// CHECK(xshape && yshape) << "Shape can not broadcast to 0.";
need_broadcast = true;
continue;
}
if (xshape<0 || yshape<0 ) continue;
CHECKop(xshape,==,yshape) << "Shape not match, x:" >> x->to_string()
<< " y:" >> y->to_string();
}
if (need_broadcast) {
auto xp = make_broadcast_to(x, y, {});
auto yp = make_broadcast_to(y, x, {});
set_inputs({x=xp, y=yp});
// infer shape again
infer_shape();
} else
z->set_shape(x->shape);
}
void BinaryOp::jit_prepare(JK& jk) {

View File

@ -132,6 +132,7 @@ void SetitemOp::infer_shape() {
}
void SetitemOp::grads(Var** dout, VarPtr* dins) {
if (!dout[0]) return;
auto outs = make_getitem2(dout[0], VarSlices(vs, true), 0);
dins[0] = move(outs[1]);
dins[1] = move(outs[0]);

View File

@ -244,6 +244,13 @@ class TestCore(unittest.TestCase):
for i in range(10):
assert orders[i] <= 14+i*3
def test_bc_bug(self):
a = jt.zeros((1,1))
b = a * 0.5
b.sync()
da = jt.grad(b, a)
da.sync()
if __name__ == "__main__":
unittest.main()