diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 23124737..a7e2c7fa 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.2.6' +__version__ = '1.3.3.0' from jittor_utils import lock with lock.lock_scope(): ori_int = int @@ -355,10 +355,10 @@ def array64(data, dtype=None): with jt.flag_scope(auto_convert_64_to_32=0): return array(data, dtype) -def grad(loss, targets): +def grad(loss, targets, retain_graph=True): if type(targets) == core.Var: - return core.grad(loss, [targets])[0] - return core.grad(loss, targets) + return core.grad(loss, [targets], retain_graph)[0] + return core.grad(loss, targets, retain_graph) def liveness_info(): return { diff --git a/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.cc b/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.cc index 1dd1fe8b..28bbeded 100644 --- a/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.cc +++ b/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.cc @@ -28,6 +28,8 @@ CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdim ASSERT(offsets->dtype()==ns_int32); y = create_output(nullptr, ns_int32); y_key = create_output(nullptr, x->dtype()); + flags.set(NodeFlags::_manual_set_vnbb); + y->flags.set(NodeFlags::_needed_by_backward); } VarPtr CubArgReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { diff --git a/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc b/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc index f1e2a0c4..09260855 100644 --- a/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc +++ b/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc @@ -27,6 +27,8 @@ CubArgsortOp::CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending, ASSERT(offsets->dtype()==ns_int32); y = create_output(nullptr, dtype); y_key = create_output(nullptr, x->dtype()); + flags.set(NodeFlags::_manual_set_vnbb); + y->flags.set(NodeFlags::_needed_by_backward); } VarPtr CubArgsortOp::grad(Var* out, Var* dout, Var* v, int v_index) { diff --git a/python/jittor/extern/cuda/cub/ops/cub_where_op.cc b/python/jittor/extern/cuda/cub/ops/cub_where_op.cc index d31144e2..ade98be8 100644 --- a/python/jittor/extern/cuda/cub/ops/cub_where_op.cc +++ b/python/jittor/extern/cuda/cub/ops/cub_where_op.cc @@ -26,7 +26,6 @@ namespace jittor { CubWhereOp::CubWhereOp(Var* cond, NanoString dtype) : cond(cond) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); - flags.set(NodeFlags::_vary_shape); auto ndim = cond->shape.size(); outs.reset(new Var*[ndim]); for (uint i=0; ishape.size(); - auto num = cond->num; - if (num>0) num = -num; + auto num = -cond->num; for (uint i=0; iset_shape({num}); } diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc index 75fc6c84..8c1eada1 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc @@ -35,6 +35,9 @@ CublasBatchedMatmulOp::CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool c = create_output(nullptr, a->dtype()); flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_manual_set_vnbb); + a->flags.set(NodeFlags::_needed_by_backward); + b->flags.set(NodeFlags::_needed_by_backward); } diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc index 2bb3a3bb..9dde34e9 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc @@ -24,6 +24,9 @@ CublasMatmulOp::CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b) : a(a), b(b), trans_a(trans_a), trans_b(trans_b) { flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + a->flags.set(NodeFlags::_needed_by_backward); + b->flags.set(NodeFlags::_needed_by_backward); // TODO: support int8 * int8 ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same"; // TODO: support diffrent input type diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc index 9679069e..11564932 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc @@ -29,6 +29,9 @@ CudnnConv3dBackwardWOp::CudnnConv3dBackwardWOp(Var* x, Var* dy, int kd, int kh, xformat(move(xformat)) { flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + x->flags.set(NodeFlags::_needed_by_backward); + dy->flags.set(NodeFlags::_needed_by_backward); dw = create_output(nullptr, dtype_infer(dy->ns, x->ns)); } diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc index 8a72b886..14a02ac7 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc @@ -29,6 +29,9 @@ CudnnConv3dBackwardXOp::CudnnConv3dBackwardXOp(Var* w, Var* dy, int depth, int h xformat(move(xformat)) { flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + w->flags.set(NodeFlags::_needed_by_backward); + dy->flags.set(NodeFlags::_needed_by_backward); dx = create_output(nullptr, dtype_infer(dy->ns, w->ns)); } diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc index 1bc1a866..e20c90eb 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc @@ -26,6 +26,9 @@ CudnnConv3dOp::CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int strid xformat(move(xformat)) { flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + x->flags.set(NodeFlags::_needed_by_backward); + w->flags.set(NodeFlags::_needed_by_backward); y = create_output(nullptr, dtype_infer(x->ns, w->ns)); } diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc index d63a84a7..4073a086 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc @@ -52,6 +52,9 @@ CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + x->flags.set(NodeFlags::_needed_by_backward); + dy->flags.set(NodeFlags::_needed_by_backward); dw = create_output(nullptr, dtype_infer(dy->ns, x->ns)); } diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc index 77769d4c..d8fff026 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc @@ -53,6 +53,9 @@ CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int widt xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + w->flags.set(NodeFlags::_needed_by_backward); + dy->flags.set(NodeFlags::_needed_by_backward); dx = create_output(nullptr, dtype_infer(dy->ns, w->ns)); } diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc index c495fdb6..7db49e82 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc @@ -50,6 +50,9 @@ CudnnConvOp::CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + x->flags.set(NodeFlags::_needed_by_backward); + w->flags.set(NodeFlags::_needed_by_backward); y = create_output(nullptr, dtype_infer(x->ns, w->ns)); if (!this->yformat.size()) this->yformat = this->xformat; diff --git a/python/jittor/misc.py b/python/jittor/misc.py index c2a9235e..f93b0e2c 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -801,6 +801,10 @@ def print_tree(now, max_memory_size, prefix1, prefix2, build_by): tab = ' ' out += prefix1+now['name']+'('+now['type']+')\n' out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%; cnt:'+format_size(now['cnt'],'') + ']\n' + if len(now['children']) == 0 and len(now['vinfo']): + out += prefix2+now['vinfo'][0] + if len(now['vinfo']) > 1: out += "..." + out += '\n' if (build_by == 0): for p in now['path']: out += prefix2+p+'\n' @@ -866,7 +870,8 @@ Output:: vars_ = vars_[1:] for v_ in vars_: v__ = v_.split(div2) - var = {'size':int(v__[1]), 'stack':[], 'cnt':1} + vinfo = v__[0].split("{")[0] + var = {'size':int(v__[1]), 'stack':[], 'cnt':1, "vinfo":vinfo} v__ = v__[2:-1] for s_ in v__: s__ = s_.split(div3) @@ -874,7 +879,7 @@ Output:: var['stack'].append(s) vars.append(var) if (build_by == 0): # build tree by name - tree = {'name':'root', "children":[], 'size':0, 'cnt':1, 'path':[], 'type':''} + tree = {'name':'root', "children":[], 'size':0, 'cnt':1, 'path':[], 'type':'', 'vinfo':[]} def find_child(now, key): for c in now['children']: @@ -885,6 +890,7 @@ Output:: now = tree now['size'] += v['size'] now['cnt'] += v['cnt'] + now['vinfo'].append(v['vinfo']) for s in v['stack']: ch = find_child(now, s['name']) if (ch is not None): @@ -894,12 +900,13 @@ Output:: now = ch now['size'] += v['size'] now['cnt'] += v['cnt'] + now['vinfo'].append(v['vinfo']) else: - now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':[s['path']], 'type':s['type']} + now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':[s['path']], 'type':s['type'], 'vinfo':[v['vinfo']]} now['children'].append(now_) now = now_ elif (build_by == 1): # build tree by path - tree = {'name':'root', "children":[], 'size':0, 'cnt':0, 'path':'_root_', 'type':''} + tree = {'name':'root', "children":[], 'size':0, 'cnt':0, 'path':'_root_', 'type':'', 'vinfo':[]} def find_child(now, key): for c in now['children']: @@ -910,14 +917,16 @@ Output:: now = tree now['size'] += v['size'] now['cnt'] += v['cnt'] + now['vinfo'].append(v['vinfo']) for s in v['stack']: ch = find_child(now, s['path']) if (ch is not None): now = ch now['size'] += v['size'] now['cnt'] += v['cnt'] + now['vinfo'].append(v['vinfo']) else: - now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':s['path'], 'type':s['type']} + now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':s['path'], 'type':s['type'], 'vinfo':[v['vinfo']]} now['children'].append(now_) now = now_ else: diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 3a921b29..58e0e46f 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -173,7 +173,8 @@ def relu(x): >>> nn.relu(a) jt.Var([0. 1.1338731 6.128115 ], dtype=float32) ''' - return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x)) + cond = x>0.0 + return jt.ternary_out_hint(cond, x, 0.0) def leaky_relu(x, scale=0.01): diff --git a/python/jittor/optim.py b/python/jittor/optim.py index 6af98f02..d7383f8b 100644 --- a/python/jittor/optim.py +++ b/python/jittor/optim.py @@ -98,7 +98,7 @@ class Optimizer(object): def zero_grad(self): self.__zero_grad = True - def pre_step(self, loss): + def pre_step(self, loss, retain_graph=False): """ something should be done before step, such as calc gradients, mpi sync, and so on. Example:: @@ -118,7 +118,7 @@ class Optimizer(object): params_has_grad.append(p) # get gradient - grads = jt.grad(loss, params_has_grad) + grads = jt.grad(loss, params_has_grad, retain_graph) # sync grads and model if in mpi if jt.in_mpi: @@ -153,7 +153,7 @@ class Optimizer(object): pid += 1 self.__zero_grad = False - def backward(self, loss): + def backward(self, loss, retain_graph=False): ''' optimize.backward(loss) is used for accumulate multiple step, it can be used as following: @@ -186,11 +186,11 @@ class Optimizer(object): ''' - self.pre_step(loss) + self.pre_step(loss, retain_graph) - def step(self, loss=None): + def step(self, loss=None, retain_graph=False): if loss is not None: - self.pre_step(loss) + self.pre_step(loss, retain_graph) for pg in self.param_groups: lr = pg.get("lr", self.lr) for p, g in zip(pg["params"], pg["grads"]): diff --git a/python/jittor/src/core.h b/python/jittor/src/core.h index 4d50358c..b725a026 100644 --- a/python/jittor/src/core.h +++ b/python/jittor/src/core.h @@ -33,6 +33,6 @@ inline static void __print_trace() { } // @pyjt(grad) -vector _grad(VarHolder* loss, const vector& targets); +vector _grad(VarHolder* loss, const vector& targets, bool retain_graph=true); } // jittor diff --git a/python/jittor/src/executor.cc b/python/jittor/src/executor.cc index 5deacd53..da10acd9 100644 --- a/python/jittor/src/executor.cc +++ b/python/jittor/src/executor.cc @@ -30,6 +30,7 @@ #include "memory_profiler.h" #include "utils/seh.h" #include "utils/cache_compile.h" +#include "var_holder.h" namespace jittor { @@ -102,6 +103,23 @@ void load_fused_op(FusedOp& fused_op, vector& fuse_ops, vector& ops, i } } +static inline void propergate_needed_flags(FusedOp& fused_op) { + auto& ops = fused_op.ops; + for (int i=ops.size()-1; i>=0; i--) { + bool has_need = 0; + auto op = ops[i]; + for (auto o : op->outputs()) + if (o->flags.get(NodeFlags::_needed_by_backward) && + !(o->custom_data&1)) { + has_need = 1; + } + if (has_need) + for (auto i : op->inputs()) { + i->flags.set(NodeFlags::_needed_by_backward); + } + } +} + void check_op_async_error(Op* op, bool is_fused_op, const std::exception& e, jittor::Log& logf) { vector stack; if (is_fused_op) { @@ -151,7 +169,30 @@ void check_op_async_error(Op* op, bool is_fused_op, const std::exception& e, jit jittor::LogFatalVoidify() && logf; } -void Executor::run_sync(vector vars, bool device_sync) { +static void top_weak_sync(vector& vars) { + auto t = ++Node::tflag_count; + int64 max_id=0; + for (auto v : vars) { + max_id = std::max(v->id, max_id); + v->tflag = t; + } + while (true) { + if (sync_ptr == hold_vars.begin()) + break; + auto next_ptr = std::prev(sync_ptr); + auto v = (*next_ptr)->var; + if (v->id > max_id) break; + sync_ptr = next_ptr; + if (v->tflag == t) continue; + if (v->_outputs.size()) continue; + if (v->is_finished()) continue; + vars.push_back(v); + } +} + +void Executor::run_sync(vector vars, bool device_sync, bool weak_sync) { + if (weak_sync) + top_weak_sync(vars); auto allocator = get_allocator(); auto temp_allocator = get_allocator(true); this->allocator = allocator; @@ -287,6 +328,10 @@ void Executor::run_sync(vector vars, bool device_sync) { // output: // queue: toplogical order of fused op { + // queue.clear(); + #ifndef JT_bfs_executor + map p_queue; + #endif for (int root : roots) { for (int i=root; i>=0; i=next[i]) { Op* op = ops[i]; @@ -299,24 +344,48 @@ void Executor::run_sync(vector vars, bool device_sync) { } } } + #ifdef JT_bfs_executor if (deps[root] == 0) queue.push_back(root); + #else + if (deps[root] == 0) + p_queue[ops[root]->id] = root; + #endif } - for (uint s=0; ssecond; + p_queue.erase(p_queue.begin()); + queue.push_back(op_id); + #endif for (int i=op_id; i>=0; i=next[i]) { Op* op = ops[i]; for (Var* v : op->outputs()) + { if (v->tflag == tt) - for (Op* op2 : v->outputs()) { + for (Op* op2 : v->outputs()) + { if (op2->tflag != tt) continue; int op2_id = father[op2->custom_data]; // continue if those two ops are fused if (op2_id == op_id) continue; deps[op2_id]--; + #ifdef JT_bfs_executor if (deps[op2_id] == 0) queue.push_back(op2_id); + #else + if (deps[op2_id] == 0) + p_queue[op2->id] = op2_id; + #endif } + } } } ASSERTop(queue.size(),==,roots.size()); @@ -479,9 +548,6 @@ void Executor::run_sync(vector vars, bool device_sync) { load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt); } LOGvvv << "Run" << op; - if (op->flags.get(NodeFlags::_has_vary_input)) op->init(); - ASSERT(!op->flags.get(NodeFlags::_has_vary_input)) - << "Shape of(" >> op->name() >> ") not solved."; for (auto* var : op->outputs()) { var->alloc(allocator); } @@ -557,6 +623,7 @@ void Executor::run_sync(vector vars, bool device_sync) { LOGvvv << "Finished Op(" >> op->name() << rid >> "/" >> queue.size() >> ") output:" << op->outputs(); if (is_fused_op) { + propergate_needed_flags(fused_op); for (Var* var : op->outputs()) var->finish_pending_liveness(); continue; @@ -596,7 +663,7 @@ void Executor::run_sync(vector vars, bool device_sync) { } } LOGvv << "All" << op_num << "ops finished, return vars:" << vars; - for (Var* v : vars) ASSERT(v->mem_ptr); + for (Var* v : vars) ASSERT(v->mem_ptr || !v->backward_liveness); // clean fetcher free buffer fetcher_to_free.clear(); #ifdef HAS_CUDA diff --git a/python/jittor/src/executor.h b/python/jittor/src/executor.h index dc21d096..fcabff4e 100644 --- a/python/jittor/src/executor.h +++ b/python/jittor/src/executor.h @@ -21,7 +21,7 @@ struct Executor { Allocator* allocator; Allocator* temp_allocator; bool last_is_cuda = false; - void run_sync(vector vars, bool device_sync); + void run_sync(vector vars, bool device_sync, bool weak_sync=true); inline Allocation alloc_temp(size_t size) { return Allocation(temp_allocator, size); diff --git a/python/jittor/src/fused_op.cc b/python/jittor/src/fused_op.cc index 924a1113..35105164 100644 --- a/python/jittor/src/fused_op.cc +++ b/python/jittor/src/fused_op.cc @@ -45,6 +45,7 @@ void FusedOp::update_ops() { _inputs.clear(); _outputs.clear(); + vars.clear(); for (Op* op : ops) { for (Var* o : op->outputs()) { if (o->loop_options) { @@ -93,10 +94,7 @@ void FusedOp::update_ops() { o->custom_data &= 1; } } - vars.clear(); - bool has_vary_input = 0; for (Op* opi : ops) { - has_vary_input |= opi->flags.get(NodeFlags::_has_vary_input); for (Var* i : opi->inputs()) { auto &c = i->custom_data; // if not visited @@ -116,7 +114,6 @@ void FusedOp::update_ops() { } } } - flags.set(NodeFlags::_has_vary_input, has_vary_input); LOGvvvv << "Var info" << vars; } @@ -144,12 +141,9 @@ FusedOp::~FusedOp() { } void FusedOp::infer_shape() { - bool has_vary_input = 0; for (Op* op : ops) { op->init(); - has_vary_input |= op->flags.get(NodeFlags::_has_vary_input); } - flags.set(NodeFlags::_has_vary_input, has_vary_input); } void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) { diff --git a/python/jittor/src/grad.cc b/python/jittor/src/grad.cc index 5e891714..f5809b1b 100644 --- a/python/jittor/src/grad.cc +++ b/python/jittor/src/grad.cc @@ -10,6 +10,7 @@ #include "op.h" #include "graph.h" #include "ops/op_register.h" +#include "var_holder.h" namespace jittor { @@ -76,7 +77,7 @@ void warn_grad_break(int i, Var* v) { LOGw << "grads[">>i>>"] '">> v->name>>"' doesn't have gradient. It will be set to zero:" << v; } -vector grad(Var* loss, vector targets) { +vector grad(Var* loss, vector targets, bool retain_graph) { LOGvv << "loss:" >> loss << "targets:" >> targets; CHECK(loss->is_float()) << "Loss should be float"; for (Var* var : targets) @@ -259,6 +260,20 @@ vector grad(Var* loss, vector targets) { assign_attrs(grad.ptr, var); } } + if (!retain_graph) { + auto t = ++Node::tflag_count; + for (auto& vh : hold_vars) + if (vh->var->tflag != t) { + vh->var->tflag = t; + } + SetupFreeBuffer setup_free_buffer; + for (int i=int(gvars.size())-1; i>=0; i--) + if (gvars[i]->tflag != t && gvars[i]->backward_liveness) + gvars[i]->set_stop_grad(); + for (int i=0; iset_stop_grad(); + } return results; } diff --git a/python/jittor/src/grad.h b/python/jittor/src/grad.h index 69f67158..2e0b5758 100644 --- a/python/jittor/src/grad.h +++ b/python/jittor/src/grad.h @@ -9,7 +9,7 @@ namespace jittor { -vector grad(Var* loss, vector targets); +vector grad(Var* loss, vector targets, bool retain_graph=true); // @pyjt(tape_together) void tape_together( diff --git a/python/jittor/src/graph.cc b/python/jittor/src/graph.cc index 27f52ada..7a6f4995 100644 --- a/python/jittor/src/graph.cc +++ b/python/jittor/src/graph.cc @@ -5,6 +5,7 @@ // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** #include +#include #include "graph.h" #include "var_holder.h" #include "var.h" @@ -32,7 +33,7 @@ void do_graph_check() { LOGvv << "Check hold_vars size" << queue.size(); int vhsize = queue.size(); for (auto* node : queue) { - ASSERTop(node->forward_liveness,>,0); + // ASSERTop(node->forward_liveness,>,0); ASSERTop(node->backward_liveness,>,0); } for (uint i=0; iinputs()) { if (i->is_stop_grad()) continue; @@ -62,7 +63,7 @@ void do_graph_check() { if (o->pending_liveness && !o->is_finished()) p++; } - if (f>0 && b>0 && !node->is_finished()) p++; + // if (f>0 && b>0 && !node->is_finished()) p++; if (f!=node->forward_liveness || b!=node->backward_liveness || p!=node->pending_liveness) { LOGf << "ERROR" << node << '\n' << f << b << p << i << '\n' @@ -91,6 +92,8 @@ DumpGraphs dump_all_graphs() { queue.push_back(vh->var); } bfs_both(queue, [](Node*){return true;}); + std::sort(queue.begin(), queue.end(), + [](Node* a, Node* b) { return a->id < b->id;}); DumpGraphs graphs; for (uint i=0; icustom_data = i; diff --git a/python/jittor/src/memory_profiler.cc b/python/jittor/src/memory_profiler.cc index 1741073a..aa500867 100644 --- a/python/jittor/src/memory_profiler.cc +++ b/python/jittor/src/memory_profiler.cc @@ -76,7 +76,7 @@ void MemoryProfiler::check() { allocations.clear(); size_t memory_size = 0; std::vector>, size_t>> live_vars; - vector queue; + vector queue, queue2; auto t = ++Node::tflag_count; for (auto& vh : hold_vars) @@ -85,6 +85,14 @@ void MemoryProfiler::check() { queue.push_back(vh->var); } bfs_both(queue, [](Node*){return true;}); + vector backup_custom_data; + backup_custom_data.resize(queue.size()); + for (int i=0; icustom_data; + toplogical_sort_forward(queue, queue2, [](Node*){}); + for (int i=0; icustom_data = backup_custom_data[i]; + queue.swap(queue2); for (Node* node : queue) { if (node->is_var()) { Var* var = (Var*)node; diff --git a/python/jittor/src/misc/nano_string.cc b/python/jittor/src/misc/nano_string.cc index 4a9c5773..0cd3595f 100644 --- a/python/jittor/src/misc/nano_string.cc +++ b/python/jittor/src/misc/nano_string.cc @@ -137,6 +137,24 @@ static unordered_set white_ops = { "pow", }; +static unordered_set no_need_back_in = { + "void", + "cast", + "negative", + "add", + "subtract", + "mean", +}; + +static unordered_set no_need_back_out = { + "void", + "cast", + "negative", + "add", + "subtract", + "multiply", +}; + #define DEFINE_NS(T) NanoString ns_##T; FOR_ALL_NS(DEFINE_NS); @@ -172,6 +190,8 @@ static void init_ns() { ns.set(NanoString::_float, float_ops.count(name)); } ns.set(NanoString::_white_list, white_ops.count(name)); + ns.set(NanoString::_no_need_back_in, no_need_back_in.count(name)); + ns.set(NanoString::_no_need_back_out, no_need_back_out.count(name)); __string_to_ns[name] = ns; auto name2 = ns.to_cstring(); int len=0; diff --git a/python/jittor/src/misc/nano_string.h b/python/jittor/src/misc/nano_string.h index 481c3d6c..55517c4d 100644 --- a/python/jittor/src/misc/nano_string.h +++ b/python/jittor/src/misc/nano_string.h @@ -98,7 +98,7 @@ EXTERN_LIB int __ns_len[]; // @pyjt(NanoString) struct NanoString { - typedef uint16 ns_t; + typedef uint32 ns_t; enum Flags { // bit0~7: index _index=0, _index_nbits=7, @@ -119,6 +119,9 @@ struct NanoString { _dsize=_n+6, _dsize_nbits=2, // bit8: white list _white_list=_n+8, + // bit9: backward opt + _no_need_back_in=_n+9, + _no_need_back_out=_n+10, }; ns_t data=0; diff --git a/python/jittor/src/node.h b/python/jittor/src/node.h index c976b0b8..21c58d5b 100644 --- a/python/jittor/src/node.h +++ b/python/jittor/src/node.h @@ -18,7 +18,7 @@ EXTERN_LIB int64 nt; EXTERN_LIB vector free_buffer; struct NodeFlags { - typedef uint16 nf_t; + typedef uint32 nf_t; nf_t flags=0; enum Flags { // bit0: is_var @@ -35,6 +35,8 @@ struct NodeFlags { _force_fuse=_n+0, _stop_fuse=_n+1, _in_update_queue=_n+2, + _needed_by_backward=_n+3, + _out_hint=_n+4, // op related flags // bit0: support cpu @@ -53,13 +55,14 @@ struct NodeFlags { _has_gopt=_n+7, // bit8: has vary input _has_vary_input=_n+8, + _manual_set_vnbb = _n+9, // bit9: prefer 32 bit - _prefer_32=_n+9, - // bit10: force 16 bit - _prefer_16=_n+10, - // bit11: reduce keep type unchange - _reduce_keep=_n+11, - _custom_flag=_reduce_keep, + _prefer_32=_n+10, + // force 16 bit + _prefer_16=_prefer_32+1, + // reduce keep type unchange + _reduce_keep=_prefer_32+2, + _custom_flag = _reduce_keep, }; inline void set(Flags f, int a=1, int nbits=1) { @@ -89,8 +92,8 @@ struct Node { }; struct output_t { Node* node; - int index; list::iterator back; + int index; output_t(Node* n, int i) : node(n), index(i) {} operator Node*() { return node; } operator Op*() { return (Op*)node; } @@ -120,14 +123,15 @@ struct Node { inline bool need_free() { return !pending_liveness && (!forward_liveness || !backward_liveness); } - int64_t tflag = 0; - int64_t custom_data; + int custom_data; + int64 tflag = 0; + int64 id; list _inputs; list _outputs; #ifdef NODE_MEMCHECK inline Node() { - lived_nodes[(void*)this] = ++total_node; + lived_nodes[(void*)this] = id = ++total_node; } inline virtual ~Node() { @@ -135,7 +139,7 @@ struct Node { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this); } #else - inline Node() {}; + inline Node() { id = ++total_node; }; inline virtual ~Node() { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);}; #endif inline Var* var() { return (Var*)this; } diff --git a/python/jittor/src/op.cc b/python/jittor/src/op.cc index 35a42f06..88c48f61 100644 --- a/python/jittor/src/op.cc +++ b/python/jittor/src/op.cc @@ -14,6 +14,8 @@ #include "mem/allocator.h" #include "misc/cuda_flags.h" #include "pybind/py_var_tracer.h" +#include "executor.h" +#include "var_holder.h" namespace jittor { @@ -65,14 +67,26 @@ Var* Op::create_output(NanoVector shape, NanoString dtype) { } void Op::init() { - bool has_vary_input = 0; - for (Var* v : inputs()) - if (v->num < 0) { - has_vary_input = 1; - break; - } - flags.set(NodeFlags::_has_vary_input, has_vary_input); infer_shape(); + bool manual_set_vnbb = flags.get(NodeFlags::_manual_set_vnbb) + || _inputs.size()==0 + || (_outputs.size()==1 && _outputs.front().node->is_stop_grad()); + for (Var* v : inputs()) { + if (!manual_set_vnbb) { + v->flags.set(NodeFlags::_needed_by_backward); + } + } + Var* need_sync = nullptr; + for (Var* v : outputs()) { + if (!manual_set_vnbb) + v->flags.set(NodeFlags::_needed_by_backward); + if (v->num < 0) + need_sync = v; + } + if (need_sync) { + exe.run_sync(vector({need_sync}), false); + CHECK(need_sync->num >= 0) << need_sync << "'s shape is error"; + } } void Op::compile_optimize(string& src) {} @@ -84,7 +98,7 @@ void Op::graph_optimize() {} string Op::name_ex() const { string a=name(); - if (ns!=ns_void) { + if (ns.data) { a += '.'; a += ns.to_cstring(); } @@ -266,7 +280,7 @@ void Op::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) { std::ostream& operator<<(std::ostream& os, const Op* op) { if (!op) return os << "Op(0)"; - os << "Op(" << (void*)op + os << "Op(" << op->id << ':' << op->forward_liveness << ':' << op->backward_liveness << ':' << op->pending_liveness diff --git a/python/jittor/src/ops/arg_reduce_op.cc b/python/jittor/src/ops/arg_reduce_op.cc index c672a837..d771baa7 100644 --- a/python/jittor/src/ops/arg_reduce_op.cc +++ b/python/jittor/src/ops/arg_reduce_op.cc @@ -46,7 +46,6 @@ ArgReduceOp::ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims) get_op_info("cub_arg_reduce").get_constructor, Var*, Var*, NanoString, bool>() : nullptr; if (cub_arg_reduce) { - if (x->num<0) exe.run_sync(vector({x}), true); int dims = x->shape.size(); vector axes; axes.reserve(dims); @@ -88,6 +87,8 @@ ArgReduceOp::ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims) #endif y = create_output(nullptr, ns_int32); y_key = create_output(nullptr, x->dtype()); + flags.set(NodeFlags::_manual_set_vnbb); + y->flags.set(NodeFlags::_needed_by_backward); } VarPtr ArgReduceOp::get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y) { // Do not have grad to extras input diff --git a/python/jittor/src/ops/argsort_op.cc b/python/jittor/src/ops/argsort_op.cc index 5bf3afc2..d2d098fc 100644 --- a/python/jittor/src/ops/argsort_op.cc +++ b/python/jittor/src/ops/argsort_op.cc @@ -45,7 +45,6 @@ ArgsortOp::ArgsortOp(Var* x, int dim, bool descending, NanoString dtype) .get_constructor, Var*, Var*, Var*, bool, NanoString>(); } if (cub_argsort) { - if (x->num<0) exe.run_sync(vector({x}), true); int dims = x->shape.size(); vector axes; axes.reserve(dims); @@ -81,6 +80,8 @@ ArgsortOp::ArgsortOp(Var* x, int dim, bool descending, NanoString dtype) #endif y = create_output(nullptr, dtype); y_key = create_output(nullptr, x->dtype()); + flags.set(NodeFlags::_manual_set_vnbb); + y->flags.set(NodeFlags::_needed_by_backward); } VarPtr ArgsortOp::get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y) { diff --git a/python/jittor/src/ops/binary_op.cc b/python/jittor/src/ops/binary_op.cc index 197d90ff..4ea7c10f 100644 --- a/python/jittor/src/ops/binary_op.cc +++ b/python/jittor/src/ops/binary_op.cc @@ -425,6 +425,18 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) { ns = op; ASSERT(ns.is_binary()); z = create_output(nullptr, binary_dtype_infer(op, x->ns, y->ns)); + bool bin = ns.get(NanoString::_no_need_back_in); + bool bout = ns.get(NanoString::_no_need_back_out); + if (bin || bout) { + flags.set(NodeFlags::_manual_set_vnbb); + if (!bin) { + x->flags.set(NodeFlags::_needed_by_backward); + y->flags.set(NodeFlags::_needed_by_backward); + } + if (!bout) { + z->flags.set(NodeFlags::_needed_by_backward); + } + } } VarPtr dirty_clone_broadcast(Var* v) { diff --git a/python/jittor/src/ops/broadcast_to_op.cc b/python/jittor/src/ops/broadcast_to_op.cc index e8f584bf..4a5453c0 100644 --- a/python/jittor/src/ops/broadcast_to_op.cc +++ b/python/jittor/src/ops/broadcast_to_op.cc @@ -29,6 +29,7 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) { } flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); set_type(OpType::broadcast); z = create_output(NanoVector(), x->dtype()); bcast_mask = 0; diff --git a/python/jittor/src/ops/candidate_op.cc b/python/jittor/src/ops/candidate_op.cc index f50955cb..769c5258 100644 --- a/python/jittor/src/ops/candidate_op.cc +++ b/python/jittor/src/ops/candidate_op.cc @@ -16,12 +16,12 @@ namespace jittor { CandidateOp::CandidateOp(Var* x, string&& fail_cond, NanoString dtype) : x(x), fail_cond(move(fail_cond)) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); - flags.set(NodeFlags::_vary_shape); + flags.set(NodeFlags::_manual_set_vnbb); y = create_output(nullptr, dtype); } void CandidateOp::infer_shape() { - y->set_shape({-std::abs(x->shape[0])}); + y->set_shape({-x->shape[0]}); } void CandidateOp::jit_prepare(JK& jk) { diff --git a/python/jittor/src/ops/clone_op.cc b/python/jittor/src/ops/clone_op.cc index 17bc4893..6f85f599 100644 --- a/python/jittor/src/ops/clone_op.cc +++ b/python/jittor/src/ops/clone_op.cc @@ -19,6 +19,7 @@ static auto make_clone = get_op_info("clone") CloneOp::CloneOp(Var* x) : x(x) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); y = create_output(nullptr, x->dtype()); if (x->name.ptr) y->name = x->name; diff --git a/python/jittor/src/ops/code_op.cc b/python/jittor/src/ops/code_op.cc index 2a6f5662..7db34f03 100644 --- a/python/jittor/src/ops/code_op.cc +++ b/python/jittor/src/ops/code_op.cc @@ -37,7 +37,6 @@ CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector&& inputs, _outputs.push_back(create_output(shape, dtype)); if (_outputs[0]->num < 0) { - flags.set(NodeFlags::_vary_shape); check_vary_shape(_outputs[0]->shape); } } @@ -58,7 +57,6 @@ CodeOp::CodeOp( for (int i=0; inum < 0) { - flags.set(NodeFlags::_vary_shape); check_vary_shape(_outputs[i]->shape); } } diff --git a/python/jittor/src/ops/copy_op.cc b/python/jittor/src/ops/copy_op.cc index 5d62e1c8..771cb928 100644 --- a/python/jittor/src/ops/copy_op.cc +++ b/python/jittor/src/ops/copy_op.cc @@ -20,6 +20,7 @@ namespace jittor { CopyOp::CopyOp(Var* x) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); auto y = create_output(nullptr, x->dtype()); if (x->name.ptr) y->name = x->name; diff --git a/python/jittor/src/ops/fuse_transpose_op.cc b/python/jittor/src/ops/fuse_transpose_op.cc index aec2b2ef..c1facc62 100644 --- a/python/jittor/src/ops/fuse_transpose_op.cc +++ b/python/jittor/src/ops/fuse_transpose_op.cc @@ -33,6 +33,7 @@ FuseTransposeOp::FuseTransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); set_type(tp); + flags.set(NodeFlags::_manual_set_vnbb); int i=0; for (; iflags.set(NodeFlags::_needed_by_backward); create_output(nullptr, x->dtype()); } @@ -48,6 +52,10 @@ GetitemOp::GetitemOp(Var* x, VarSlices&& slices, int _) flags.set(NodeFlags::_has_gopt); flags.set(NodeFlags::_custom_flag); flags.set(NodeFlags::_grads); + flags.set(NodeFlags::_manual_set_vnbb); + for (int i=0; iflags.set(NodeFlags::_needed_by_backward); create_output(nullptr, x->dtype()); auto out2 = create_output(nullptr, x->dtype()); out2->share_with(x); @@ -421,7 +429,6 @@ void GetitemOp::grads(Var** dout, VarPtr* dins) { VarPtr y = dout[0]; if (!x) { auto in = inputs().front(); - if (in->num<0) exe.run_sync(vector({in}), true); // ns.data represents this is the last split var if (ns.data) x = make_empty(in->shape, in->dtype()); diff --git a/python/jittor/src/ops/index_op.cc b/python/jittor/src/ops/index_op.cc index 5e99f96e..2a3f1c00 100644 --- a/python/jittor/src/ops/index_op.cc +++ b/python/jittor/src/ops/index_op.cc @@ -31,6 +31,7 @@ IndexOp::IndexOp(Var* a, int64 dim, NanoString dtype) : dim(dim) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); set_type(OpType::element); + flags.set(NodeFlags::_manual_set_vnbb); x.reset(new Var*[1]); x[0] = create_output(nullptr, dtype); } @@ -38,6 +39,7 @@ IndexOp::IndexOp(Var* a, int64 dim, NanoString dtype) : dim(dim) { IndexOp::IndexOp(Var* a, NanoString dtype) : dim(a->shape.size()) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); set_type(OpType::element); x.reset(new Var*[dim]); for (int i=0; ishape.size(); @@ -279,6 +281,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask) flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); set_type(OpType::reduce); + if (op.get(NanoString::_no_need_back_in)) + flags.set(NodeFlags::_manual_set_vnbb); ns = op; ASSERT(ns.is_binary()); reduce_mask = dims_mask; @@ -319,12 +323,6 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { return make_binary(b, v, ns_divide); } if (ns == ns_mean) { - if (v->num < 0) { - // TODO: Dynamic shape of mean grad was not supported yet - LOGw << "Dynamic shape of mean grad cause synchronize."; - exe.run_sync({v}, 0); - ASSERT(v->num>=0); - } VarPtr a = make_broadcast_to(dout, v, reduce_mask, keepdims_mask); VarPtr n = make_number(1.0f*out->num / v->num, a); return make_binary(a, n, ns_multiply); diff --git a/python/jittor/src/ops/reindex_op.cc b/python/jittor/src/ops/reindex_op.cc index be38b97d..f3cd242f 100644 --- a/python/jittor/src/ops/reindex_op.cc +++ b/python/jittor/src/ops/reindex_op.cc @@ -27,6 +27,7 @@ ReindexOp::ReindexOp(Var* x, NanoVector shape, vector&& indexes, float64 flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); set_type(OpType::broadcast); + flags.set(NodeFlags::_manual_set_vnbb); y = create_output(nullptr, x->dtype()); } diff --git a/python/jittor/src/ops/reindex_reduce_op.cc b/python/jittor/src/ops/reindex_reduce_op.cc index 93b958e6..77b84477 100644 --- a/python/jittor/src/ops/reindex_reduce_op.cc +++ b/python/jittor/src/ops/reindex_reduce_op.cc @@ -28,6 +28,8 @@ ReindexReduceOp::ReindexReduceOp(Var* y, NanoString op, NanoVector shape, vector flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); set_type(OpType::reduce); + if (op.get(NanoString::_no_need_back_in)) + flags.set(NodeFlags::_manual_set_vnbb); ns = op; ASSERT(ns.is_binary() && ns!=ns_mean); x = create_output(nullptr, y->dtype()); diff --git a/python/jittor/src/ops/reshape_op.cc b/python/jittor/src/ops/reshape_op.cc index f743a9ad..666cf400 100644 --- a/python/jittor/src/ops/reshape_op.cc +++ b/python/jittor/src/ops/reshape_op.cc @@ -20,6 +20,7 @@ static auto make_reshape = get_op_info("reshape") ReshapeOp::ReshapeOp(Var* x, NanoVector shape) : x(x), shape(shape) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); y = create_output(nullptr, x->dtype()); ASSERT(shape.size() > 0) << "input target shape of reshape can't be empty."; } diff --git a/python/jittor/src/ops/safe_clip_op.cc b/python/jittor/src/ops/safe_clip_op.cc index 1f47bd79..39847e80 100644 --- a/python/jittor/src/ops/safe_clip_op.cc +++ b/python/jittor/src/ops/safe_clip_op.cc @@ -16,6 +16,7 @@ namespace jittor { SafeClipOp::SafeClipOp(Var* x, float64 left, float64 right) : x(x), left(left), right(right) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); set_type(OpType::element); y = create_output(nullptr, x->dtype()); } diff --git a/python/jittor/src/ops/setitem_op.cc b/python/jittor/src/ops/setitem_op.cc index 77447fa0..d16faeac 100644 --- a/python/jittor/src/ops/setitem_op.cc +++ b/python/jittor/src/ops/setitem_op.cc @@ -42,6 +42,12 @@ SetitemOp::SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op) flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_has_gopt); + if (op.get(NanoString::_no_need_back_in)) { + flags.set(NodeFlags::_manual_set_vnbb); + for (int i=0; iflags.set(NodeFlags::_needed_by_backward); + } ASSERT(op == ns_void || op.is_binary()); create_output(nullptr, x->dtype()); if (flags.get(NodeFlags::_custom_flag)) { diff --git a/python/jittor/src/ops/tape_op.cc b/python/jittor/src/ops/tape_op.cc index 7202f597..f1b5f6dc 100644 --- a/python/jittor/src/ops/tape_op.cc +++ b/python/jittor/src/ops/tape_op.cc @@ -16,6 +16,7 @@ namespace jittor { TapeOp::TapeOp(Var* x) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); create_output(nullptr, x->dtype()); } @@ -53,6 +54,7 @@ Tapes::Tapes( flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_grads); + flags.set(NodeFlags::_manual_set_vnbb); callback = move(grad_callback); diff --git a/python/jittor/src/ops/ternary_op.cc b/python/jittor/src/ops/ternary_op.cc index 681d017e..70bbae36 100644 --- a/python/jittor/src/ops/ternary_op.cc +++ b/python/jittor/src/ops/ternary_op.cc @@ -13,13 +13,26 @@ namespace jittor { #ifndef JIT static auto make_ternary = get_op_info("ternary") .get_constructor(); +static auto make_broadcast = get_op_info("broadcast_to") + .get_constructor(); static auto make_number = get_op_info("number") .get_constructor(); TernaryOp::TernaryOp(Var* cond, Var* x, Var* y) : cond(cond), x(x), y(y) { + bool bx = cond->shape.size() > x->shape.size() || cond->num > x->num; + bool by = cond->shape.size() > y->shape.size() || cond->num > y->num; + if (bx || by) { + VarPtr xx, yy; + if (bx) xx = make_broadcast(x, cond, NanoVector()), x = xx; + if (by) yy = make_broadcast(y, cond, NanoVector()), y = yy; + forward(make_ternary(cond, x, y)); + return; + } flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); set_type(OpType::element); + flags.set(NodeFlags::_manual_set_vnbb); + cond->flags.set(NodeFlags::_needed_by_backward); z = create_output(nullptr, dtype_infer(x->ns, y->ns)); } diff --git a/python/jittor/src/ops/transpose_op.cc b/python/jittor/src/ops/transpose_op.cc index 94d9576e..143e76ea 100644 --- a/python/jittor/src/ops/transpose_op.cc +++ b/python/jittor/src/ops/transpose_op.cc @@ -49,6 +49,7 @@ TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) { } #endif y = create_output(nullptr, x->dtype()); + flags.set(NodeFlags::_manual_set_vnbb); } void TransposeOp::infer_shape() { diff --git a/python/jittor/src/ops/unary_op.cc b/python/jittor/src/ops/unary_op.cc index 9a95d6df..f7ead025 100644 --- a/python/jittor/src/ops/unary_op.cc +++ b/python/jittor/src/ops/unary_op.cc @@ -544,6 +544,17 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) { } else dtype = unary_dtype_infer(ns, x->ns); y = create_output(nullptr, dtype); + bool bin = ns.get(NanoString::_no_need_back_in); + bool bout = ns.get(NanoString::_no_need_back_out); + if (bin || bout) { + flags.set(NodeFlags::_manual_set_vnbb); + if (!bin) { + x->flags.set(NodeFlags::_needed_by_backward); + } + if (!bout) { + y->flags.set(NodeFlags::_needed_by_backward); + } + } } VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) { diff --git a/python/jittor/src/ops/where_op.cc b/python/jittor/src/ops/where_op.cc index ed3ea8d2..947ad04e 100644 --- a/python/jittor/src/ops/where_op.cc +++ b/python/jittor/src/ops/where_op.cc @@ -21,7 +21,7 @@ namespace jittor { WhereOp::WhereOp(Var* cond, NanoString dtype) : cond(cond) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); - flags.set(NodeFlags::_vary_shape); + flags.set(NodeFlags::_manual_set_vnbb); auto ndim = cond->shape.size(); #ifdef HAS_CUDA if (use_cuda) { @@ -48,8 +48,7 @@ WhereOp::WhereOp(Var* cond, Var* x, Var* y) { void WhereOp::infer_shape() { auto ndim = cond->shape.size(); - auto num = cond->num; - if (num>0) num = -num; + auto num = -cond->num; for (uint i=0; iset_shape({num}); } diff --git a/python/jittor/src/opt/gopt/setitem_gopt.cc b/python/jittor/src/opt/gopt/setitem_gopt.cc index 43ab416d..d6d0bc8a 100644 --- a/python/jittor/src/opt/gopt/setitem_gopt.cc +++ b/python/jittor/src/opt/gopt/setitem_gopt.cc @@ -124,9 +124,6 @@ static void getitem_inplace(GetitemOp* op) { // return if out is all ready inplaced if (ou->allocator) return; - // return if input or output's shape is variable - if (in->num <= 0 || ou->num <= 0) - return; VarSlices vs = op->vs; auto in_shape = in->shape; diff --git a/python/jittor/src/pybind/core.cc b/python/jittor/src/pybind/core.cc index c4a979f0..9d23be9f 100644 --- a/python/jittor/src/pybind/core.cc +++ b/python/jittor/src/pybind/core.cc @@ -17,11 +17,11 @@ SEH_HOOK; // Those function is generated by python EXTERN_LIB void pyjt_def_all(PyObject* m); -vector _grad(VarHolder* loss, const vector& targets) { +vector _grad(VarHolder* loss, const vector& targets, bool retain_graph) { vector vs; vs.reserve(targets.size()); for (auto* v : targets) vs.push_back(v->var); - auto grads = grad(loss->var, vs); + auto grads = grad(loss->var, vs, retain_graph); vector grads_hold; grads_hold.reserve(targets.size()); for (auto& grad : grads) diff --git a/python/jittor/src/utils/log.cc b/python/jittor/src/utils/log.cc index cefba74a..f9e8b41c 100644 --- a/python/jittor/src/utils/log.cc +++ b/python/jittor/src/utils/log.cc @@ -604,7 +604,7 @@ void system_with_check(const char* cmd, const char* cwd) { auto ret = system_popen(cmd, cwd); CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd << "\nreturn ">> ret >> ". This might be an overcommit issue or out of memory." - << "Try : sudo sysctl vm.overcommit_memory=1"; + << "Try : sudo sysctl vm.overcommit_memory=1, or set enviroment variable `export DISABLE_MULTIPROCESSING=1`"; CHECKop(ret,==,0) << "Run cmd failed:" << cmd; } diff --git a/python/jittor/src/var.cc b/python/jittor/src/var.cc index 1bc750b6..15c03ea9 100644 --- a/python/jittor/src/var.cc +++ b/python/jittor/src/var.cc @@ -22,6 +22,7 @@ DEFINE_FLAG(bool, no_grad, 0, "No grad for all jittor Var creation"); DEFINE_FLAG(bool, no_fuse, 0, "No fusion optimization for all jittor Var creation"); +// TODO: fuse multiple flags DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too"); DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16"); @@ -91,13 +92,14 @@ bool Var::alloc(Allocator* allocator) { std::ostream& operator<<(std::ostream& os, const Var& var) { - os << "Var" << '(' << (void*)&var + os << "Var" << '(' << var.id << ':' << var.forward_liveness << ':' << var.backward_liveness << ':' << var.pending_liveness << ":i" << var._inputs.size() << ":o" << var._outputs.size() << ":s" << var.is_finished() + << ":n" << var.flags.get(NodeFlags::_needed_by_backward) << ',' << var.dtype().to_cstring() << ',' << var.name << ',' << var.mem_ptr << ')' << var.shape; diff --git a/python/jittor/src/var_holder.cc b/python/jittor/src/var_holder.cc index 0fe00548..61698d9b 100644 --- a/python/jittor/src/var_holder.cc +++ b/python/jittor/src/var_holder.cc @@ -22,6 +22,7 @@ namespace jittor { DEFINE_FLAG(int, lazy_execution, 1, "Default enabled, if disable, use immediately eager execution rather than lazy execution, This flag makes error message and traceback infomation better. But this flag will raise memory consumption and lower the performance."); list hold_vars; +list::iterator sync_ptr = hold_vars.end(); void add_hold_vars(VarHolder* self) { hold_vars.push_front(self); @@ -79,6 +80,8 @@ VarHolder::VarHolder(PyObject* obj, NanoString dtype) { VarHolder::~VarHolder() { if (PREDICT_BRANCH_NOT_TAKEN(!var)) return; + if (iter == sync_ptr) + sync_ptr = std::next(sync_ptr); hold_vars.erase(iter); var->release_both_liveness(); } @@ -100,7 +103,6 @@ void VarHolder::operator=(VarPtr&& v) { } string VarHolder::to_string() { - if (var->num<0) sync(); return var->to_string(); } @@ -131,8 +133,8 @@ VarHolder* VarHolder::_update(VarHolder* v) { EXTERN_LIB Executor exe; -void VarHolder::sync(bool device_sync) { - jittor::sync({this}, device_sync); +void VarHolder::sync(bool device_sync, bool weak_sync) { + jittor::sync({this}, device_sync, weak_sync); } ArrayArgs VarHolder::fetch_sync() { @@ -178,12 +180,12 @@ void sync_all(bool device_sync) { graph_check(); } -void sync(const vector& vh, bool device_sync) { +void sync(const vector& vh, bool device_sync, bool weak_sync) { vector vars; vars.reserve(vh.size()); for (auto v : vh) vars.push_back(v->var); graph_check(); - exe.run_sync(vars, device_sync); //need sync at last + exe.run_sync(vars, device_sync, weak_sync); //need sync at last graph_check(); } @@ -226,4 +228,17 @@ your code as below:: return 0; } + +static auto make_ternary = get_op_info("ternary") + .get_constructor(); + +extern int no_grad; + +VarHolder* ternary_out_hint(VarHolder* cond, VarHolder* x, VarHolder* y) { + if (!no_grad) + cond->var->flags.set(NodeFlags::_out_hint); + return new VarHolder(make_ternary(cond->var, x->var, y->var)); +} + + } // jittor \ No newline at end of file diff --git a/python/jittor/src/var_holder.h b/python/jittor/src/var_holder.h index f28ca401..71c0d76e 100644 --- a/python/jittor/src/var_holder.h +++ b/python/jittor/src/var_holder.h @@ -31,6 +31,7 @@ struct ItemData { typedef struct _object PyObject; EXTERN_LIB list hold_vars; +EXTERN_LIB list::iterator sync_ptr; // @pyjt(Var) // @attrs(heaptype) @@ -47,7 +48,7 @@ struct VarHolder { ~VarHolder(); string to_string(); // @pyjt(sync) - void sync(bool device_sync = false); + void sync(bool device_sync = false, bool weak_sync = true); // @pyjt(fetch_sync,numpy) ArrayArgs fetch_sync(); @@ -108,7 +109,6 @@ struct VarHolder { */ // @pyjt(numel) inline int64 numel() { - if (var->num<0) sync(); return var->num; } @@ -155,12 +155,21 @@ struct VarHolder { return var->flags.get(NodeFlags::_stop_fuse); } + /** + * output hint for training optimization + */ + // @pyjt(out_hint) + // @attrs(return_self) + inline VarHolder* out_hint() { + var->flags.set(NodeFlags::_out_hint); + return this; + } + /** * return the shape of the Var. */ // @pyjt(__get__shape) inline NanoVector shape() { - if (var->num<0) sync(); return var->shape; } @@ -324,7 +333,7 @@ struct VarHolder { }; // @pyjt(sync) -void sync(const vector& vh=vector(), bool device_sync=false); +void sync(const vector& vh=vector(), bool device_sync=false, bool weak_sync=true); // @pyjt(fetch_sync) vector fetch_sync(const vector& vh); @@ -347,4 +356,7 @@ inline vector make_vh_vector(vector&& vps) { return a; } +// @pyjt(ternary_out_hint) +VarHolder* ternary_out_hint(VarHolder* cond, VarHolder* x, VarHolder* y); + } // jittor \ No newline at end of file diff --git a/python/jittor/test/test_clone.py b/python/jittor/test/test_clone.py index 6d4aca68..75f4fd1d 100644 --- a/python/jittor/test/test_clone.py +++ b/python/jittor/test/test_clone.py @@ -11,7 +11,7 @@ import jittor as jt import numpy as np class TestClone(unittest.TestCase): - def test(self): + def test_mid_stop_grad(self): jt.clean() b = a = jt.array(1.0) for i in range(10): @@ -19,8 +19,11 @@ class TestClone(unittest.TestCase): if i==5: c=b b.sync() assert jt.number_of_lived_vars()==11 + c.name("c") c.stop_grad() - assert jt.number_of_lived_vars()==3 + for n in jt.dump_all_graphs().nodes_info: + print(n) + assert jt.number_of_lived_vars()==3, jt.number_of_lived_vars() def test2(self): a = jt.array([1,2]) diff --git a/python/jittor/test/test_core.py b/python/jittor/test/test_core.py index fd7d888e..f4e875c9 100644 --- a/python/jittor/test/test_core.py +++ b/python/jittor/test/test_core.py @@ -17,6 +17,7 @@ def expect_error(func): raise Exception("Expect an error, but nothing catched.") class TestCore(unittest.TestCase): + def test_number_of_hold_vars(self): assert jt.random([1,2,3]).peek() == "float32[1,2,3,]" assert jt.core.number_of_hold_vars() == 0 @@ -73,9 +74,10 @@ class TestCore(unittest.TestCase): c = np.matmul(a, b) jtc = jt.matmul(jt.array(a), jt.array(b)).data assert np.allclose(jtc, c), np.abs(jtc-c).max() - + def test_var_holder(self): jt.clean() + self.assertEqual(jt.number_of_lived_vars(), 0) expect_error(lambda: jt.matmul(1,1)) expect_error(lambda: jt.matmul([1],[1])) expect_error(lambda: jt.matmul([[1]],[1])) @@ -87,7 +89,7 @@ class TestCore(unittest.TestCase): c = np.matmul(a, b) jtc = jt.matmul(jt.array(a), jt.array(b)).data assert np.all(jtc == c) - + def test_save_load_sub_module(self): class Net(jt.Module): def __init__(self): @@ -119,5 +121,103 @@ class TestCore(unittest.TestCase): assert a._parameters['a'] is a.a assert a._parameters['b'] is a.b + def test_copy_memopt(self): + # exe: post run + # remove pending done + # add hold pending done + # pending release mem done + a = jt.rand(10) + b = a.copy().copy().copy() + a.name("aa") + b.name("bb") + + cnt = 0 + graphs = jt.dump_all_graphs() + for x in graphs.nodes_info: + if "Var" not in x: continue + print(x) + if ",aa," in x: + assert ":2:i" in x, x + elif ",bb," in x: + assert ":1:i" in x + else: + assert ":1:i" in x + + b.sync() + cnt = 0 + graphs = jt.dump_all_graphs() + for x in graphs.nodes_info: + # print(x) + if "Var" in x and ",0)" in x: + cnt += 1 + assert cnt == 2 + + def test_fuse_memopt(self): + def check(): + a = jt.rand(10) + b = (a.copy().name("copy_out1") + 1).sqr() + a.copy().name("copy_out2") + b.sync() + for n in jt.dump_all_graphs().nodes_info: + if "Var" not in n: continue + # print(n) + + if "copy_out1" in n: + # copy out1 is not free + assert ",0)" not in n + if "copy_out2" in n: + # copy out2 is freeed + assert ",0)" in n + da = jt.grad(b, a) + da.sync() + check() + jt.gc() + assert jt.liveness_info()['lived_vars'] == 0 + + def test_out_hint1(self): + a = jt.rand(10) + b = jt.rand(10) + c = jt.ternary_out_hint((a0.0).name("b"+str(i)), a, 0.0) + a = jt.matmul(a.name("m1"),jt.rand(10,10).name("m2")).name("m3-"+str(i)) + da = jt.grad(a, x, True) + # jt.clean_graph() + da.sync() + cnt1 = 0 + cnt2 = 0 + for n in jt.dump_all_graphs().nodes_info: + if "Var" in n and ",0)" not in n: + cnt1 +=1 + if "bool" in n: + cnt2 += 1 + print(cnt1, cnt2) + assert cnt2 == 10 + assert cnt1 <= 33, cnt1 + + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_example.py b/python/jittor/test/test_example.py index 2e5a7e82..1013e99d 100644 --- a/python/jittor/test/test_example.py +++ b/python/jittor/test/test_example.py @@ -55,6 +55,7 @@ class TestExample(unittest.TestCase): model = Model(input_size=1) ps = model.parameters() + for p in reversed(ps): p.sync(0,0) for i,(x,y) in enumerate(get_data(n)): pred_y = model(x).name("pred_y") diff --git a/python/jittor/test/test_example_accumulate_grad.py b/python/jittor/test/test_example_accumulate_grad.py index d618312e..89f87d65 100644 --- a/python/jittor/test/test_example_accumulate_grad.py +++ b/python/jittor/test/test_example_accumulate_grad.py @@ -63,6 +63,7 @@ class TestExample(unittest.TestCase): model = Model(input_size=1) ps = model.parameters() + for p in reversed(ps): p.sync(0,0) opt = Optimizer(ps, lr) all_loss = 0 diff --git a/python/jittor/test/test_function.py b/python/jittor/test/test_function.py index 01d49044..8ca98b70 100644 --- a/python/jittor/test/test_function.py +++ b/python/jittor/test/test_function.py @@ -258,6 +258,7 @@ class TestFunction(unittest.TestCase): g = jt.grad(c+d*3, [a, b]) test() jt.clean() + jt.dump_all_graphs() self.assertEqual(jt.liveness_info()["lived_vars"], 0) @unittest.skipIf(True, "skip memleak test") diff --git a/python/jittor/test/test_fused_op.py b/python/jittor/test/test_fused_op.py index 46867a8e..b56460cf 100644 --- a/python/jittor/test/test_fused_op.py +++ b/python/jittor/test/test_fused_op.py @@ -78,9 +78,11 @@ class TestFusedOp(unittest.TestCase): def test_add(self): jt.clean() def check(hv, lv, lo): - self.assertEqual(jt.number_of_hold_vars(), hv) - self.assertEqual(jt.number_of_lived_vars(), lv) - self.assertEqual(jt.number_of_lived_ops(), lo) + self.assertEqual(( + jt.number_of_hold_vars(), + jt.number_of_lived_vars(), + jt.number_of_lived_ops()), + (hv, lv, lo)) for i in range(8): check(0,0,0) a = jt.array(1.0).name('a').stop_fuse() @@ -88,7 +90,14 @@ class TestFusedOp(unittest.TestCase): c = (b+jt.array(1.0).name('t2').stop_fuse()).name('c') check(3,5,5) graph = jt.dump_all_graphs() + # for n in graph.nodes_info: + # print(n) self.assertEqual(c.data, 3) + graph2 = jt.dump_all_graphs() + print("check", i) + for n in graph2.nodes_info: + print(n) + print(jt.liveness_info()) check(3,5,2) graph = jt.dump_all_graphs() for node in graph.nodes_info: diff --git a/python/jittor/test/test_index_op.py b/python/jittor/test/test_index_op.py index 3ab1108e..ca13abbd 100644 --- a/python/jittor/test/test_index_op.py +++ b/python/jittor/test/test_index_op.py @@ -39,7 +39,6 @@ class TestIndexOp(unittest.TestCase): def test_vary_shape_dep(self): a, = jt.where([1,0,1]) b, = a.index_var() - assert a.uncertain_shape==[-3] and b.uncertain_shape==[-3] assert (b.data==[0,1]).all() def test_vary_shape_dep2(self): @@ -48,7 +47,6 @@ class TestIndexOp(unittest.TestCase): index0 = index0.broadcast([1,3], dims=[1]) # [[1,1,1],[2,2,2]] index1 = index0.index_var(1) # [[0,1,2],[0,1,2]] b = a.reindex_var([index0, index1]) - assert b.uncertain_shape==[-3,3] assert (b.data==[[4,5,6],[7,8,9]]).all() assert (index0.data==[[1,1,1],[2,2,2]]).all() assert (index1.data==[[0,1,2],[0,1,2]]).all() diff --git a/python/jittor/test/test_matmul.py b/python/jittor/test/test_matmul.py index 8227d838..b80dbe3d 100644 --- a/python/jittor/test/test_matmul.py +++ b/python/jittor/test/test_matmul.py @@ -130,6 +130,7 @@ class TestMatmul(unittest.TestCase): np.random.seed(0) jt.set_seed(3) model = Model() + for p in reversed(model.parameters()): p.sync(0,0) SGD = jt.nn.SGD(model.parameters(), 0.05, 0.9, 0) n = 1000 batch_size = 50 diff --git a/python/jittor/test/test_node.py b/python/jittor/test/test_node.py index 8ab0c421..7dcb078a 100644 --- a/python/jittor/test/test_node.py +++ b/python/jittor/test/test_node.py @@ -141,6 +141,13 @@ class TestNode(unittest.TestCase): # noded opt: build(0.44),execute(0.11) # for i in range(20): # run() + + # version 1.3.2.6 retest(laptop) + # mode1: + # origin 0.296 exec(0.11) + # int32flag 0.298 exec(0.11) + # add order 0.299 exec(0.11) + # rm p1 rule 0.299 exec(0.11) for i in range(20): run() import gc diff --git a/python/jittor/test/test_reindex_op.py b/python/jittor/test/test_reindex_op.py index a98d84ef..b45ca783 100644 --- a/python/jittor/test/test_reindex_op.py +++ b/python/jittor/test/test_reindex_op.py @@ -75,10 +75,7 @@ def conv_transpose_naive(x, w): def is_fused(x): - x.name('_x') - graph = jt.dump_all_graphs() - node_a = [ node for node in graph.nodes_info if ",_x," in node ] - return 's0' in node_a[0] + return 's0' in x.debug_msg() def check_fused(dim): jt.clean() diff --git a/python/jittor/test/test_resnet.py b/python/jittor/test/test_resnet.py index 8de9fc3f..0d40a717 100644 --- a/python/jittor/test/test_resnet.py +++ b/python/jittor/test/test_resnet.py @@ -91,7 +91,10 @@ class TestResnetFp32(unittest.TestCase): print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' .format(epoch_id, batch_id, 600,1. * batch_id / 6.0, loss[0], acc, time.time()-prev)) # prev = time.time() + # async version jt.fetch(epoch_id, batch_id, loss, output, target, callback) + # sync version + # callback(epoch_id, batch_id, loss.numpy(), output.numpy(), target.numpy()) # log_conv = find_log_with_re(logs, # "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*") diff --git a/python/jittor/test/test_ring_buffer2.py b/python/jittor/test/test_ring_buffer2.py index 487d65b4..30cddc1c 100644 --- a/python/jittor/test/test_ring_buffer2.py +++ b/python/jittor/test/test_ring_buffer2.py @@ -66,8 +66,9 @@ def test_ring_buffer(): assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() test_send_recv(np.random.rand(10,10)) - n_byte += 1 + 16 + 2 + 10*10*8 - assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() + n_byte += 1 + 16 + 4 + 10*10*8 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push(), \ + (n_byte, buffer.total_pop(), n_byte, buffer.total_push()) test_send_recv(test_ring_buffer) test_send_recv(jt.array(np.random.rand(10,10))) diff --git a/python/jittor/test/test_setitem.py b/python/jittor/test/test_setitem.py index 150a999e..71d13f15 100644 --- a/python/jittor/test/test_setitem.py +++ b/python/jittor/test/test_setitem.py @@ -381,6 +381,19 @@ class TestSetitem(unittest.TestCase): for i in range(n): np.testing.assert_allclose(g.data[i*m:(i+1)*m], yy[i].data) + def test_dfs_memopt(self): + with jt.flag_scope(profile_memory_enable=1): + n = 1024 + b = [] + for i in range(n): + a = jt.rand(n).copy().copy() + a = a.sum() + # a.sync() + b.append(a) + jt.sync_all() + jt.get_max_memory_treemap() + + diff --git a/python/jittor/test/test_where_op.py b/python/jittor/test/test_where_op.py index f4dcf894..f712f783 100644 --- a/python/jittor/test/test_where_op.py +++ b/python/jittor/test/test_where_op.py @@ -15,7 +15,7 @@ class TestWhereOp(unittest.TestCase): def test(self): assert (self.where([0,1,0,1])[0].data == [1,3]).all() a, = self.where([0,1,0,1]) - assert a.uncertain_shape==[-4] + assert a.uncertain_shape==[2] a.data assert a.uncertain_shape==[2] a,b = self.where([[0,0,1],[1,0,0]]) diff --git a/python/jittor/utils/data.gz b/python/jittor/utils/data.gz index f59ad084..faf6559a 100644 Binary files a/python/jittor/utils/data.gz and b/python/jittor/utils/data.gz differ