v 1.3.3 memory optimization

This commit is contained in:
Dun Liang 2022-04-22 15:04:06 +08:00
parent 9048f3fd41
commit 0666456a2f
70 changed files with 496 additions and 111 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.3.2.6' __version__ = '1.3.3.0'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int
@ -355,10 +355,10 @@ def array64(data, dtype=None):
with jt.flag_scope(auto_convert_64_to_32=0): with jt.flag_scope(auto_convert_64_to_32=0):
return array(data, dtype) return array(data, dtype)
def grad(loss, targets): def grad(loss, targets, retain_graph=True):
if type(targets) == core.Var: if type(targets) == core.Var:
return core.grad(loss, [targets])[0] return core.grad(loss, [targets], retain_graph)[0]
return core.grad(loss, targets) return core.grad(loss, targets, retain_graph)
def liveness_info(): def liveness_info():
return { return {

View File

@ -28,6 +28,8 @@ CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdim
ASSERT(offsets->dtype()==ns_int32); ASSERT(offsets->dtype()==ns_int32);
y = create_output(nullptr, ns_int32); y = create_output(nullptr, ns_int32);
y_key = create_output(nullptr, x->dtype()); y_key = create_output(nullptr, x->dtype());
flags.set(NodeFlags::_manual_set_vnbb);
y->flags.set(NodeFlags::_needed_by_backward);
} }
VarPtr CubArgReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { VarPtr CubArgReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {

View File

@ -27,6 +27,8 @@ CubArgsortOp::CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending,
ASSERT(offsets->dtype()==ns_int32); ASSERT(offsets->dtype()==ns_int32);
y = create_output(nullptr, dtype); y = create_output(nullptr, dtype);
y_key = create_output(nullptr, x->dtype()); y_key = create_output(nullptr, x->dtype());
flags.set(NodeFlags::_manual_set_vnbb);
y->flags.set(NodeFlags::_needed_by_backward);
} }
VarPtr CubArgsortOp::grad(Var* out, Var* dout, Var* v, int v_index) { VarPtr CubArgsortOp::grad(Var* out, Var* dout, Var* v, int v_index) {

View File

@ -26,7 +26,6 @@ namespace jittor {
CubWhereOp::CubWhereOp(Var* cond, NanoString dtype) : cond(cond) { CubWhereOp::CubWhereOp(Var* cond, NanoString dtype) : cond(cond) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_vary_shape);
auto ndim = cond->shape.size(); auto ndim = cond->shape.size();
outs.reset(new Var*[ndim]); outs.reset(new Var*[ndim]);
for (uint i=0; i<ndim; i++) for (uint i=0; i<ndim; i++)
@ -35,8 +34,7 @@ CubWhereOp::CubWhereOp(Var* cond, NanoString dtype) : cond(cond) {
void CubWhereOp::infer_shape() { void CubWhereOp::infer_shape() {
auto ndim = cond->shape.size(); auto ndim = cond->shape.size();
auto num = cond->num; auto num = -cond->num;
if (num>0) num = -num;
for (uint i=0; i<ndim; i++) for (uint i=0; i<ndim; i++)
outs[i]->set_shape({num}); outs[i]->set_shape({num});
} }

View File

@ -35,6 +35,9 @@ CublasBatchedMatmulOp::CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool
c = create_output(nullptr, a->dtype()); c = create_output(nullptr, a->dtype());
flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_manual_set_vnbb);
a->flags.set(NodeFlags::_needed_by_backward);
b->flags.set(NodeFlags::_needed_by_backward);
} }

View File

@ -24,6 +24,9 @@ CublasMatmulOp::CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
: a(a), b(b), trans_a(trans_a), trans_b(trans_b) { : a(a), b(b), trans_a(trans_a), trans_b(trans_b) {
flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
a->flags.set(NodeFlags::_needed_by_backward);
b->flags.set(NodeFlags::_needed_by_backward);
// TODO: support int8 * int8 // TODO: support int8 * int8
ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same"; ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same";
// TODO: support diffrent input type // TODO: support diffrent input type

View File

@ -29,6 +29,9 @@ CudnnConv3dBackwardWOp::CudnnConv3dBackwardWOp(Var* x, Var* dy, int kd, int kh,
xformat(move(xformat)) { xformat(move(xformat)) {
flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
x->flags.set(NodeFlags::_needed_by_backward);
dy->flags.set(NodeFlags::_needed_by_backward);
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns)); dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
} }

View File

@ -29,6 +29,9 @@ CudnnConv3dBackwardXOp::CudnnConv3dBackwardXOp(Var* w, Var* dy, int depth, int h
xformat(move(xformat)) { xformat(move(xformat)) {
flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
w->flags.set(NodeFlags::_needed_by_backward);
dy->flags.set(NodeFlags::_needed_by_backward);
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns)); dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
} }

View File

@ -26,6 +26,9 @@ CudnnConv3dOp::CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int strid
xformat(move(xformat)) { xformat(move(xformat)) {
flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
x->flags.set(NodeFlags::_needed_by_backward);
w->flags.set(NodeFlags::_needed_by_backward);
y = create_output(nullptr, dtype_infer(x->ns, w->ns)); y = create_output(nullptr, dtype_infer(x->ns, w->ns));
} }

View File

@ -52,6 +52,9 @@ CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
x->flags.set(NodeFlags::_needed_by_backward);
dy->flags.set(NodeFlags::_needed_by_backward);
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns)); dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
} }

View File

@ -53,6 +53,9 @@ CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int widt
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
w->flags.set(NodeFlags::_needed_by_backward);
dy->flags.set(NodeFlags::_needed_by_backward);
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns)); dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
} }

View File

@ -50,6 +50,9 @@ CudnnConvOp::CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh,
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
flags.set(NodeFlags::_cuda, 1); flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
x->flags.set(NodeFlags::_needed_by_backward);
w->flags.set(NodeFlags::_needed_by_backward);
y = create_output(nullptr, dtype_infer(x->ns, w->ns)); y = create_output(nullptr, dtype_infer(x->ns, w->ns));
if (!this->yformat.size()) if (!this->yformat.size())
this->yformat = this->xformat; this->yformat = this->xformat;

View File

@ -801,6 +801,10 @@ def print_tree(now, max_memory_size, prefix1, prefix2, build_by):
tab = ' ' tab = ' '
out += prefix1+now['name']+'('+now['type']+')\n' out += prefix1+now['name']+'('+now['type']+')\n'
out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%; cnt:'+format_size(now['cnt'],'') + ']\n' out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%; cnt:'+format_size(now['cnt'],'') + ']\n'
if len(now['children']) == 0 and len(now['vinfo']):
out += prefix2+now['vinfo'][0]
if len(now['vinfo']) > 1: out += "..."
out += '\n'
if (build_by == 0): if (build_by == 0):
for p in now['path']: for p in now['path']:
out += prefix2+p+'\n' out += prefix2+p+'\n'
@ -866,7 +870,8 @@ Output::
vars_ = vars_[1:] vars_ = vars_[1:]
for v_ in vars_: for v_ in vars_:
v__ = v_.split(div2) v__ = v_.split(div2)
var = {'size':int(v__[1]), 'stack':[], 'cnt':1} vinfo = v__[0].split("{")[0]
var = {'size':int(v__[1]), 'stack':[], 'cnt':1, "vinfo":vinfo}
v__ = v__[2:-1] v__ = v__[2:-1]
for s_ in v__: for s_ in v__:
s__ = s_.split(div3) s__ = s_.split(div3)
@ -874,7 +879,7 @@ Output::
var['stack'].append(s) var['stack'].append(s)
vars.append(var) vars.append(var)
if (build_by == 0): # build tree by name if (build_by == 0): # build tree by name
tree = {'name':'root', "children":[], 'size':0, 'cnt':1, 'path':[], 'type':''} tree = {'name':'root', "children":[], 'size':0, 'cnt':1, 'path':[], 'type':'', 'vinfo':[]}
def find_child(now, key): def find_child(now, key):
for c in now['children']: for c in now['children']:
@ -885,6 +890,7 @@ Output::
now = tree now = tree
now['size'] += v['size'] now['size'] += v['size']
now['cnt'] += v['cnt'] now['cnt'] += v['cnt']
now['vinfo'].append(v['vinfo'])
for s in v['stack']: for s in v['stack']:
ch = find_child(now, s['name']) ch = find_child(now, s['name'])
if (ch is not None): if (ch is not None):
@ -894,12 +900,13 @@ Output::
now = ch now = ch
now['size'] += v['size'] now['size'] += v['size']
now['cnt'] += v['cnt'] now['cnt'] += v['cnt']
now['vinfo'].append(v['vinfo'])
else: else:
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':[s['path']], 'type':s['type']} now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':[s['path']], 'type':s['type'], 'vinfo':[v['vinfo']]}
now['children'].append(now_) now['children'].append(now_)
now = now_ now = now_
elif (build_by == 1): # build tree by path elif (build_by == 1): # build tree by path
tree = {'name':'root', "children":[], 'size':0, 'cnt':0, 'path':'_root_', 'type':''} tree = {'name':'root', "children":[], 'size':0, 'cnt':0, 'path':'_root_', 'type':'', 'vinfo':[]}
def find_child(now, key): def find_child(now, key):
for c in now['children']: for c in now['children']:
@ -910,14 +917,16 @@ Output::
now = tree now = tree
now['size'] += v['size'] now['size'] += v['size']
now['cnt'] += v['cnt'] now['cnt'] += v['cnt']
now['vinfo'].append(v['vinfo'])
for s in v['stack']: for s in v['stack']:
ch = find_child(now, s['path']) ch = find_child(now, s['path'])
if (ch is not None): if (ch is not None):
now = ch now = ch
now['size'] += v['size'] now['size'] += v['size']
now['cnt'] += v['cnt'] now['cnt'] += v['cnt']
now['vinfo'].append(v['vinfo'])
else: else:
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':s['path'], 'type':s['type']} now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':s['path'], 'type':s['type'], 'vinfo':[v['vinfo']]}
now['children'].append(now_) now['children'].append(now_)
now = now_ now = now_
else: else:

View File

@ -173,7 +173,8 @@ def relu(x):
>>> nn.relu(a) >>> nn.relu(a)
jt.Var([0. 1.1338731 6.128115 ], dtype=float32) jt.Var([0. 1.1338731 6.128115 ], dtype=float32)
''' '''
return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x)) cond = x>0.0
return jt.ternary_out_hint(cond, x, 0.0)
def leaky_relu(x, scale=0.01): def leaky_relu(x, scale=0.01):

View File

@ -98,7 +98,7 @@ class Optimizer(object):
def zero_grad(self): def zero_grad(self):
self.__zero_grad = True self.__zero_grad = True
def pre_step(self, loss): def pre_step(self, loss, retain_graph=False):
""" something should be done before step, such as calc gradients, mpi sync, and so on. """ something should be done before step, such as calc gradients, mpi sync, and so on.
Example:: Example::
@ -118,7 +118,7 @@ class Optimizer(object):
params_has_grad.append(p) params_has_grad.append(p)
# get gradient # get gradient
grads = jt.grad(loss, params_has_grad) grads = jt.grad(loss, params_has_grad, retain_graph)
# sync grads and model if in mpi # sync grads and model if in mpi
if jt.in_mpi: if jt.in_mpi:
@ -153,7 +153,7 @@ class Optimizer(object):
pid += 1 pid += 1
self.__zero_grad = False self.__zero_grad = False
def backward(self, loss): def backward(self, loss, retain_graph=False):
''' '''
optimize.backward(loss) is used for accumulate multiple step, optimize.backward(loss) is used for accumulate multiple step,
it can be used as following: it can be used as following:
@ -186,11 +186,11 @@ class Optimizer(object):
''' '''
self.pre_step(loss) self.pre_step(loss, retain_graph)
def step(self, loss=None): def step(self, loss=None, retain_graph=False):
if loss is not None: if loss is not None:
self.pre_step(loss) self.pre_step(loss, retain_graph)
for pg in self.param_groups: for pg in self.param_groups:
lr = pg.get("lr", self.lr) lr = pg.get("lr", self.lr)
for p, g in zip(pg["params"], pg["grads"]): for p, g in zip(pg["params"], pg["grads"]):

View File

@ -33,6 +33,6 @@ inline static void __print_trace() {
} }
// @pyjt(grad) // @pyjt(grad)
vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets); vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets, bool retain_graph=true);
} // jittor } // jittor

View File

@ -30,6 +30,7 @@
#include "memory_profiler.h" #include "memory_profiler.h"
#include "utils/seh.h" #include "utils/seh.h"
#include "utils/cache_compile.h" #include "utils/cache_compile.h"
#include "var_holder.h"
namespace jittor { namespace jittor {
@ -102,6 +103,23 @@ void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, i
} }
} }
static inline void propergate_needed_flags(FusedOp& fused_op) {
auto& ops = fused_op.ops;
for (int i=ops.size()-1; i>=0; i--) {
bool has_need = 0;
auto op = ops[i];
for (auto o : op->outputs())
if (o->flags.get(NodeFlags::_needed_by_backward) &&
!(o->custom_data&1)) {
has_need = 1;
}
if (has_need)
for (auto i : op->inputs()) {
i->flags.set(NodeFlags::_needed_by_backward);
}
}
}
void check_op_async_error(Op* op, bool is_fused_op, const std::exception& e, jittor::Log& logf) { void check_op_async_error(Op* op, bool is_fused_op, const std::exception& e, jittor::Log& logf) {
vector<Stack> stack; vector<Stack> stack;
if (is_fused_op) { if (is_fused_op) {
@ -151,7 +169,30 @@ void check_op_async_error(Op* op, bool is_fused_op, const std::exception& e, jit
jittor::LogFatalVoidify() && logf; jittor::LogFatalVoidify() && logf;
} }
void Executor::run_sync(vector<Var*> vars, bool device_sync) { static void top_weak_sync(vector<Var*>& vars) {
auto t = ++Node::tflag_count;
int64 max_id=0;
for (auto v : vars) {
max_id = std::max(v->id, max_id);
v->tflag = t;
}
while (true) {
if (sync_ptr == hold_vars.begin())
break;
auto next_ptr = std::prev(sync_ptr);
auto v = (*next_ptr)->var;
if (v->id > max_id) break;
sync_ptr = next_ptr;
if (v->tflag == t) continue;
if (v->_outputs.size()) continue;
if (v->is_finished()) continue;
vars.push_back(v);
}
}
void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
if (weak_sync)
top_weak_sync(vars);
auto allocator = get_allocator(); auto allocator = get_allocator();
auto temp_allocator = get_allocator(true); auto temp_allocator = get_allocator(true);
this->allocator = allocator; this->allocator = allocator;
@ -287,6 +328,10 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
// output: // output:
// queue: toplogical order of fused op // queue: toplogical order of fused op
{ {
// queue.clear();
#ifndef JT_bfs_executor
map<int64, int> p_queue;
#endif
for (int root : roots) { for (int root : roots) {
for (int i=root; i>=0; i=next[i]) { for (int i=root; i>=0; i=next[i]) {
Op* op = ops[i]; Op* op = ops[i];
@ -299,24 +344,48 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
} }
} }
} }
#ifdef JT_bfs_executor
if (deps[root] == 0) if (deps[root] == 0)
queue.push_back(root); queue.push_back(root);
#else
if (deps[root] == 0)
p_queue[ops[root]->id] = root;
#endif
} }
for (uint s=0; s<queue.size(); s++) { #ifdef JT_bfs_executor
for (uint s=0; s<queue.size(); s++)
#else
while (p_queue.size())
#endif
{
#ifdef JT_bfs_executor
int op_id = queue[s]; int op_id = queue[s];
#else
int op_id = p_queue.begin()->second;
p_queue.erase(p_queue.begin());
queue.push_back(op_id);
#endif
for (int i=op_id; i>=0; i=next[i]) { for (int i=op_id; i>=0; i=next[i]) {
Op* op = ops[i]; Op* op = ops[i];
for (Var* v : op->outputs()) for (Var* v : op->outputs())
{
if (v->tflag == tt) if (v->tflag == tt)
for (Op* op2 : v->outputs()) { for (Op* op2 : v->outputs())
{
if (op2->tflag != tt) continue; if (op2->tflag != tt) continue;
int op2_id = father[op2->custom_data]; int op2_id = father[op2->custom_data];
// continue if those two ops are fused // continue if those two ops are fused
if (op2_id == op_id) continue; if (op2_id == op_id) continue;
deps[op2_id]--; deps[op2_id]--;
#ifdef JT_bfs_executor
if (deps[op2_id] == 0) if (deps[op2_id] == 0)
queue.push_back(op2_id); queue.push_back(op2_id);
#else
if (deps[op2_id] == 0)
p_queue[op2->id] = op2_id;
#endif
} }
}
} }
} }
ASSERTop(queue.size(),==,roots.size()); ASSERTop(queue.size(),==,roots.size());
@ -479,9 +548,6 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt); load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
} }
LOGvvv << "Run" << op; LOGvvv << "Run" << op;
if (op->flags.get(NodeFlags::_has_vary_input)) op->init();
ASSERT(!op->flags.get(NodeFlags::_has_vary_input))
<< "Shape of(" >> op->name() >> ") not solved.";
for (auto* var : op->outputs()) { for (auto* var : op->outputs()) {
var->alloc(allocator); var->alloc(allocator);
} }
@ -557,6 +623,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
LOGvvv << "Finished Op(" >> op->name() << rid >> LOGvvv << "Finished Op(" >> op->name() << rid >>
"/" >> queue.size() >> ") output:" << op->outputs(); "/" >> queue.size() >> ") output:" << op->outputs();
if (is_fused_op) { if (is_fused_op) {
propergate_needed_flags(fused_op);
for (Var* var : op->outputs()) for (Var* var : op->outputs())
var->finish_pending_liveness(); var->finish_pending_liveness();
continue; continue;
@ -596,7 +663,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
} }
} }
LOGvv << "All" << op_num << "ops finished, return vars:" << vars; LOGvv << "All" << op_num << "ops finished, return vars:" << vars;
for (Var* v : vars) ASSERT(v->mem_ptr); for (Var* v : vars) ASSERT(v->mem_ptr || !v->backward_liveness);
// clean fetcher free buffer // clean fetcher free buffer
fetcher_to_free.clear(); fetcher_to_free.clear();
#ifdef HAS_CUDA #ifdef HAS_CUDA

View File

@ -21,7 +21,7 @@ struct Executor {
Allocator* allocator; Allocator* allocator;
Allocator* temp_allocator; Allocator* temp_allocator;
bool last_is_cuda = false; bool last_is_cuda = false;
void run_sync(vector<Var*> vars, bool device_sync); void run_sync(vector<Var*> vars, bool device_sync, bool weak_sync=true);
inline Allocation alloc_temp(size_t size) { inline Allocation alloc_temp(size_t size) {
return Allocation(temp_allocator, size); return Allocation(temp_allocator, size);

View File

@ -45,6 +45,7 @@ void FusedOp::update_ops() {
_inputs.clear(); _inputs.clear();
_outputs.clear(); _outputs.clear();
vars.clear();
for (Op* op : ops) { for (Op* op : ops) {
for (Var* o : op->outputs()) { for (Var* o : op->outputs()) {
if (o->loop_options) { if (o->loop_options) {
@ -93,10 +94,7 @@ void FusedOp::update_ops() {
o->custom_data &= 1; o->custom_data &= 1;
} }
} }
vars.clear();
bool has_vary_input = 0;
for (Op* opi : ops) { for (Op* opi : ops) {
has_vary_input |= opi->flags.get(NodeFlags::_has_vary_input);
for (Var* i : opi->inputs()) { for (Var* i : opi->inputs()) {
auto &c = i->custom_data; auto &c = i->custom_data;
// if not visited // if not visited
@ -116,7 +114,6 @@ void FusedOp::update_ops() {
} }
} }
} }
flags.set(NodeFlags::_has_vary_input, has_vary_input);
LOGvvvv << "Var info" << vars; LOGvvvv << "Var info" << vars;
} }
@ -144,12 +141,9 @@ FusedOp::~FusedOp() {
} }
void FusedOp::infer_shape() { void FusedOp::infer_shape() {
bool has_vary_input = 0;
for (Op* op : ops) { for (Op* op : ops) {
op->init(); op->init();
has_vary_input |= op->flags.get(NodeFlags::_has_vary_input);
} }
flags.set(NodeFlags::_has_vary_input, has_vary_input);
} }
void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) { void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {

View File

@ -10,6 +10,7 @@
#include "op.h" #include "op.h"
#include "graph.h" #include "graph.h"
#include "ops/op_register.h" #include "ops/op_register.h"
#include "var_holder.h"
namespace jittor { namespace jittor {
@ -76,7 +77,7 @@ void warn_grad_break(int i, Var* v) {
LOGw << "grads[">>i>>"] '">> v->name>>"' doesn't have gradient. It will be set to zero:" << v; LOGw << "grads[">>i>>"] '">> v->name>>"' doesn't have gradient. It will be set to zero:" << v;
} }
vector<VarPtr> grad(Var* loss, vector<Var*> targets) { vector<VarPtr> grad(Var* loss, vector<Var*> targets, bool retain_graph) {
LOGvv << "loss:" >> loss << "targets:" >> targets; LOGvv << "loss:" >> loss << "targets:" >> targets;
CHECK(loss->is_float()) << "Loss should be float"; CHECK(loss->is_float()) << "Loss should be float";
for (Var* var : targets) for (Var* var : targets)
@ -259,6 +260,20 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
assign_attrs(grad.ptr, var); assign_attrs(grad.ptr, var);
} }
} }
if (!retain_graph) {
auto t = ++Node::tflag_count;
for (auto& vh : hold_vars)
if (vh->var->tflag != t) {
vh->var->tflag = t;
}
SetupFreeBuffer setup_free_buffer;
for (int i=int(gvars.size())-1; i>=0; i--)
if (gvars[i]->tflag != t && gvars[i]->backward_liveness)
gvars[i]->set_stop_grad();
for (int i=0; i<grads.size(); i++)
if (grads[i])
grads[i]->set_stop_grad();
}
return results; return results;
} }

View File

@ -9,7 +9,7 @@
namespace jittor { namespace jittor {
vector<VarPtr> grad(Var* loss, vector<Var*> targets); vector<VarPtr> grad(Var* loss, vector<Var*> targets, bool retain_graph=true);
// @pyjt(tape_together) // @pyjt(tape_together)
void tape_together( void tape_together(

View File

@ -5,6 +5,7 @@
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
#include <sstream> #include <sstream>
#include <algorithm>
#include "graph.h" #include "graph.h"
#include "var_holder.h" #include "var_holder.h"
#include "var.h" #include "var.h"
@ -32,7 +33,7 @@ void do_graph_check() {
LOGvv << "Check hold_vars size" << queue.size(); LOGvv << "Check hold_vars size" << queue.size();
int vhsize = queue.size(); int vhsize = queue.size();
for (auto* node : queue) { for (auto* node : queue) {
ASSERTop(node->forward_liveness,>,0); // ASSERTop(node->forward_liveness,>,0);
ASSERTop(node->backward_liveness,>,0); ASSERTop(node->backward_liveness,>,0);
} }
for (uint i=0; i<queue.size(); i++) { for (uint i=0; i<queue.size(); i++) {
@ -49,7 +50,7 @@ void do_graph_check() {
LOGvvvv << "Check node" << i << node; LOGvvvv << "Check node" << i << node;
int f=0, b=0, p=0; int f=0, b=0, p=0;
if (i<vhsize) { if (i<vhsize) {
f+=visited.at(node), b+=visited.at(node); f+=visited.at(node), b+=visited.at(node), p+=visited.at(node);
} }
for (auto* i : node->inputs()) { for (auto* i : node->inputs()) {
if (i->is_stop_grad()) continue; if (i->is_stop_grad()) continue;
@ -62,7 +63,7 @@ void do_graph_check() {
if (o->pending_liveness && !o->is_finished()) if (o->pending_liveness && !o->is_finished())
p++; p++;
} }
if (f>0 && b>0 && !node->is_finished()) p++; // if (f>0 && b>0 && !node->is_finished()) p++;
if (f!=node->forward_liveness || b!=node->backward_liveness || p!=node->pending_liveness) { if (f!=node->forward_liveness || b!=node->backward_liveness || p!=node->pending_liveness) {
LOGf << "ERROR" << node << '\n' LOGf << "ERROR" << node << '\n'
<< f << b << p << i << '\n' << f << b << p << i << '\n'
@ -91,6 +92,8 @@ DumpGraphs dump_all_graphs() {
queue.push_back(vh->var); queue.push_back(vh->var);
} }
bfs_both(queue, [](Node*){return true;}); bfs_both(queue, [](Node*){return true;});
std::sort(queue.begin(), queue.end(),
[](Node* a, Node* b) { return a->id < b->id;});
DumpGraphs graphs; DumpGraphs graphs;
for (uint i=0; i<queue.size(); i++) for (uint i=0; i<queue.size(); i++)
queue[i]->custom_data = i; queue[i]->custom_data = i;

View File

@ -76,7 +76,7 @@ void MemoryProfiler::check() {
allocations.clear(); allocations.clear();
size_t memory_size = 0; size_t memory_size = 0;
std::vector<std::pair<std::pair<string, vector<Stack>>, size_t>> live_vars; std::vector<std::pair<std::pair<string, vector<Stack>>, size_t>> live_vars;
vector<Node*> queue; vector<Node*> queue, queue2;
auto t = ++Node::tflag_count; auto t = ++Node::tflag_count;
for (auto& vh : hold_vars) for (auto& vh : hold_vars)
@ -85,6 +85,14 @@ void MemoryProfiler::check() {
queue.push_back(vh->var); queue.push_back(vh->var);
} }
bfs_both(queue, [](Node*){return true;}); bfs_both(queue, [](Node*){return true;});
vector<int> backup_custom_data;
backup_custom_data.resize(queue.size());
for (int i=0; i<queue.size(); i++)
backup_custom_data[i] = queue[i]->custom_data;
toplogical_sort_forward(queue, queue2, [](Node*){});
for (int i=0; i<queue.size(); i++)
queue[i]->custom_data = backup_custom_data[i];
queue.swap(queue2);
for (Node* node : queue) { for (Node* node : queue) {
if (node->is_var()) { if (node->is_var()) {
Var* var = (Var*)node; Var* var = (Var*)node;

View File

@ -137,6 +137,24 @@ static unordered_set<string> white_ops = {
"pow", "pow",
}; };
static unordered_set<string> no_need_back_in = {
"void",
"cast",
"negative",
"add",
"subtract",
"mean",
};
static unordered_set<string> no_need_back_out = {
"void",
"cast",
"negative",
"add",
"subtract",
"multiply",
};
#define DEFINE_NS(T) NanoString ns_##T; #define DEFINE_NS(T) NanoString ns_##T;
FOR_ALL_NS(DEFINE_NS); FOR_ALL_NS(DEFINE_NS);
@ -172,6 +190,8 @@ static void init_ns() {
ns.set(NanoString::_float, float_ops.count(name)); ns.set(NanoString::_float, float_ops.count(name));
} }
ns.set(NanoString::_white_list, white_ops.count(name)); ns.set(NanoString::_white_list, white_ops.count(name));
ns.set(NanoString::_no_need_back_in, no_need_back_in.count(name));
ns.set(NanoString::_no_need_back_out, no_need_back_out.count(name));
__string_to_ns[name] = ns; __string_to_ns[name] = ns;
auto name2 = ns.to_cstring(); auto name2 = ns.to_cstring();
int len=0; int len=0;

View File

@ -98,7 +98,7 @@ EXTERN_LIB int __ns_len[];
// @pyjt(NanoString) // @pyjt(NanoString)
struct NanoString { struct NanoString {
typedef uint16 ns_t; typedef uint32 ns_t;
enum Flags { enum Flags {
// bit0~7: index // bit0~7: index
_index=0, _index_nbits=7, _index=0, _index_nbits=7,
@ -119,6 +119,9 @@ struct NanoString {
_dsize=_n+6, _dsize_nbits=2, _dsize=_n+6, _dsize_nbits=2,
// bit8: white list // bit8: white list
_white_list=_n+8, _white_list=_n+8,
// bit9: backward opt
_no_need_back_in=_n+9,
_no_need_back_out=_n+10,
}; };
ns_t data=0; ns_t data=0;

View File

@ -18,7 +18,7 @@ EXTERN_LIB int64 nt;
EXTERN_LIB vector<Node*> free_buffer; EXTERN_LIB vector<Node*> free_buffer;
struct NodeFlags { struct NodeFlags {
typedef uint16 nf_t; typedef uint32 nf_t;
nf_t flags=0; nf_t flags=0;
enum Flags { enum Flags {
// bit0: is_var // bit0: is_var
@ -35,6 +35,8 @@ struct NodeFlags {
_force_fuse=_n+0, _force_fuse=_n+0,
_stop_fuse=_n+1, _stop_fuse=_n+1,
_in_update_queue=_n+2, _in_update_queue=_n+2,
_needed_by_backward=_n+3,
_out_hint=_n+4,
// op related flags // op related flags
// bit0: support cpu // bit0: support cpu
@ -53,13 +55,14 @@ struct NodeFlags {
_has_gopt=_n+7, _has_gopt=_n+7,
// bit8: has vary input // bit8: has vary input
_has_vary_input=_n+8, _has_vary_input=_n+8,
_manual_set_vnbb = _n+9,
// bit9: prefer 32 bit // bit9: prefer 32 bit
_prefer_32=_n+9, _prefer_32=_n+10,
// bit10: force 16 bit // force 16 bit
_prefer_16=_n+10, _prefer_16=_prefer_32+1,
// bit11: reduce keep type unchange // reduce keep type unchange
_reduce_keep=_n+11, _reduce_keep=_prefer_32+2,
_custom_flag=_reduce_keep, _custom_flag = _reduce_keep,
}; };
inline void set(Flags f, int a=1, int nbits=1) { inline void set(Flags f, int a=1, int nbits=1) {
@ -89,8 +92,8 @@ struct Node {
}; };
struct output_t { struct output_t {
Node* node; Node* node;
int index;
list<input_t>::iterator back; list<input_t>::iterator back;
int index;
output_t(Node* n, int i) : node(n), index(i) {} output_t(Node* n, int i) : node(n), index(i) {}
operator Node*() { return node; } operator Node*() { return node; }
operator Op*() { return (Op*)node; } operator Op*() { return (Op*)node; }
@ -120,14 +123,15 @@ struct Node {
inline bool need_free() inline bool need_free()
{ return !pending_liveness && (!forward_liveness || !backward_liveness); } { return !pending_liveness && (!forward_liveness || !backward_liveness); }
int64_t tflag = 0; int custom_data;
int64_t custom_data; int64 tflag = 0;
int64 id;
list<input_t> _inputs; list<input_t> _inputs;
list<output_t> _outputs; list<output_t> _outputs;
#ifdef NODE_MEMCHECK #ifdef NODE_MEMCHECK
inline Node() { inline Node() {
lived_nodes[(void*)this] = ++total_node; lived_nodes[(void*)this] = id = ++total_node;
} }
inline virtual ~Node() { inline virtual ~Node() {
@ -135,7 +139,7 @@ struct Node {
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this); if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);
} }
#else #else
inline Node() {}; inline Node() { id = ++total_node; };
inline virtual ~Node() { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);}; inline virtual ~Node() { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);};
#endif #endif
inline Var* var() { return (Var*)this; } inline Var* var() { return (Var*)this; }

View File

@ -14,6 +14,8 @@
#include "mem/allocator.h" #include "mem/allocator.h"
#include "misc/cuda_flags.h" #include "misc/cuda_flags.h"
#include "pybind/py_var_tracer.h" #include "pybind/py_var_tracer.h"
#include "executor.h"
#include "var_holder.h"
namespace jittor { namespace jittor {
@ -65,14 +67,26 @@ Var* Op::create_output(NanoVector shape, NanoString dtype) {
} }
void Op::init() { void Op::init() {
bool has_vary_input = 0;
for (Var* v : inputs())
if (v->num < 0) {
has_vary_input = 1;
break;
}
flags.set(NodeFlags::_has_vary_input, has_vary_input);
infer_shape(); infer_shape();
bool manual_set_vnbb = flags.get(NodeFlags::_manual_set_vnbb)
|| _inputs.size()==0
|| (_outputs.size()==1 && _outputs.front().node->is_stop_grad());
for (Var* v : inputs()) {
if (!manual_set_vnbb) {
v->flags.set(NodeFlags::_needed_by_backward);
}
}
Var* need_sync = nullptr;
for (Var* v : outputs()) {
if (!manual_set_vnbb)
v->flags.set(NodeFlags::_needed_by_backward);
if (v->num < 0)
need_sync = v;
}
if (need_sync) {
exe.run_sync(vector<Var*>({need_sync}), false);
CHECK(need_sync->num >= 0) << need_sync << "'s shape is error";
}
} }
void Op::compile_optimize(string& src) {} void Op::compile_optimize(string& src) {}
@ -84,7 +98,7 @@ void Op::graph_optimize() {}
string Op::name_ex() const { string Op::name_ex() const {
string a=name(); string a=name();
if (ns!=ns_void) { if (ns.data) {
a += '.'; a += '.';
a += ns.to_cstring(); a += ns.to_cstring();
} }
@ -266,7 +280,7 @@ void Op::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
std::ostream& operator<<(std::ostream& os, const Op* op) { std::ostream& operator<<(std::ostream& os, const Op* op) {
if (!op) return os << "Op(0)"; if (!op) return os << "Op(0)";
os << "Op(" << (void*)op os << "Op(" << op->id
<< ':' << op->forward_liveness << ':' << op->forward_liveness
<< ':' << op->backward_liveness << ':' << op->backward_liveness
<< ':' << op->pending_liveness << ':' << op->pending_liveness

View File

@ -46,7 +46,6 @@ ArgReduceOp::ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims)
get_op_info("cub_arg_reduce").get_constructor<std::vector<VarPtr>, Var*, Var*, NanoString, bool>() get_op_info("cub_arg_reduce").get_constructor<std::vector<VarPtr>, Var*, Var*, NanoString, bool>()
: nullptr; : nullptr;
if (cub_arg_reduce) { if (cub_arg_reduce) {
if (x->num<0) exe.run_sync(vector<Var*>({x}), true);
int dims = x->shape.size(); int dims = x->shape.size();
vector<int64> axes; vector<int64> axes;
axes.reserve(dims); axes.reserve(dims);
@ -88,6 +87,8 @@ ArgReduceOp::ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims)
#endif #endif
y = create_output(nullptr, ns_int32); y = create_output(nullptr, ns_int32);
y_key = create_output(nullptr, x->dtype()); y_key = create_output(nullptr, x->dtype());
flags.set(NodeFlags::_manual_set_vnbb);
y->flags.set(NodeFlags::_needed_by_backward);
} }
VarPtr ArgReduceOp::get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y) { VarPtr ArgReduceOp::get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y) {
// Do not have grad to extras input // Do not have grad to extras input

View File

@ -45,7 +45,6 @@ ArgsortOp::ArgsortOp(Var* x, int dim, bool descending, NanoString dtype)
.get_constructor<std::vector<VarPtr>, Var*, Var*, Var*, bool, NanoString>(); .get_constructor<std::vector<VarPtr>, Var*, Var*, Var*, bool, NanoString>();
} }
if (cub_argsort) { if (cub_argsort) {
if (x->num<0) exe.run_sync(vector<Var*>({x}), true);
int dims = x->shape.size(); int dims = x->shape.size();
vector<int64> axes; vector<int64> axes;
axes.reserve(dims); axes.reserve(dims);
@ -81,6 +80,8 @@ ArgsortOp::ArgsortOp(Var* x, int dim, bool descending, NanoString dtype)
#endif #endif
y = create_output(nullptr, dtype); y = create_output(nullptr, dtype);
y_key = create_output(nullptr, x->dtype()); y_key = create_output(nullptr, x->dtype());
flags.set(NodeFlags::_manual_set_vnbb);
y->flags.set(NodeFlags::_needed_by_backward);
} }
VarPtr ArgsortOp::get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y) { VarPtr ArgsortOp::get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y) {

View File

@ -425,6 +425,18 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
ns = op; ns = op;
ASSERT(ns.is_binary()); ASSERT(ns.is_binary());
z = create_output(nullptr, binary_dtype_infer(op, x->ns, y->ns)); z = create_output(nullptr, binary_dtype_infer(op, x->ns, y->ns));
bool bin = ns.get(NanoString::_no_need_back_in);
bool bout = ns.get(NanoString::_no_need_back_out);
if (bin || bout) {
flags.set(NodeFlags::_manual_set_vnbb);
if (!bin) {
x->flags.set(NodeFlags::_needed_by_backward);
y->flags.set(NodeFlags::_needed_by_backward);
}
if (!bout) {
z->flags.set(NodeFlags::_needed_by_backward);
}
}
} }
VarPtr dirty_clone_broadcast(Var* v) { VarPtr dirty_clone_broadcast(Var* v) {

View File

@ -29,6 +29,7 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
} }
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_manual_set_vnbb);
set_type(OpType::broadcast); set_type(OpType::broadcast);
z = create_output(NanoVector(), x->dtype()); z = create_output(NanoVector(), x->dtype());
bcast_mask = 0; bcast_mask = 0;

View File

@ -16,12 +16,12 @@ namespace jittor {
CandidateOp::CandidateOp(Var* x, string&& fail_cond, NanoString dtype) : x(x), fail_cond(move(fail_cond)) { CandidateOp::CandidateOp(Var* x, string&& fail_cond, NanoString dtype) : x(x), fail_cond(move(fail_cond)) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_vary_shape); flags.set(NodeFlags::_manual_set_vnbb);
y = create_output(nullptr, dtype); y = create_output(nullptr, dtype);
} }
void CandidateOp::infer_shape() { void CandidateOp::infer_shape() {
y->set_shape({-std::abs(x->shape[0])}); y->set_shape({-x->shape[0]});
} }
void CandidateOp::jit_prepare(JK& jk) { void CandidateOp::jit_prepare(JK& jk) {

View File

@ -19,6 +19,7 @@ static auto make_clone = get_op_info("clone")
CloneOp::CloneOp(Var* x) : x(x) { CloneOp::CloneOp(Var* x) : x(x) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_manual_set_vnbb);
y = create_output(nullptr, x->dtype()); y = create_output(nullptr, x->dtype());
if (x->name.ptr) if (x->name.ptr)
y->name = x->name; y->name = x->name;

View File

@ -37,7 +37,6 @@ CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
_outputs.push_back(create_output(shape, dtype)); _outputs.push_back(create_output(shape, dtype));
if (_outputs[0]->num < 0) { if (_outputs[0]->num < 0) {
flags.set(NodeFlags::_vary_shape);
check_vary_shape(_outputs[0]->shape); check_vary_shape(_outputs[0]->shape);
} }
} }
@ -58,7 +57,6 @@ CodeOp::CodeOp(
for (int i=0; i<shapes.size(); i++) { for (int i=0; i<shapes.size(); i++) {
_outputs[i] = create_output(shapes[i], dtypes[i]); _outputs[i] = create_output(shapes[i], dtypes[i]);
if (_outputs[i]->num < 0) { if (_outputs[i]->num < 0) {
flags.set(NodeFlags::_vary_shape);
check_vary_shape(_outputs[i]->shape); check_vary_shape(_outputs[i]->shape);
} }
} }

View File

@ -20,6 +20,7 @@ namespace jittor {
CopyOp::CopyOp(Var* x) { CopyOp::CopyOp(Var* x) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_manual_set_vnbb);
auto y = create_output(nullptr, x->dtype()); auto y = create_output(nullptr, x->dtype());
if (x->name.ptr) if (x->name.ptr)
y->name = x->name; y->name = x->name;

View File

@ -33,6 +33,7 @@ FuseTransposeOp::FuseTransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
set_type(tp); set_type(tp);
flags.set(NodeFlags::_manual_set_vnbb);
int i=0; int i=0;
for (; i<axes.size(); i++) for (; i<axes.size(); i++)
if (i!=axes[i]) break; if (i!=axes[i]) break;

View File

@ -38,6 +38,10 @@ GetitemOp::GetitemOp(Var* x, VarSlices&& slices)
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_has_gopt); flags.set(NodeFlags::_has_gopt);
flags.set(NodeFlags::_manual_set_vnbb);
for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var())
vs.slices[i].var->flags.set(NodeFlags::_needed_by_backward);
create_output(nullptr, x->dtype()); create_output(nullptr, x->dtype());
} }
@ -48,6 +52,10 @@ GetitemOp::GetitemOp(Var* x, VarSlices&& slices, int _)
flags.set(NodeFlags::_has_gopt); flags.set(NodeFlags::_has_gopt);
flags.set(NodeFlags::_custom_flag); flags.set(NodeFlags::_custom_flag);
flags.set(NodeFlags::_grads); flags.set(NodeFlags::_grads);
flags.set(NodeFlags::_manual_set_vnbb);
for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var())
vs.slices[i].var->flags.set(NodeFlags::_needed_by_backward);
create_output(nullptr, x->dtype()); create_output(nullptr, x->dtype());
auto out2 = create_output(nullptr, x->dtype()); auto out2 = create_output(nullptr, x->dtype());
out2->share_with(x); out2->share_with(x);
@ -421,7 +429,6 @@ void GetitemOp::grads(Var** dout, VarPtr* dins) {
VarPtr y = dout[0]; VarPtr y = dout[0];
if (!x) { if (!x) {
auto in = inputs().front(); auto in = inputs().front();
if (in->num<0) exe.run_sync(vector<Var*>({in}), true);
// ns.data represents this is the last split var // ns.data represents this is the last split var
if (ns.data) if (ns.data)
x = make_empty(in->shape, in->dtype()); x = make_empty(in->shape, in->dtype());

View File

@ -31,6 +31,7 @@ IndexOp::IndexOp(Var* a, int64 dim, NanoString dtype) : dim(dim) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
set_type(OpType::element); set_type(OpType::element);
flags.set(NodeFlags::_manual_set_vnbb);
x.reset(new Var*[1]); x.reset(new Var*[1]);
x[0] = create_output(nullptr, dtype); x[0] = create_output(nullptr, dtype);
} }
@ -38,6 +39,7 @@ IndexOp::IndexOp(Var* a, int64 dim, NanoString dtype) : dim(dim) {
IndexOp::IndexOp(Var* a, NanoString dtype) : dim(a->shape.size()) { IndexOp::IndexOp(Var* a, NanoString dtype) : dim(a->shape.size()) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_manual_set_vnbb);
set_type(OpType::element); set_type(OpType::element);
x.reset(new Var*[dim]); x.reset(new Var*[dim]);
for (int i=0; i<dim; i++) for (int i=0; i<dim; i++)

View File

@ -253,6 +253,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
set_type(OpType::reduce); set_type(OpType::reduce);
if (op.get(NanoString::_no_need_back_in))
flags.set(NodeFlags::_manual_set_vnbb);
ns = op; ns = op;
ASSERT(ns.is_binary()); ASSERT(ns.is_binary());
auto xdim = x->shape.size(); auto xdim = x->shape.size();
@ -279,6 +281,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
set_type(OpType::reduce); set_type(OpType::reduce);
if (op.get(NanoString::_no_need_back_in))
flags.set(NodeFlags::_manual_set_vnbb);
ns = op; ns = op;
ASSERT(ns.is_binary()); ASSERT(ns.is_binary());
reduce_mask = dims_mask; reduce_mask = dims_mask;
@ -319,12 +323,6 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return make_binary(b, v, ns_divide); return make_binary(b, v, ns_divide);
} }
if (ns == ns_mean) { if (ns == ns_mean) {
if (v->num < 0) {
// TODO: Dynamic shape of mean grad was not supported yet
LOGw << "Dynamic shape of mean grad cause synchronize.";
exe.run_sync({v}, 0);
ASSERT(v->num>=0);
}
VarPtr a = make_broadcast_to(dout, v, reduce_mask, keepdims_mask); VarPtr a = make_broadcast_to(dout, v, reduce_mask, keepdims_mask);
VarPtr n = make_number(1.0f*out->num / v->num, a); VarPtr n = make_number(1.0f*out->num / v->num, a);
return make_binary(a, n, ns_multiply); return make_binary(a, n, ns_multiply);

View File

@ -27,6 +27,7 @@ ReindexOp::ReindexOp(Var* x, NanoVector shape, vector<string>&& indexes, float64
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
set_type(OpType::broadcast); set_type(OpType::broadcast);
flags.set(NodeFlags::_manual_set_vnbb);
y = create_output(nullptr, x->dtype()); y = create_output(nullptr, x->dtype());
} }

View File

@ -28,6 +28,8 @@ ReindexReduceOp::ReindexReduceOp(Var* y, NanoString op, NanoVector shape, vector
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
set_type(OpType::reduce); set_type(OpType::reduce);
if (op.get(NanoString::_no_need_back_in))
flags.set(NodeFlags::_manual_set_vnbb);
ns = op; ns = op;
ASSERT(ns.is_binary() && ns!=ns_mean); ASSERT(ns.is_binary() && ns!=ns_mean);
x = create_output(nullptr, y->dtype()); x = create_output(nullptr, y->dtype());

View File

@ -20,6 +20,7 @@ static auto make_reshape = get_op_info("reshape")
ReshapeOp::ReshapeOp(Var* x, NanoVector shape) : x(x), shape(shape) { ReshapeOp::ReshapeOp(Var* x, NanoVector shape) : x(x), shape(shape) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_manual_set_vnbb);
y = create_output(nullptr, x->dtype()); y = create_output(nullptr, x->dtype());
ASSERT(shape.size() > 0) << "input target shape of reshape can't be empty."; ASSERT(shape.size() > 0) << "input target shape of reshape can't be empty.";
} }

View File

@ -16,6 +16,7 @@ namespace jittor {
SafeClipOp::SafeClipOp(Var* x, float64 left, float64 right) : x(x), left(left), right(right) { SafeClipOp::SafeClipOp(Var* x, float64 left, float64 right) : x(x), left(left), right(right) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_manual_set_vnbb);
set_type(OpType::element); set_type(OpType::element);
y = create_output(nullptr, x->dtype()); y = create_output(nullptr, x->dtype());
} }

View File

@ -42,6 +42,12 @@ SetitemOp::SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op)
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_has_gopt); flags.set(NodeFlags::_has_gopt);
if (op.get(NanoString::_no_need_back_in)) {
flags.set(NodeFlags::_manual_set_vnbb);
for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var())
vs.slices[i].var->flags.set(NodeFlags::_needed_by_backward);
}
ASSERT(op == ns_void || op.is_binary()); ASSERT(op == ns_void || op.is_binary());
create_output(nullptr, x->dtype()); create_output(nullptr, x->dtype());
if (flags.get(NodeFlags::_custom_flag)) { if (flags.get(NodeFlags::_custom_flag)) {

View File

@ -16,6 +16,7 @@ namespace jittor {
TapeOp::TapeOp(Var* x) { TapeOp::TapeOp(Var* x) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_manual_set_vnbb);
create_output(nullptr, x->dtype()); create_output(nullptr, x->dtype());
} }
@ -53,6 +54,7 @@ Tapes::Tapes(
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_grads); flags.set(NodeFlags::_grads);
flags.set(NodeFlags::_manual_set_vnbb);
callback = move(grad_callback); callback = move(grad_callback);

View File

@ -13,13 +13,26 @@ namespace jittor {
#ifndef JIT #ifndef JIT
static auto make_ternary = get_op_info("ternary") static auto make_ternary = get_op_info("ternary")
.get_constructor<VarPtr, Var*, Var*, Var*>(); .get_constructor<VarPtr, Var*, Var*, Var*>();
static auto make_broadcast = get_op_info("broadcast_to")
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
static auto make_number = get_op_info("number") static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>(); .get_constructor<VarPtr, float, Var*>();
TernaryOp::TernaryOp(Var* cond, Var* x, Var* y) : cond(cond), x(x), y(y) { TernaryOp::TernaryOp(Var* cond, Var* x, Var* y) : cond(cond), x(x), y(y) {
bool bx = cond->shape.size() > x->shape.size() || cond->num > x->num;
bool by = cond->shape.size() > y->shape.size() || cond->num > y->num;
if (bx || by) {
VarPtr xx, yy;
if (bx) xx = make_broadcast(x, cond, NanoVector()), x = xx;
if (by) yy = make_broadcast(y, cond, NanoVector()), y = yy;
forward(make_ternary(cond, x, y));
return;
}
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
set_type(OpType::element); set_type(OpType::element);
flags.set(NodeFlags::_manual_set_vnbb);
cond->flags.set(NodeFlags::_needed_by_backward);
z = create_output(nullptr, dtype_infer(x->ns, y->ns)); z = create_output(nullptr, dtype_infer(x->ns, y->ns));
} }

View File

@ -49,6 +49,7 @@ TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) {
} }
#endif #endif
y = create_output(nullptr, x->dtype()); y = create_output(nullptr, x->dtype());
flags.set(NodeFlags::_manual_set_vnbb);
} }
void TransposeOp::infer_shape() { void TransposeOp::infer_shape() {

View File

@ -544,6 +544,17 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
} else } else
dtype = unary_dtype_infer(ns, x->ns); dtype = unary_dtype_infer(ns, x->ns);
y = create_output(nullptr, dtype); y = create_output(nullptr, dtype);
bool bin = ns.get(NanoString::_no_need_back_in);
bool bout = ns.get(NanoString::_no_need_back_out);
if (bin || bout) {
flags.set(NodeFlags::_manual_set_vnbb);
if (!bin) {
x->flags.set(NodeFlags::_needed_by_backward);
}
if (!bout) {
y->flags.set(NodeFlags::_needed_by_backward);
}
}
} }
VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) { VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {

View File

@ -21,7 +21,7 @@ namespace jittor {
WhereOp::WhereOp(Var* cond, NanoString dtype) : cond(cond) { WhereOp::WhereOp(Var* cond, NanoString dtype) : cond(cond) {
flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda); flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_vary_shape); flags.set(NodeFlags::_manual_set_vnbb);
auto ndim = cond->shape.size(); auto ndim = cond->shape.size();
#ifdef HAS_CUDA #ifdef HAS_CUDA
if (use_cuda) { if (use_cuda) {
@ -48,8 +48,7 @@ WhereOp::WhereOp(Var* cond, Var* x, Var* y) {
void WhereOp::infer_shape() { void WhereOp::infer_shape() {
auto ndim = cond->shape.size(); auto ndim = cond->shape.size();
auto num = cond->num; auto num = -cond->num;
if (num>0) num = -num;
for (uint i=0; i<ndim; i++) for (uint i=0; i<ndim; i++)
outs[i]->set_shape({num}); outs[i]->set_shape({num});
} }

View File

@ -124,9 +124,6 @@ static void getitem_inplace(GetitemOp* op) {
// return if out is all ready inplaced // return if out is all ready inplaced
if (ou->allocator) if (ou->allocator)
return; return;
// return if input or output's shape is variable
if (in->num <= 0 || ou->num <= 0)
return;
VarSlices vs = op->vs; VarSlices vs = op->vs;
auto in_shape = in->shape; auto in_shape = in->shape;

View File

@ -17,11 +17,11 @@ SEH_HOOK;
// Those function is generated by python // Those function is generated by python
EXTERN_LIB void pyjt_def_all(PyObject* m); EXTERN_LIB void pyjt_def_all(PyObject* m);
vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets) { vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets, bool retain_graph) {
vector<Var*> vs; vector<Var*> vs;
vs.reserve(targets.size()); vs.reserve(targets.size());
for (auto* v : targets) vs.push_back(v->var); for (auto* v : targets) vs.push_back(v->var);
auto grads = grad(loss->var, vs); auto grads = grad(loss->var, vs, retain_graph);
vector<VarHolder*> grads_hold; vector<VarHolder*> grads_hold;
grads_hold.reserve(targets.size()); grads_hold.reserve(targets.size());
for (auto& grad : grads) for (auto& grad : grads)

View File

@ -604,7 +604,7 @@ void system_with_check(const char* cmd, const char* cwd) {
auto ret = system_popen(cmd, cwd); auto ret = system_popen(cmd, cwd);
CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd << CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd <<
"\nreturn ">> ret >> ". This might be an overcommit issue or out of memory." "\nreturn ">> ret >> ". This might be an overcommit issue or out of memory."
<< "Try : sudo sysctl vm.overcommit_memory=1"; << "Try : sudo sysctl vm.overcommit_memory=1, or set enviroment variable `export DISABLE_MULTIPROCESSING=1`";
CHECKop(ret,==,0) << "Run cmd failed:" << cmd; CHECKop(ret,==,0) << "Run cmd failed:" << cmd;
} }

View File

@ -22,6 +22,7 @@ DEFINE_FLAG(bool, no_grad, 0,
"No grad for all jittor Var creation"); "No grad for all jittor Var creation");
DEFINE_FLAG(bool, no_fuse, 0, DEFINE_FLAG(bool, no_fuse, 0,
"No fusion optimization for all jittor Var creation"); "No fusion optimization for all jittor Var creation");
// TODO: fuse multiple flags
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too"); DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16"); DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16");
@ -91,13 +92,14 @@ bool Var::alloc(Allocator* allocator) {
std::ostream& operator<<(std::ostream& os, const Var& var) { std::ostream& operator<<(std::ostream& os, const Var& var) {
os << "Var" << '(' << (void*)&var os << "Var" << '(' << var.id
<< ':' << var.forward_liveness << ':' << var.forward_liveness
<< ':' << var.backward_liveness << ':' << var.backward_liveness
<< ':' << var.pending_liveness << ':' << var.pending_liveness
<< ":i" << var._inputs.size() << ":i" << var._inputs.size()
<< ":o" << var._outputs.size() << ":o" << var._outputs.size()
<< ":s" << var.is_finished() << ":s" << var.is_finished()
<< ":n" << var.flags.get(NodeFlags::_needed_by_backward)
<< ',' << ','
<< var.dtype().to_cstring() << ',' << var.name << ',' << var.mem_ptr << var.dtype().to_cstring() << ',' << var.name << ',' << var.mem_ptr
<< ')' << var.shape; << ')' << var.shape;

View File

@ -22,6 +22,7 @@ namespace jittor {
DEFINE_FLAG(int, lazy_execution, 1, "Default enabled, if disable, use immediately eager execution rather than lazy execution, This flag makes error message and traceback infomation better. But this flag will raise memory consumption and lower the performance."); DEFINE_FLAG(int, lazy_execution, 1, "Default enabled, if disable, use immediately eager execution rather than lazy execution, This flag makes error message and traceback infomation better. But this flag will raise memory consumption and lower the performance.");
list<VarHolder*> hold_vars; list<VarHolder*> hold_vars;
list<VarHolder*>::iterator sync_ptr = hold_vars.end();
void add_hold_vars(VarHolder* self) { void add_hold_vars(VarHolder* self) {
hold_vars.push_front(self); hold_vars.push_front(self);
@ -79,6 +80,8 @@ VarHolder::VarHolder(PyObject* obj, NanoString dtype) {
VarHolder::~VarHolder() { VarHolder::~VarHolder() {
if (PREDICT_BRANCH_NOT_TAKEN(!var)) return; if (PREDICT_BRANCH_NOT_TAKEN(!var)) return;
if (iter == sync_ptr)
sync_ptr = std::next(sync_ptr);
hold_vars.erase(iter); hold_vars.erase(iter);
var->release_both_liveness(); var->release_both_liveness();
} }
@ -100,7 +103,6 @@ void VarHolder::operator=(VarPtr&& v) {
} }
string VarHolder::to_string() { string VarHolder::to_string() {
if (var->num<0) sync();
return var->to_string(); return var->to_string();
} }
@ -131,8 +133,8 @@ VarHolder* VarHolder::_update(VarHolder* v) {
EXTERN_LIB Executor exe; EXTERN_LIB Executor exe;
void VarHolder::sync(bool device_sync) { void VarHolder::sync(bool device_sync, bool weak_sync) {
jittor::sync({this}, device_sync); jittor::sync({this}, device_sync, weak_sync);
} }
ArrayArgs VarHolder::fetch_sync() { ArrayArgs VarHolder::fetch_sync() {
@ -178,12 +180,12 @@ void sync_all(bool device_sync) {
graph_check(); graph_check();
} }
void sync(const vector<VarHolder*>& vh, bool device_sync) { void sync(const vector<VarHolder*>& vh, bool device_sync, bool weak_sync) {
vector<Var*> vars; vector<Var*> vars;
vars.reserve(vh.size()); vars.reserve(vh.size());
for (auto v : vh) vars.push_back(v->var); for (auto v : vh) vars.push_back(v->var);
graph_check(); graph_check();
exe.run_sync(vars, device_sync); //need sync at last exe.run_sync(vars, device_sync, weak_sync); //need sync at last
graph_check(); graph_check();
} }
@ -226,4 +228,17 @@ your code as below::
return 0; return 0;
} }
static auto make_ternary = get_op_info("ternary")
.get_constructor<VarPtr, Var*, Var*, Var*>();
extern int no_grad;
VarHolder* ternary_out_hint(VarHolder* cond, VarHolder* x, VarHolder* y) {
if (!no_grad)
cond->var->flags.set(NodeFlags::_out_hint);
return new VarHolder(make_ternary(cond->var, x->var, y->var));
}
} // jittor } // jittor

View File

@ -31,6 +31,7 @@ struct ItemData {
typedef struct _object PyObject; typedef struct _object PyObject;
EXTERN_LIB list<VarHolder*> hold_vars; EXTERN_LIB list<VarHolder*> hold_vars;
EXTERN_LIB list<VarHolder*>::iterator sync_ptr;
// @pyjt(Var) // @pyjt(Var)
// @attrs(heaptype) // @attrs(heaptype)
@ -47,7 +48,7 @@ struct VarHolder {
~VarHolder(); ~VarHolder();
string to_string(); string to_string();
// @pyjt(sync) // @pyjt(sync)
void sync(bool device_sync = false); void sync(bool device_sync = false, bool weak_sync = true);
// @pyjt(fetch_sync,numpy) // @pyjt(fetch_sync,numpy)
ArrayArgs fetch_sync(); ArrayArgs fetch_sync();
@ -108,7 +109,6 @@ struct VarHolder {
*/ */
// @pyjt(numel) // @pyjt(numel)
inline int64 numel() { inline int64 numel() {
if (var->num<0) sync();
return var->num; return var->num;
} }
@ -155,12 +155,21 @@ struct VarHolder {
return var->flags.get(NodeFlags::_stop_fuse); return var->flags.get(NodeFlags::_stop_fuse);
} }
/**
* output hint for training optimization
*/
// @pyjt(out_hint)
// @attrs(return_self)
inline VarHolder* out_hint() {
var->flags.set(NodeFlags::_out_hint);
return this;
}
/** /**
* return the shape of the Var. * return the shape of the Var.
*/ */
// @pyjt(__get__shape) // @pyjt(__get__shape)
inline NanoVector shape() { inline NanoVector shape() {
if (var->num<0) sync();
return var->shape; return var->shape;
} }
@ -324,7 +333,7 @@ struct VarHolder {
}; };
// @pyjt(sync) // @pyjt(sync)
void sync(const vector<VarHolder*>& vh=vector<VarHolder*>(), bool device_sync=false); void sync(const vector<VarHolder*>& vh=vector<VarHolder*>(), bool device_sync=false, bool weak_sync=true);
// @pyjt(fetch_sync) // @pyjt(fetch_sync)
vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh); vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh);
@ -347,4 +356,7 @@ inline vector<VarHolder*> make_vh_vector(vector<VarPtr>&& vps) {
return a; return a;
} }
// @pyjt(ternary_out_hint)
VarHolder* ternary_out_hint(VarHolder* cond, VarHolder* x, VarHolder* y);
} // jittor } // jittor

View File

@ -11,7 +11,7 @@ import jittor as jt
import numpy as np import numpy as np
class TestClone(unittest.TestCase): class TestClone(unittest.TestCase):
def test(self): def test_mid_stop_grad(self):
jt.clean() jt.clean()
b = a = jt.array(1.0) b = a = jt.array(1.0)
for i in range(10): for i in range(10):
@ -19,8 +19,11 @@ class TestClone(unittest.TestCase):
if i==5: c=b if i==5: c=b
b.sync() b.sync()
assert jt.number_of_lived_vars()==11 assert jt.number_of_lived_vars()==11
c.name("c")
c.stop_grad() c.stop_grad()
assert jt.number_of_lived_vars()==3 for n in jt.dump_all_graphs().nodes_info:
print(n)
assert jt.number_of_lived_vars()==3, jt.number_of_lived_vars()
def test2(self): def test2(self):
a = jt.array([1,2]) a = jt.array([1,2])

View File

@ -17,6 +17,7 @@ def expect_error(func):
raise Exception("Expect an error, but nothing catched.") raise Exception("Expect an error, but nothing catched.")
class TestCore(unittest.TestCase): class TestCore(unittest.TestCase):
def test_number_of_hold_vars(self): def test_number_of_hold_vars(self):
assert jt.random([1,2,3]).peek() == "float32[1,2,3,]" assert jt.random([1,2,3]).peek() == "float32[1,2,3,]"
assert jt.core.number_of_hold_vars() == 0 assert jt.core.number_of_hold_vars() == 0
@ -76,6 +77,7 @@ class TestCore(unittest.TestCase):
def test_var_holder(self): def test_var_holder(self):
jt.clean() jt.clean()
self.assertEqual(jt.number_of_lived_vars(), 0)
expect_error(lambda: jt.matmul(1,1)) expect_error(lambda: jt.matmul(1,1))
expect_error(lambda: jt.matmul([1],[1])) expect_error(lambda: jt.matmul([1],[1]))
expect_error(lambda: jt.matmul([[1]],[1])) expect_error(lambda: jt.matmul([[1]],[1]))
@ -119,5 +121,103 @@ class TestCore(unittest.TestCase):
assert a._parameters['a'] is a.a assert a._parameters['a'] is a.a
assert a._parameters['b'] is a.b assert a._parameters['b'] is a.b
def test_copy_memopt(self):
# exe: post run
# remove pending done
# add hold pending done
# pending release mem done
a = jt.rand(10)
b = a.copy().copy().copy()
a.name("aa")
b.name("bb")
cnt = 0
graphs = jt.dump_all_graphs()
for x in graphs.nodes_info:
if "Var" not in x: continue
print(x)
if ",aa," in x:
assert ":2:i" in x, x
elif ",bb," in x:
assert ":1:i" in x
else:
assert ":1:i" in x
b.sync()
cnt = 0
graphs = jt.dump_all_graphs()
for x in graphs.nodes_info:
# print(x)
if "Var" in x and ",0)" in x:
cnt += 1
assert cnt == 2
def test_fuse_memopt(self):
def check():
a = jt.rand(10)
b = (a.copy().name("copy_out1") + 1).sqr() + a.copy().name("copy_out2")
b.sync()
for n in jt.dump_all_graphs().nodes_info:
if "Var" not in n: continue
# print(n)
if "copy_out1" in n:
# copy out1 is not free
assert ",0)" not in n
if "copy_out2" in n:
# copy out2 is freeed
assert ",0)" in n
da = jt.grad(b, a)
da.sync()
check()
jt.gc()
assert jt.liveness_info()['lived_vars'] == 0
def test_out_hint1(self):
a = jt.rand(10)
b = jt.rand(10)
c = jt.ternary_out_hint((a<b).out_hint(), a, b).clone()
c.sync()
da, db = jt.grad(c, [a, b])
jt.sync_all()
for n in jt.dump_all_graphs().nodes_info:
if "Var" in n and "bool" in n:
print(n)
assert ",0)" not in n
jt.ternary_out_hint((a<b).out_hint(), a, 0).sync()
def test_out_hint2(self):
a = jt.rand(10)
b = jt.rand(10)
c = jt.ternary(a<b, a, b).clone()
# c.sync()
da, db = jt.grad(c, [a, b])
jt.sync_all()
for n in jt.dump_all_graphs().nodes_info:
if "Var" in n and "bool" in n:
print(n)
assert ",0)" not in n
def test_relu_memopt(self):
x = a = jt.rand(10,10)
for i in range(10):
# a = jt.nn.relu(a)
a = jt.ternary_out_hint((a>0.0).name("b"+str(i)), a, 0.0)
a = jt.matmul(a.name("m1"),jt.rand(10,10).name("m2")).name("m3-"+str(i))
da = jt.grad(a, x, True)
# jt.clean_graph()
da.sync()
cnt1 = 0
cnt2 = 0
for n in jt.dump_all_graphs().nodes_info:
if "Var" in n and ",0)" not in n:
cnt1 +=1
if "bool" in n:
cnt2 += 1
print(cnt1, cnt2)
assert cnt2 == 10
assert cnt1 <= 33, cnt1
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -55,6 +55,7 @@ class TestExample(unittest.TestCase):
model = Model(input_size=1) model = Model(input_size=1)
ps = model.parameters() ps = model.parameters()
for p in reversed(ps): p.sync(0,0)
for i,(x,y) in enumerate(get_data(n)): for i,(x,y) in enumerate(get_data(n)):
pred_y = model(x).name("pred_y") pred_y = model(x).name("pred_y")

View File

@ -63,6 +63,7 @@ class TestExample(unittest.TestCase):
model = Model(input_size=1) model = Model(input_size=1)
ps = model.parameters() ps = model.parameters()
for p in reversed(ps): p.sync(0,0)
opt = Optimizer(ps, lr) opt = Optimizer(ps, lr)
all_loss = 0 all_loss = 0

View File

@ -258,6 +258,7 @@ class TestFunction(unittest.TestCase):
g = jt.grad(c+d*3, [a, b]) g = jt.grad(c+d*3, [a, b])
test() test()
jt.clean() jt.clean()
jt.dump_all_graphs()
self.assertEqual(jt.liveness_info()["lived_vars"], 0) self.assertEqual(jt.liveness_info()["lived_vars"], 0)
@unittest.skipIf(True, "skip memleak test") @unittest.skipIf(True, "skip memleak test")

View File

@ -78,9 +78,11 @@ class TestFusedOp(unittest.TestCase):
def test_add(self): def test_add(self):
jt.clean() jt.clean()
def check(hv, lv, lo): def check(hv, lv, lo):
self.assertEqual(jt.number_of_hold_vars(), hv) self.assertEqual((
self.assertEqual(jt.number_of_lived_vars(), lv) jt.number_of_hold_vars(),
self.assertEqual(jt.number_of_lived_ops(), lo) jt.number_of_lived_vars(),
jt.number_of_lived_ops()),
(hv, lv, lo))
for i in range(8): for i in range(8):
check(0,0,0) check(0,0,0)
a = jt.array(1.0).name('a').stop_fuse() a = jt.array(1.0).name('a').stop_fuse()
@ -88,7 +90,14 @@ class TestFusedOp(unittest.TestCase):
c = (b+jt.array(1.0).name('t2').stop_fuse()).name('c') c = (b+jt.array(1.0).name('t2').stop_fuse()).name('c')
check(3,5,5) check(3,5,5)
graph = jt.dump_all_graphs() graph = jt.dump_all_graphs()
# for n in graph.nodes_info:
# print(n)
self.assertEqual(c.data, 3) self.assertEqual(c.data, 3)
graph2 = jt.dump_all_graphs()
print("check", i)
for n in graph2.nodes_info:
print(n)
print(jt.liveness_info())
check(3,5,2) check(3,5,2)
graph = jt.dump_all_graphs() graph = jt.dump_all_graphs()
for node in graph.nodes_info: for node in graph.nodes_info:

View File

@ -39,7 +39,6 @@ class TestIndexOp(unittest.TestCase):
def test_vary_shape_dep(self): def test_vary_shape_dep(self):
a, = jt.where([1,0,1]) a, = jt.where([1,0,1])
b, = a.index_var() b, = a.index_var()
assert a.uncertain_shape==[-3] and b.uncertain_shape==[-3]
assert (b.data==[0,1]).all() assert (b.data==[0,1]).all()
def test_vary_shape_dep2(self): def test_vary_shape_dep2(self):
@ -48,7 +47,6 @@ class TestIndexOp(unittest.TestCase):
index0 = index0.broadcast([1,3], dims=[1]) # [[1,1,1],[2,2,2]] index0 = index0.broadcast([1,3], dims=[1]) # [[1,1,1],[2,2,2]]
index1 = index0.index_var(1) # [[0,1,2],[0,1,2]] index1 = index0.index_var(1) # [[0,1,2],[0,1,2]]
b = a.reindex_var([index0, index1]) b = a.reindex_var([index0, index1])
assert b.uncertain_shape==[-3,3]
assert (b.data==[[4,5,6],[7,8,9]]).all() assert (b.data==[[4,5,6],[7,8,9]]).all()
assert (index0.data==[[1,1,1],[2,2,2]]).all() assert (index0.data==[[1,1,1],[2,2,2]]).all()
assert (index1.data==[[0,1,2],[0,1,2]]).all() assert (index1.data==[[0,1,2],[0,1,2]]).all()

View File

@ -130,6 +130,7 @@ class TestMatmul(unittest.TestCase):
np.random.seed(0) np.random.seed(0)
jt.set_seed(3) jt.set_seed(3)
model = Model() model = Model()
for p in reversed(model.parameters()): p.sync(0,0)
SGD = jt.nn.SGD(model.parameters(), 0.05, 0.9, 0) SGD = jt.nn.SGD(model.parameters(), 0.05, 0.9, 0)
n = 1000 n = 1000
batch_size = 50 batch_size = 50

View File

@ -141,6 +141,13 @@ class TestNode(unittest.TestCase):
# noded opt: build(0.44),execute(0.11) # noded opt: build(0.44),execute(0.11)
# for i in range(20): # for i in range(20):
# run() # run()
# version 1.3.2.6 retest(laptop)
# mode1:
# origin 0.296 exec(0.11)
# int32flag 0.298 exec(0.11)
# add order 0.299 exec(0.11)
# rm p1 rule 0.299 exec(0.11)
for i in range(20): for i in range(20):
run() run()
import gc import gc

View File

@ -75,10 +75,7 @@ def conv_transpose_naive(x, w):
def is_fused(x): def is_fused(x):
x.name('_x') return 's0' in x.debug_msg()
graph = jt.dump_all_graphs()
node_a = [ node for node in graph.nodes_info if ",_x," in node ]
return 's0' in node_a[0]
def check_fused(dim): def check_fused(dim):
jt.clean() jt.clean()

View File

@ -91,7 +91,10 @@ class TestResnetFp32(unittest.TestCase):
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
.format(epoch_id, batch_id, 600,1. * batch_id / 6.0, loss[0], acc, time.time()-prev)) .format(epoch_id, batch_id, 600,1. * batch_id / 6.0, loss[0], acc, time.time()-prev))
# prev = time.time() # prev = time.time()
# async version
jt.fetch(epoch_id, batch_id, loss, output, target, callback) jt.fetch(epoch_id, batch_id, loss, output, target, callback)
# sync version
# callback(epoch_id, batch_id, loss.numpy(), output.numpy(), target.numpy())
# log_conv = find_log_with_re(logs, # log_conv = find_log_with_re(logs,
# "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*") # "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")

View File

@ -66,8 +66,9 @@ def test_ring_buffer():
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
test_send_recv(np.random.rand(10,10)) test_send_recv(np.random.rand(10,10))
n_byte += 1 + 16 + 2 + 10*10*8 n_byte += 1 + 16 + 4 + 10*10*8
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() assert n_byte == buffer.total_pop() and n_byte == buffer.total_push(), \
(n_byte, buffer.total_pop(), n_byte, buffer.total_push())
test_send_recv(test_ring_buffer) test_send_recv(test_ring_buffer)
test_send_recv(jt.array(np.random.rand(10,10))) test_send_recv(jt.array(np.random.rand(10,10)))

View File

@ -381,6 +381,19 @@ class TestSetitem(unittest.TestCase):
for i in range(n): for i in range(n):
np.testing.assert_allclose(g.data[i*m:(i+1)*m], yy[i].data) np.testing.assert_allclose(g.data[i*m:(i+1)*m], yy[i].data)
def test_dfs_memopt(self):
with jt.flag_scope(profile_memory_enable=1):
n = 1024
b = []
for i in range(n):
a = jt.rand(n).copy().copy()
a = a.sum()
# a.sync()
b.append(a)
jt.sync_all()
jt.get_max_memory_treemap()

View File

@ -15,7 +15,7 @@ class TestWhereOp(unittest.TestCase):
def test(self): def test(self):
assert (self.where([0,1,0,1])[0].data == [1,3]).all() assert (self.where([0,1,0,1])[0].data == [1,3]).all()
a, = self.where([0,1,0,1]) a, = self.where([0,1,0,1])
assert a.uncertain_shape==[-4] assert a.uncertain_shape==[2]
a.data a.data
assert a.uncertain_shape==[2] assert a.uncertain_shape==[2]
a,b = self.where([[0,0,1],[1,0,0]]) a,b = self.where([[0,0,1],[1,0,0]])

Binary file not shown.