mirror of https://github.com/Jittor/Jittor
polish setitem grad nullptr
This commit is contained in:
parent
ba266fa99c
commit
643ca5bbb4
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.4.6'
|
||||
__version__ = '1.3.4.7'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -622,8 +622,11 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
|||
check_nan(var);
|
||||
#endif
|
||||
#ifdef JT_SYNC
|
||||
#ifdef HAS_CUDA
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
#endif
|
||||
#endif
|
||||
LOGvvv << "Finished Op(" >> op->name() << rid >>
|
||||
"/" >> queue.size() >> ") output:" << op->outputs();
|
||||
if (is_fused_op) {
|
||||
|
|
|
@ -419,12 +419,33 @@ unordered_set<string> binary_ops = {
|
|||
};
|
||||
|
||||
BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
|
||||
auto xdim = x->shape.size();
|
||||
auto ydim = y->shape.size();
|
||||
bool need_broadcast = xdim != ydim;
|
||||
for (size_t i=0; i<xdim && i<ydim; i++) {
|
||||
auto xshape = x->shape[xdim-i-1];
|
||||
auto yshape = y->shape[ydim-i-1];
|
||||
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
|
||||
need_broadcast = true;
|
||||
continue;
|
||||
}
|
||||
CHECKop(xshape,==,yshape) << "Shape not match, x:" >> x->to_string()
|
||||
<< " y:" >> y->to_string();
|
||||
}
|
||||
if (need_broadcast) {
|
||||
auto xp = make_broadcast_to(x, y, {});
|
||||
auto yp = make_broadcast_to(y, x, {});
|
||||
auto zp = make_binary(xp, yp, op);
|
||||
forward(zp);
|
||||
return;
|
||||
}
|
||||
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::element);
|
||||
ns = op;
|
||||
ASSERT(ns.is_binary());
|
||||
z = create_output(nullptr, binary_dtype_infer(op, x->ns, y->ns));
|
||||
z = create_output(x->shape, binary_dtype_infer(op, x->ns, y->ns));
|
||||
bool bin = ns.get(NanoString::_no_need_back_in);
|
||||
bool bout = ns.get(NanoString::_no_need_back_out);
|
||||
if (bin || bout) {
|
||||
|
@ -516,31 +537,6 @@ VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void BinaryOp::infer_shape() {
|
||||
auto xdim = x->shape.size();
|
||||
auto ydim = y->shape.size();
|
||||
bool need_broadcast = xdim != ydim;
|
||||
for (size_t i=0; i<xdim && i<ydim; i++) {
|
||||
auto xshape = x->shape[xdim-i-1];
|
||||
auto yshape = y->shape[ydim-i-1];
|
||||
// -1 1 need b
|
||||
// has 1, b, both 1, not b, 0, error
|
||||
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
|
||||
// CHECK(xshape && yshape) << "Shape can not broadcast to 0.";
|
||||
need_broadcast = true;
|
||||
continue;
|
||||
}
|
||||
if (xshape<0 || yshape<0 ) continue;
|
||||
CHECKop(xshape,==,yshape) << "Shape not match, x:" >> x->to_string()
|
||||
<< " y:" >> y->to_string();
|
||||
}
|
||||
if (need_broadcast) {
|
||||
auto xp = make_broadcast_to(x, y, {});
|
||||
auto yp = make_broadcast_to(y, x, {});
|
||||
set_inputs({x=xp, y=yp});
|
||||
// infer shape again
|
||||
infer_shape();
|
||||
} else
|
||||
z->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void BinaryOp::jit_prepare(JK& jk) {
|
||||
|
|
|
@ -132,6 +132,7 @@ void SetitemOp::infer_shape() {
|
|||
}
|
||||
|
||||
void SetitemOp::grads(Var** dout, VarPtr* dins) {
|
||||
if (!dout[0]) return;
|
||||
auto outs = make_getitem2(dout[0], VarSlices(vs, true), 0);
|
||||
dins[0] = move(outs[1]);
|
||||
dins[1] = move(outs[0]);
|
||||
|
|
|
@ -244,6 +244,13 @@ class TestCore(unittest.TestCase):
|
|||
for i in range(10):
|
||||
assert orders[i] <= 14+i*3
|
||||
|
||||
def test_bc_bug(self):
|
||||
a = jt.zeros((1,1))
|
||||
b = a * 0.5
|
||||
b.sync()
|
||||
da = jt.grad(b, a)
|
||||
da.sync()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue