mirror of https://github.com/Jittor/Jittor
v 1.3.3 memory optimization
This commit is contained in:
parent
9048f3fd41
commit
0666456a2f
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.2.6'
|
||||
__version__ = '1.3.3.0'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -355,10 +355,10 @@ def array64(data, dtype=None):
|
|||
with jt.flag_scope(auto_convert_64_to_32=0):
|
||||
return array(data, dtype)
|
||||
|
||||
def grad(loss, targets):
|
||||
def grad(loss, targets, retain_graph=True):
|
||||
if type(targets) == core.Var:
|
||||
return core.grad(loss, [targets])[0]
|
||||
return core.grad(loss, targets)
|
||||
return core.grad(loss, [targets], retain_graph)[0]
|
||||
return core.grad(loss, targets, retain_graph)
|
||||
|
||||
def liveness_info():
|
||||
return {
|
||||
|
|
|
@ -28,6 +28,8 @@ CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdim
|
|||
ASSERT(offsets->dtype()==ns_int32);
|
||||
y = create_output(nullptr, ns_int32);
|
||||
y_key = create_output(nullptr, x->dtype());
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
y->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
|
||||
VarPtr CubArgReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
|
|
|
@ -27,6 +27,8 @@ CubArgsortOp::CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending,
|
|||
ASSERT(offsets->dtype()==ns_int32);
|
||||
y = create_output(nullptr, dtype);
|
||||
y_key = create_output(nullptr, x->dtype());
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
y->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
|
||||
VarPtr CubArgsortOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
|
|
|
@ -26,7 +26,6 @@ namespace jittor {
|
|||
CubWhereOp::CubWhereOp(Var* cond, NanoString dtype) : cond(cond) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_vary_shape);
|
||||
auto ndim = cond->shape.size();
|
||||
outs.reset(new Var*[ndim]);
|
||||
for (uint i=0; i<ndim; i++)
|
||||
|
@ -35,8 +34,7 @@ CubWhereOp::CubWhereOp(Var* cond, NanoString dtype) : cond(cond) {
|
|||
|
||||
void CubWhereOp::infer_shape() {
|
||||
auto ndim = cond->shape.size();
|
||||
auto num = cond->num;
|
||||
if (num>0) num = -num;
|
||||
auto num = -cond->num;
|
||||
for (uint i=0; i<ndim; i++)
|
||||
outs[i]->set_shape({num});
|
||||
}
|
||||
|
|
|
@ -35,6 +35,9 @@ CublasBatchedMatmulOp::CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool
|
|||
c = create_output(nullptr, a->dtype());
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
a->flags.set(NodeFlags::_needed_by_backward);
|
||||
b->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -24,6 +24,9 @@ CublasMatmulOp::CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
|
|||
: a(a), b(b), trans_a(trans_a), trans_b(trans_b) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
a->flags.set(NodeFlags::_needed_by_backward);
|
||||
b->flags.set(NodeFlags::_needed_by_backward);
|
||||
// TODO: support int8 * int8
|
||||
ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same";
|
||||
// TODO: support diffrent input type
|
||||
|
|
|
@ -29,6 +29,9 @@ CudnnConv3dBackwardWOp::CudnnConv3dBackwardWOp(Var* x, Var* dy, int kd, int kh,
|
|||
xformat(move(xformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
x->flags.set(NodeFlags::_needed_by_backward);
|
||||
dy->flags.set(NodeFlags::_needed_by_backward);
|
||||
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,9 @@ CudnnConv3dBackwardXOp::CudnnConv3dBackwardXOp(Var* w, Var* dy, int depth, int h
|
|||
xformat(move(xformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
w->flags.set(NodeFlags::_needed_by_backward);
|
||||
dy->flags.set(NodeFlags::_needed_by_backward);
|
||||
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,9 @@ CudnnConv3dOp::CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int strid
|
|||
xformat(move(xformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
x->flags.set(NodeFlags::_needed_by_backward);
|
||||
w->flags.set(NodeFlags::_needed_by_backward);
|
||||
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
}
|
||||
|
||||
|
|
|
@ -52,6 +52,9 @@ CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int
|
|||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
x->flags.set(NodeFlags::_needed_by_backward);
|
||||
dy->flags.set(NodeFlags::_needed_by_backward);
|
||||
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,9 @@ CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int widt
|
|||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
w->flags.set(NodeFlags::_needed_by_backward);
|
||||
dy->flags.set(NodeFlags::_needed_by_backward);
|
||||
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
|
||||
}
|
||||
|
||||
|
|
|
@ -50,6 +50,9 @@ CudnnConvOp::CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh,
|
|||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
x->flags.set(NodeFlags::_needed_by_backward);
|
||||
w->flags.set(NodeFlags::_needed_by_backward);
|
||||
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
if (!this->yformat.size())
|
||||
this->yformat = this->xformat;
|
||||
|
|
|
@ -801,6 +801,10 @@ def print_tree(now, max_memory_size, prefix1, prefix2, build_by):
|
|||
tab = ' '
|
||||
out += prefix1+now['name']+'('+now['type']+')\n'
|
||||
out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%; cnt:'+format_size(now['cnt'],'') + ']\n'
|
||||
if len(now['children']) == 0 and len(now['vinfo']):
|
||||
out += prefix2+now['vinfo'][0]
|
||||
if len(now['vinfo']) > 1: out += "..."
|
||||
out += '\n'
|
||||
if (build_by == 0):
|
||||
for p in now['path']:
|
||||
out += prefix2+p+'\n'
|
||||
|
@ -866,7 +870,8 @@ Output::
|
|||
vars_ = vars_[1:]
|
||||
for v_ in vars_:
|
||||
v__ = v_.split(div2)
|
||||
var = {'size':int(v__[1]), 'stack':[], 'cnt':1}
|
||||
vinfo = v__[0].split("{")[0]
|
||||
var = {'size':int(v__[1]), 'stack':[], 'cnt':1, "vinfo":vinfo}
|
||||
v__ = v__[2:-1]
|
||||
for s_ in v__:
|
||||
s__ = s_.split(div3)
|
||||
|
@ -874,7 +879,7 @@ Output::
|
|||
var['stack'].append(s)
|
||||
vars.append(var)
|
||||
if (build_by == 0): # build tree by name
|
||||
tree = {'name':'root', "children":[], 'size':0, 'cnt':1, 'path':[], 'type':''}
|
||||
tree = {'name':'root', "children":[], 'size':0, 'cnt':1, 'path':[], 'type':'', 'vinfo':[]}
|
||||
|
||||
def find_child(now, key):
|
||||
for c in now['children']:
|
||||
|
@ -885,6 +890,7 @@ Output::
|
|||
now = tree
|
||||
now['size'] += v['size']
|
||||
now['cnt'] += v['cnt']
|
||||
now['vinfo'].append(v['vinfo'])
|
||||
for s in v['stack']:
|
||||
ch = find_child(now, s['name'])
|
||||
if (ch is not None):
|
||||
|
@ -894,12 +900,13 @@ Output::
|
|||
now = ch
|
||||
now['size'] += v['size']
|
||||
now['cnt'] += v['cnt']
|
||||
now['vinfo'].append(v['vinfo'])
|
||||
else:
|
||||
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':[s['path']], 'type':s['type']}
|
||||
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':[s['path']], 'type':s['type'], 'vinfo':[v['vinfo']]}
|
||||
now['children'].append(now_)
|
||||
now = now_
|
||||
elif (build_by == 1): # build tree by path
|
||||
tree = {'name':'root', "children":[], 'size':0, 'cnt':0, 'path':'_root_', 'type':''}
|
||||
tree = {'name':'root', "children":[], 'size':0, 'cnt':0, 'path':'_root_', 'type':'', 'vinfo':[]}
|
||||
|
||||
def find_child(now, key):
|
||||
for c in now['children']:
|
||||
|
@ -910,14 +917,16 @@ Output::
|
|||
now = tree
|
||||
now['size'] += v['size']
|
||||
now['cnt'] += v['cnt']
|
||||
now['vinfo'].append(v['vinfo'])
|
||||
for s in v['stack']:
|
||||
ch = find_child(now, s['path'])
|
||||
if (ch is not None):
|
||||
now = ch
|
||||
now['size'] += v['size']
|
||||
now['cnt'] += v['cnt']
|
||||
now['vinfo'].append(v['vinfo'])
|
||||
else:
|
||||
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':s['path'], 'type':s['type']}
|
||||
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':s['path'], 'type':s['type'], 'vinfo':[v['vinfo']]}
|
||||
now['children'].append(now_)
|
||||
now = now_
|
||||
else:
|
||||
|
|
|
@ -173,7 +173,8 @@ def relu(x):
|
|||
>>> nn.relu(a)
|
||||
jt.Var([0. 1.1338731 6.128115 ], dtype=float32)
|
||||
'''
|
||||
return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x))
|
||||
cond = x>0.0
|
||||
return jt.ternary_out_hint(cond, x, 0.0)
|
||||
|
||||
|
||||
def leaky_relu(x, scale=0.01):
|
||||
|
|
|
@ -98,7 +98,7 @@ class Optimizer(object):
|
|||
def zero_grad(self):
|
||||
self.__zero_grad = True
|
||||
|
||||
def pre_step(self, loss):
|
||||
def pre_step(self, loss, retain_graph=False):
|
||||
""" something should be done before step, such as calc gradients, mpi sync, and so on.
|
||||
|
||||
Example::
|
||||
|
@ -118,7 +118,7 @@ class Optimizer(object):
|
|||
params_has_grad.append(p)
|
||||
|
||||
# get gradient
|
||||
grads = jt.grad(loss, params_has_grad)
|
||||
grads = jt.grad(loss, params_has_grad, retain_graph)
|
||||
|
||||
# sync grads and model if in mpi
|
||||
if jt.in_mpi:
|
||||
|
@ -153,7 +153,7 @@ class Optimizer(object):
|
|||
pid += 1
|
||||
self.__zero_grad = False
|
||||
|
||||
def backward(self, loss):
|
||||
def backward(self, loss, retain_graph=False):
|
||||
'''
|
||||
optimize.backward(loss) is used for accumulate multiple step,
|
||||
it can be used as following:
|
||||
|
@ -186,11 +186,11 @@ class Optimizer(object):
|
|||
|
||||
|
||||
'''
|
||||
self.pre_step(loss)
|
||||
self.pre_step(loss, retain_graph)
|
||||
|
||||
def step(self, loss=None):
|
||||
def step(self, loss=None, retain_graph=False):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
self.pre_step(loss, retain_graph)
|
||||
for pg in self.param_groups:
|
||||
lr = pg.get("lr", self.lr)
|
||||
for p, g in zip(pg["params"], pg["grads"]):
|
||||
|
|
|
@ -33,6 +33,6 @@ inline static void __print_trace() {
|
|||
}
|
||||
|
||||
// @pyjt(grad)
|
||||
vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets);
|
||||
vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets, bool retain_graph=true);
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "memory_profiler.h"
|
||||
#include "utils/seh.h"
|
||||
#include "utils/cache_compile.h"
|
||||
#include "var_holder.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -102,6 +103,23 @@ void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, i
|
|||
}
|
||||
}
|
||||
|
||||
static inline void propergate_needed_flags(FusedOp& fused_op) {
|
||||
auto& ops = fused_op.ops;
|
||||
for (int i=ops.size()-1; i>=0; i--) {
|
||||
bool has_need = 0;
|
||||
auto op = ops[i];
|
||||
for (auto o : op->outputs())
|
||||
if (o->flags.get(NodeFlags::_needed_by_backward) &&
|
||||
!(o->custom_data&1)) {
|
||||
has_need = 1;
|
||||
}
|
||||
if (has_need)
|
||||
for (auto i : op->inputs()) {
|
||||
i->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void check_op_async_error(Op* op, bool is_fused_op, const std::exception& e, jittor::Log& logf) {
|
||||
vector<Stack> stack;
|
||||
if (is_fused_op) {
|
||||
|
@ -151,7 +169,30 @@ void check_op_async_error(Op* op, bool is_fused_op, const std::exception& e, jit
|
|||
jittor::LogFatalVoidify() && logf;
|
||||
}
|
||||
|
||||
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||
static void top_weak_sync(vector<Var*>& vars) {
|
||||
auto t = ++Node::tflag_count;
|
||||
int64 max_id=0;
|
||||
for (auto v : vars) {
|
||||
max_id = std::max(v->id, max_id);
|
||||
v->tflag = t;
|
||||
}
|
||||
while (true) {
|
||||
if (sync_ptr == hold_vars.begin())
|
||||
break;
|
||||
auto next_ptr = std::prev(sync_ptr);
|
||||
auto v = (*next_ptr)->var;
|
||||
if (v->id > max_id) break;
|
||||
sync_ptr = next_ptr;
|
||||
if (v->tflag == t) continue;
|
||||
if (v->_outputs.size()) continue;
|
||||
if (v->is_finished()) continue;
|
||||
vars.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
||||
if (weak_sync)
|
||||
top_weak_sync(vars);
|
||||
auto allocator = get_allocator();
|
||||
auto temp_allocator = get_allocator(true);
|
||||
this->allocator = allocator;
|
||||
|
@ -287,6 +328,10 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
// output:
|
||||
// queue: toplogical order of fused op
|
||||
{
|
||||
// queue.clear();
|
||||
#ifndef JT_bfs_executor
|
||||
map<int64, int> p_queue;
|
||||
#endif
|
||||
for (int root : roots) {
|
||||
for (int i=root; i>=0; i=next[i]) {
|
||||
Op* op = ops[i];
|
||||
|
@ -299,23 +344,47 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#ifdef JT_bfs_executor
|
||||
if (deps[root] == 0)
|
||||
queue.push_back(root);
|
||||
#else
|
||||
if (deps[root] == 0)
|
||||
p_queue[ops[root]->id] = root;
|
||||
#endif
|
||||
}
|
||||
for (uint s=0; s<queue.size(); s++) {
|
||||
#ifdef JT_bfs_executor
|
||||
for (uint s=0; s<queue.size(); s++)
|
||||
#else
|
||||
while (p_queue.size())
|
||||
#endif
|
||||
{
|
||||
#ifdef JT_bfs_executor
|
||||
int op_id = queue[s];
|
||||
#else
|
||||
int op_id = p_queue.begin()->second;
|
||||
p_queue.erase(p_queue.begin());
|
||||
queue.push_back(op_id);
|
||||
#endif
|
||||
for (int i=op_id; i>=0; i=next[i]) {
|
||||
Op* op = ops[i];
|
||||
for (Var* v : op->outputs())
|
||||
{
|
||||
if (v->tflag == tt)
|
||||
for (Op* op2 : v->outputs()) {
|
||||
for (Op* op2 : v->outputs())
|
||||
{
|
||||
if (op2->tflag != tt) continue;
|
||||
int op2_id = father[op2->custom_data];
|
||||
// continue if those two ops are fused
|
||||
if (op2_id == op_id) continue;
|
||||
deps[op2_id]--;
|
||||
#ifdef JT_bfs_executor
|
||||
if (deps[op2_id] == 0)
|
||||
queue.push_back(op2_id);
|
||||
#else
|
||||
if (deps[op2_id] == 0)
|
||||
p_queue[op2->id] = op2_id;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -479,9 +548,6 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
|
||||
}
|
||||
LOGvvv << "Run" << op;
|
||||
if (op->flags.get(NodeFlags::_has_vary_input)) op->init();
|
||||
ASSERT(!op->flags.get(NodeFlags::_has_vary_input))
|
||||
<< "Shape of(" >> op->name() >> ") not solved.";
|
||||
for (auto* var : op->outputs()) {
|
||||
var->alloc(allocator);
|
||||
}
|
||||
|
@ -557,6 +623,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
LOGvvv << "Finished Op(" >> op->name() << rid >>
|
||||
"/" >> queue.size() >> ") output:" << op->outputs();
|
||||
if (is_fused_op) {
|
||||
propergate_needed_flags(fused_op);
|
||||
for (Var* var : op->outputs())
|
||||
var->finish_pending_liveness();
|
||||
continue;
|
||||
|
@ -596,7 +663,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
}
|
||||
}
|
||||
LOGvv << "All" << op_num << "ops finished, return vars:" << vars;
|
||||
for (Var* v : vars) ASSERT(v->mem_ptr);
|
||||
for (Var* v : vars) ASSERT(v->mem_ptr || !v->backward_liveness);
|
||||
// clean fetcher free buffer
|
||||
fetcher_to_free.clear();
|
||||
#ifdef HAS_CUDA
|
||||
|
|
|
@ -21,7 +21,7 @@ struct Executor {
|
|||
Allocator* allocator;
|
||||
Allocator* temp_allocator;
|
||||
bool last_is_cuda = false;
|
||||
void run_sync(vector<Var*> vars, bool device_sync);
|
||||
void run_sync(vector<Var*> vars, bool device_sync, bool weak_sync=true);
|
||||
|
||||
inline Allocation alloc_temp(size_t size) {
|
||||
return Allocation(temp_allocator, size);
|
||||
|
|
|
@ -45,6 +45,7 @@ void FusedOp::update_ops() {
|
|||
|
||||
_inputs.clear();
|
||||
_outputs.clear();
|
||||
vars.clear();
|
||||
for (Op* op : ops) {
|
||||
for (Var* o : op->outputs()) {
|
||||
if (o->loop_options) {
|
||||
|
@ -93,10 +94,7 @@ void FusedOp::update_ops() {
|
|||
o->custom_data &= 1;
|
||||
}
|
||||
}
|
||||
vars.clear();
|
||||
bool has_vary_input = 0;
|
||||
for (Op* opi : ops) {
|
||||
has_vary_input |= opi->flags.get(NodeFlags::_has_vary_input);
|
||||
for (Var* i : opi->inputs()) {
|
||||
auto &c = i->custom_data;
|
||||
// if not visited
|
||||
|
@ -116,7 +114,6 @@ void FusedOp::update_ops() {
|
|||
}
|
||||
}
|
||||
}
|
||||
flags.set(NodeFlags::_has_vary_input, has_vary_input);
|
||||
LOGvvvv << "Var info" << vars;
|
||||
}
|
||||
|
||||
|
@ -144,12 +141,9 @@ FusedOp::~FusedOp() {
|
|||
}
|
||||
|
||||
void FusedOp::infer_shape() {
|
||||
bool has_vary_input = 0;
|
||||
for (Op* op : ops) {
|
||||
op->init();
|
||||
has_vary_input |= op->flags.get(NodeFlags::_has_vary_input);
|
||||
}
|
||||
flags.set(NodeFlags::_has_vary_input, has_vary_input);
|
||||
}
|
||||
|
||||
void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "op.h"
|
||||
#include "graph.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "var_holder.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -76,7 +77,7 @@ void warn_grad_break(int i, Var* v) {
|
|||
LOGw << "grads[">>i>>"] '">> v->name>>"' doesn't have gradient. It will be set to zero:" << v;
|
||||
}
|
||||
|
||||
vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
||||
vector<VarPtr> grad(Var* loss, vector<Var*> targets, bool retain_graph) {
|
||||
LOGvv << "loss:" >> loss << "targets:" >> targets;
|
||||
CHECK(loss->is_float()) << "Loss should be float";
|
||||
for (Var* var : targets)
|
||||
|
@ -259,6 +260,20 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
assign_attrs(grad.ptr, var);
|
||||
}
|
||||
}
|
||||
if (!retain_graph) {
|
||||
auto t = ++Node::tflag_count;
|
||||
for (auto& vh : hold_vars)
|
||||
if (vh->var->tflag != t) {
|
||||
vh->var->tflag = t;
|
||||
}
|
||||
SetupFreeBuffer setup_free_buffer;
|
||||
for (int i=int(gvars.size())-1; i>=0; i--)
|
||||
if (gvars[i]->tflag != t && gvars[i]->backward_liveness)
|
||||
gvars[i]->set_stop_grad();
|
||||
for (int i=0; i<grads.size(); i++)
|
||||
if (grads[i])
|
||||
grads[i]->set_stop_grad();
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
vector<VarPtr> grad(Var* loss, vector<Var*> targets);
|
||||
vector<VarPtr> grad(Var* loss, vector<Var*> targets, bool retain_graph=true);
|
||||
|
||||
// @pyjt(tape_together)
|
||||
void tape_together(
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include "graph.h"
|
||||
#include "var_holder.h"
|
||||
#include "var.h"
|
||||
|
@ -32,7 +33,7 @@ void do_graph_check() {
|
|||
LOGvv << "Check hold_vars size" << queue.size();
|
||||
int vhsize = queue.size();
|
||||
for (auto* node : queue) {
|
||||
ASSERTop(node->forward_liveness,>,0);
|
||||
// ASSERTop(node->forward_liveness,>,0);
|
||||
ASSERTop(node->backward_liveness,>,0);
|
||||
}
|
||||
for (uint i=0; i<queue.size(); i++) {
|
||||
|
@ -49,7 +50,7 @@ void do_graph_check() {
|
|||
LOGvvvv << "Check node" << i << node;
|
||||
int f=0, b=0, p=0;
|
||||
if (i<vhsize) {
|
||||
f+=visited.at(node), b+=visited.at(node);
|
||||
f+=visited.at(node), b+=visited.at(node), p+=visited.at(node);
|
||||
}
|
||||
for (auto* i : node->inputs()) {
|
||||
if (i->is_stop_grad()) continue;
|
||||
|
@ -62,7 +63,7 @@ void do_graph_check() {
|
|||
if (o->pending_liveness && !o->is_finished())
|
||||
p++;
|
||||
}
|
||||
if (f>0 && b>0 && !node->is_finished()) p++;
|
||||
// if (f>0 && b>0 && !node->is_finished()) p++;
|
||||
if (f!=node->forward_liveness || b!=node->backward_liveness || p!=node->pending_liveness) {
|
||||
LOGf << "ERROR" << node << '\n'
|
||||
<< f << b << p << i << '\n'
|
||||
|
@ -91,6 +92,8 @@ DumpGraphs dump_all_graphs() {
|
|||
queue.push_back(vh->var);
|
||||
}
|
||||
bfs_both(queue, [](Node*){return true;});
|
||||
std::sort(queue.begin(), queue.end(),
|
||||
[](Node* a, Node* b) { return a->id < b->id;});
|
||||
DumpGraphs graphs;
|
||||
for (uint i=0; i<queue.size(); i++)
|
||||
queue[i]->custom_data = i;
|
||||
|
|
|
@ -76,7 +76,7 @@ void MemoryProfiler::check() {
|
|||
allocations.clear();
|
||||
size_t memory_size = 0;
|
||||
std::vector<std::pair<std::pair<string, vector<Stack>>, size_t>> live_vars;
|
||||
vector<Node*> queue;
|
||||
vector<Node*> queue, queue2;
|
||||
|
||||
auto t = ++Node::tflag_count;
|
||||
for (auto& vh : hold_vars)
|
||||
|
@ -85,6 +85,14 @@ void MemoryProfiler::check() {
|
|||
queue.push_back(vh->var);
|
||||
}
|
||||
bfs_both(queue, [](Node*){return true;});
|
||||
vector<int> backup_custom_data;
|
||||
backup_custom_data.resize(queue.size());
|
||||
for (int i=0; i<queue.size(); i++)
|
||||
backup_custom_data[i] = queue[i]->custom_data;
|
||||
toplogical_sort_forward(queue, queue2, [](Node*){});
|
||||
for (int i=0; i<queue.size(); i++)
|
||||
queue[i]->custom_data = backup_custom_data[i];
|
||||
queue.swap(queue2);
|
||||
for (Node* node : queue) {
|
||||
if (node->is_var()) {
|
||||
Var* var = (Var*)node;
|
||||
|
|
|
@ -137,6 +137,24 @@ static unordered_set<string> white_ops = {
|
|||
"pow",
|
||||
};
|
||||
|
||||
static unordered_set<string> no_need_back_in = {
|
||||
"void",
|
||||
"cast",
|
||||
"negative",
|
||||
"add",
|
||||
"subtract",
|
||||
"mean",
|
||||
};
|
||||
|
||||
static unordered_set<string> no_need_back_out = {
|
||||
"void",
|
||||
"cast",
|
||||
"negative",
|
||||
"add",
|
||||
"subtract",
|
||||
"multiply",
|
||||
};
|
||||
|
||||
#define DEFINE_NS(T) NanoString ns_##T;
|
||||
FOR_ALL_NS(DEFINE_NS);
|
||||
|
||||
|
@ -172,6 +190,8 @@ static void init_ns() {
|
|||
ns.set(NanoString::_float, float_ops.count(name));
|
||||
}
|
||||
ns.set(NanoString::_white_list, white_ops.count(name));
|
||||
ns.set(NanoString::_no_need_back_in, no_need_back_in.count(name));
|
||||
ns.set(NanoString::_no_need_back_out, no_need_back_out.count(name));
|
||||
__string_to_ns[name] = ns;
|
||||
auto name2 = ns.to_cstring();
|
||||
int len=0;
|
||||
|
|
|
@ -98,7 +98,7 @@ EXTERN_LIB int __ns_len[];
|
|||
|
||||
// @pyjt(NanoString)
|
||||
struct NanoString {
|
||||
typedef uint16 ns_t;
|
||||
typedef uint32 ns_t;
|
||||
enum Flags {
|
||||
// bit0~7: index
|
||||
_index=0, _index_nbits=7,
|
||||
|
@ -119,6 +119,9 @@ struct NanoString {
|
|||
_dsize=_n+6, _dsize_nbits=2,
|
||||
// bit8: white list
|
||||
_white_list=_n+8,
|
||||
// bit9: backward opt
|
||||
_no_need_back_in=_n+9,
|
||||
_no_need_back_out=_n+10,
|
||||
};
|
||||
ns_t data=0;
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ EXTERN_LIB int64 nt;
|
|||
EXTERN_LIB vector<Node*> free_buffer;
|
||||
|
||||
struct NodeFlags {
|
||||
typedef uint16 nf_t;
|
||||
typedef uint32 nf_t;
|
||||
nf_t flags=0;
|
||||
enum Flags {
|
||||
// bit0: is_var
|
||||
|
@ -35,6 +35,8 @@ struct NodeFlags {
|
|||
_force_fuse=_n+0,
|
||||
_stop_fuse=_n+1,
|
||||
_in_update_queue=_n+2,
|
||||
_needed_by_backward=_n+3,
|
||||
_out_hint=_n+4,
|
||||
|
||||
// op related flags
|
||||
// bit0: support cpu
|
||||
|
@ -53,13 +55,14 @@ struct NodeFlags {
|
|||
_has_gopt=_n+7,
|
||||
// bit8: has vary input
|
||||
_has_vary_input=_n+8,
|
||||
_manual_set_vnbb = _n+9,
|
||||
// bit9: prefer 32 bit
|
||||
_prefer_32=_n+9,
|
||||
// bit10: force 16 bit
|
||||
_prefer_16=_n+10,
|
||||
// bit11: reduce keep type unchange
|
||||
_reduce_keep=_n+11,
|
||||
_custom_flag=_reduce_keep,
|
||||
_prefer_32=_n+10,
|
||||
// force 16 bit
|
||||
_prefer_16=_prefer_32+1,
|
||||
// reduce keep type unchange
|
||||
_reduce_keep=_prefer_32+2,
|
||||
_custom_flag = _reduce_keep,
|
||||
};
|
||||
|
||||
inline void set(Flags f, int a=1, int nbits=1) {
|
||||
|
@ -89,8 +92,8 @@ struct Node {
|
|||
};
|
||||
struct output_t {
|
||||
Node* node;
|
||||
int index;
|
||||
list<input_t>::iterator back;
|
||||
int index;
|
||||
output_t(Node* n, int i) : node(n), index(i) {}
|
||||
operator Node*() { return node; }
|
||||
operator Op*() { return (Op*)node; }
|
||||
|
@ -120,14 +123,15 @@ struct Node {
|
|||
inline bool need_free()
|
||||
{ return !pending_liveness && (!forward_liveness || !backward_liveness); }
|
||||
|
||||
int64_t tflag = 0;
|
||||
int64_t custom_data;
|
||||
int custom_data;
|
||||
int64 tflag = 0;
|
||||
int64 id;
|
||||
list<input_t> _inputs;
|
||||
list<output_t> _outputs;
|
||||
|
||||
#ifdef NODE_MEMCHECK
|
||||
inline Node() {
|
||||
lived_nodes[(void*)this] = ++total_node;
|
||||
lived_nodes[(void*)this] = id = ++total_node;
|
||||
}
|
||||
|
||||
inline virtual ~Node() {
|
||||
|
@ -135,7 +139,7 @@ struct Node {
|
|||
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);
|
||||
}
|
||||
#else
|
||||
inline Node() {};
|
||||
inline Node() { id = ++total_node; };
|
||||
inline virtual ~Node() { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);};
|
||||
#endif
|
||||
inline Var* var() { return (Var*)this; }
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
#include "mem/allocator.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "pybind/py_var_tracer.h"
|
||||
#include "executor.h"
|
||||
#include "var_holder.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -65,14 +67,26 @@ Var* Op::create_output(NanoVector shape, NanoString dtype) {
|
|||
}
|
||||
|
||||
void Op::init() {
|
||||
bool has_vary_input = 0;
|
||||
for (Var* v : inputs())
|
||||
if (v->num < 0) {
|
||||
has_vary_input = 1;
|
||||
break;
|
||||
}
|
||||
flags.set(NodeFlags::_has_vary_input, has_vary_input);
|
||||
infer_shape();
|
||||
bool manual_set_vnbb = flags.get(NodeFlags::_manual_set_vnbb)
|
||||
|| _inputs.size()==0
|
||||
|| (_outputs.size()==1 && _outputs.front().node->is_stop_grad());
|
||||
for (Var* v : inputs()) {
|
||||
if (!manual_set_vnbb) {
|
||||
v->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
}
|
||||
Var* need_sync = nullptr;
|
||||
for (Var* v : outputs()) {
|
||||
if (!manual_set_vnbb)
|
||||
v->flags.set(NodeFlags::_needed_by_backward);
|
||||
if (v->num < 0)
|
||||
need_sync = v;
|
||||
}
|
||||
if (need_sync) {
|
||||
exe.run_sync(vector<Var*>({need_sync}), false);
|
||||
CHECK(need_sync->num >= 0) << need_sync << "'s shape is error";
|
||||
}
|
||||
}
|
||||
|
||||
void Op::compile_optimize(string& src) {}
|
||||
|
@ -84,7 +98,7 @@ void Op::graph_optimize() {}
|
|||
|
||||
string Op::name_ex() const {
|
||||
string a=name();
|
||||
if (ns!=ns_void) {
|
||||
if (ns.data) {
|
||||
a += '.';
|
||||
a += ns.to_cstring();
|
||||
}
|
||||
|
@ -266,7 +280,7 @@ void Op::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
|
|||
|
||||
std::ostream& operator<<(std::ostream& os, const Op* op) {
|
||||
if (!op) return os << "Op(0)";
|
||||
os << "Op(" << (void*)op
|
||||
os << "Op(" << op->id
|
||||
<< ':' << op->forward_liveness
|
||||
<< ':' << op->backward_liveness
|
||||
<< ':' << op->pending_liveness
|
||||
|
|
|
@ -46,7 +46,6 @@ ArgReduceOp::ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims)
|
|||
get_op_info("cub_arg_reduce").get_constructor<std::vector<VarPtr>, Var*, Var*, NanoString, bool>()
|
||||
: nullptr;
|
||||
if (cub_arg_reduce) {
|
||||
if (x->num<0) exe.run_sync(vector<Var*>({x}), true);
|
||||
int dims = x->shape.size();
|
||||
vector<int64> axes;
|
||||
axes.reserve(dims);
|
||||
|
@ -88,6 +87,8 @@ ArgReduceOp::ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims)
|
|||
#endif
|
||||
y = create_output(nullptr, ns_int32);
|
||||
y_key = create_output(nullptr, x->dtype());
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
y->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
VarPtr ArgReduceOp::get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y) {
|
||||
// Do not have grad to extras input
|
||||
|
|
|
@ -45,7 +45,6 @@ ArgsortOp::ArgsortOp(Var* x, int dim, bool descending, NanoString dtype)
|
|||
.get_constructor<std::vector<VarPtr>, Var*, Var*, Var*, bool, NanoString>();
|
||||
}
|
||||
if (cub_argsort) {
|
||||
if (x->num<0) exe.run_sync(vector<Var*>({x}), true);
|
||||
int dims = x->shape.size();
|
||||
vector<int64> axes;
|
||||
axes.reserve(dims);
|
||||
|
@ -81,6 +80,8 @@ ArgsortOp::ArgsortOp(Var* x, int dim, bool descending, NanoString dtype)
|
|||
#endif
|
||||
y = create_output(nullptr, dtype);
|
||||
y_key = create_output(nullptr, x->dtype());
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
y->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
|
||||
VarPtr ArgsortOp::get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y) {
|
||||
|
|
|
@ -425,6 +425,18 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
|
|||
ns = op;
|
||||
ASSERT(ns.is_binary());
|
||||
z = create_output(nullptr, binary_dtype_infer(op, x->ns, y->ns));
|
||||
bool bin = ns.get(NanoString::_no_need_back_in);
|
||||
bool bout = ns.get(NanoString::_no_need_back_out);
|
||||
if (bin || bout) {
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
if (!bin) {
|
||||
x->flags.set(NodeFlags::_needed_by_backward);
|
||||
y->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
if (!bout) {
|
||||
z->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VarPtr dirty_clone_broadcast(Var* v) {
|
||||
|
|
|
@ -29,6 +29,7 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
|
|||
}
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
set_type(OpType::broadcast);
|
||||
z = create_output(NanoVector(), x->dtype());
|
||||
bcast_mask = 0;
|
||||
|
|
|
@ -16,12 +16,12 @@ namespace jittor {
|
|||
CandidateOp::CandidateOp(Var* x, string&& fail_cond, NanoString dtype) : x(x), fail_cond(move(fail_cond)) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_vary_shape);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
y = create_output(nullptr, dtype);
|
||||
}
|
||||
|
||||
void CandidateOp::infer_shape() {
|
||||
y->set_shape({-std::abs(x->shape[0])});
|
||||
y->set_shape({-x->shape[0]});
|
||||
}
|
||||
|
||||
void CandidateOp::jit_prepare(JK& jk) {
|
||||
|
|
|
@ -19,6 +19,7 @@ static auto make_clone = get_op_info("clone")
|
|||
CloneOp::CloneOp(Var* x) : x(x) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
if (x->name.ptr)
|
||||
y->name = x->name;
|
||||
|
|
|
@ -37,7 +37,6 @@ CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
|
|||
_outputs.push_back(create_output(shape, dtype));
|
||||
|
||||
if (_outputs[0]->num < 0) {
|
||||
flags.set(NodeFlags::_vary_shape);
|
||||
check_vary_shape(_outputs[0]->shape);
|
||||
}
|
||||
}
|
||||
|
@ -58,7 +57,6 @@ CodeOp::CodeOp(
|
|||
for (int i=0; i<shapes.size(); i++) {
|
||||
_outputs[i] = create_output(shapes[i], dtypes[i]);
|
||||
if (_outputs[i]->num < 0) {
|
||||
flags.set(NodeFlags::_vary_shape);
|
||||
check_vary_shape(_outputs[i]->shape);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ namespace jittor {
|
|||
CopyOp::CopyOp(Var* x) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
auto y = create_output(nullptr, x->dtype());
|
||||
if (x->name.ptr)
|
||||
y->name = x->name;
|
||||
|
|
|
@ -33,6 +33,7 @@ FuseTransposeOp::FuseTransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) {
|
|||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(tp);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
int i=0;
|
||||
for (; i<axes.size(); i++)
|
||||
if (i!=axes[i]) break;
|
||||
|
|
|
@ -38,6 +38,10 @@ GetitemOp::GetitemOp(Var* x, VarSlices&& slices)
|
|||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_has_gopt);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
for (int i=0; i<vs.n; i++)
|
||||
if (vs.slices[i].is_var())
|
||||
vs.slices[i].var->flags.set(NodeFlags::_needed_by_backward);
|
||||
create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
|
@ -48,6 +52,10 @@ GetitemOp::GetitemOp(Var* x, VarSlices&& slices, int _)
|
|||
flags.set(NodeFlags::_has_gopt);
|
||||
flags.set(NodeFlags::_custom_flag);
|
||||
flags.set(NodeFlags::_grads);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
for (int i=0; i<vs.n; i++)
|
||||
if (vs.slices[i].is_var())
|
||||
vs.slices[i].var->flags.set(NodeFlags::_needed_by_backward);
|
||||
create_output(nullptr, x->dtype());
|
||||
auto out2 = create_output(nullptr, x->dtype());
|
||||
out2->share_with(x);
|
||||
|
@ -421,7 +429,6 @@ void GetitemOp::grads(Var** dout, VarPtr* dins) {
|
|||
VarPtr y = dout[0];
|
||||
if (!x) {
|
||||
auto in = inputs().front();
|
||||
if (in->num<0) exe.run_sync(vector<Var*>({in}), true);
|
||||
// ns.data represents this is the last split var
|
||||
if (ns.data)
|
||||
x = make_empty(in->shape, in->dtype());
|
||||
|
|
|
@ -31,6 +31,7 @@ IndexOp::IndexOp(Var* a, int64 dim, NanoString dtype) : dim(dim) {
|
|||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::element);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
x.reset(new Var*[1]);
|
||||
x[0] = create_output(nullptr, dtype);
|
||||
}
|
||||
|
@ -38,6 +39,7 @@ IndexOp::IndexOp(Var* a, int64 dim, NanoString dtype) : dim(dim) {
|
|||
IndexOp::IndexOp(Var* a, NanoString dtype) : dim(a->shape.size()) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
set_type(OpType::element);
|
||||
x.reset(new Var*[dim]);
|
||||
for (int i=0; i<dim; i++)
|
||||
|
|
|
@ -253,6 +253,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
|||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::reduce);
|
||||
if (op.get(NanoString::_no_need_back_in))
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
ns = op;
|
||||
ASSERT(ns.is_binary());
|
||||
auto xdim = x->shape.size();
|
||||
|
@ -279,6 +281,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
|
|||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::reduce);
|
||||
if (op.get(NanoString::_no_need_back_in))
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
ns = op;
|
||||
ASSERT(ns.is_binary());
|
||||
reduce_mask = dims_mask;
|
||||
|
@ -319,12 +323,6 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return make_binary(b, v, ns_divide);
|
||||
}
|
||||
if (ns == ns_mean) {
|
||||
if (v->num < 0) {
|
||||
// TODO: Dynamic shape of mean grad was not supported yet
|
||||
LOGw << "Dynamic shape of mean grad cause synchronize.";
|
||||
exe.run_sync({v}, 0);
|
||||
ASSERT(v->num>=0);
|
||||
}
|
||||
VarPtr a = make_broadcast_to(dout, v, reduce_mask, keepdims_mask);
|
||||
VarPtr n = make_number(1.0f*out->num / v->num, a);
|
||||
return make_binary(a, n, ns_multiply);
|
||||
|
|
|
@ -27,6 +27,7 @@ ReindexOp::ReindexOp(Var* x, NanoVector shape, vector<string>&& indexes, float64
|
|||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::broadcast);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,8 @@ ReindexReduceOp::ReindexReduceOp(Var* y, NanoString op, NanoVector shape, vector
|
|||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::reduce);
|
||||
if (op.get(NanoString::_no_need_back_in))
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
ns = op;
|
||||
ASSERT(ns.is_binary() && ns!=ns_mean);
|
||||
x = create_output(nullptr, y->dtype());
|
||||
|
|
|
@ -20,6 +20,7 @@ static auto make_reshape = get_op_info("reshape")
|
|||
ReshapeOp::ReshapeOp(Var* x, NanoVector shape) : x(x), shape(shape) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(shape.size() > 0) << "input target shape of reshape can't be empty.";
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ namespace jittor {
|
|||
SafeClipOp::SafeClipOp(Var* x, float64 left, float64 right) : x(x), left(left), right(right) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
set_type(OpType::element);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
|
|
@ -42,6 +42,12 @@ SetitemOp::SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op)
|
|||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_has_gopt);
|
||||
if (op.get(NanoString::_no_need_back_in)) {
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
for (int i=0; i<vs.n; i++)
|
||||
if (vs.slices[i].is_var())
|
||||
vs.slices[i].var->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
ASSERT(op == ns_void || op.is_binary());
|
||||
create_output(nullptr, x->dtype());
|
||||
if (flags.get(NodeFlags::_custom_flag)) {
|
||||
|
|
|
@ -16,6 +16,7 @@ namespace jittor {
|
|||
TapeOp::TapeOp(Var* x) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
|
@ -53,6 +54,7 @@ Tapes::Tapes(
|
|||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_grads);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
callback = move(grad_callback);
|
||||
|
||||
|
||||
|
|
|
@ -13,13 +13,26 @@ namespace jittor {
|
|||
#ifndef JIT
|
||||
static auto make_ternary = get_op_info("ternary")
|
||||
.get_constructor<VarPtr, Var*, Var*, Var*>();
|
||||
static auto make_broadcast = get_op_info("broadcast_to")
|
||||
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
|
||||
static auto make_number = get_op_info("number")
|
||||
.get_constructor<VarPtr, float, Var*>();
|
||||
|
||||
TernaryOp::TernaryOp(Var* cond, Var* x, Var* y) : cond(cond), x(x), y(y) {
|
||||
bool bx = cond->shape.size() > x->shape.size() || cond->num > x->num;
|
||||
bool by = cond->shape.size() > y->shape.size() || cond->num > y->num;
|
||||
if (bx || by) {
|
||||
VarPtr xx, yy;
|
||||
if (bx) xx = make_broadcast(x, cond, NanoVector()), x = xx;
|
||||
if (by) yy = make_broadcast(y, cond, NanoVector()), y = yy;
|
||||
forward(make_ternary(cond, x, y));
|
||||
return;
|
||||
}
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::element);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
cond->flags.set(NodeFlags::_needed_by_backward);
|
||||
z = create_output(nullptr, dtype_infer(x->ns, y->ns));
|
||||
}
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) {
|
|||
}
|
||||
#endif
|
||||
y = create_output(nullptr, x->dtype());
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
}
|
||||
|
||||
void TransposeOp::infer_shape() {
|
||||
|
|
|
@ -544,6 +544,17 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
|
|||
} else
|
||||
dtype = unary_dtype_infer(ns, x->ns);
|
||||
y = create_output(nullptr, dtype);
|
||||
bool bin = ns.get(NanoString::_no_need_back_in);
|
||||
bool bout = ns.get(NanoString::_no_need_back_out);
|
||||
if (bin || bout) {
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
if (!bin) {
|
||||
x->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
if (!bout) {
|
||||
y->flags.set(NodeFlags::_needed_by_backward);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace jittor {
|
|||
WhereOp::WhereOp(Var* cond, NanoString dtype) : cond(cond) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_vary_shape);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
auto ndim = cond->shape.size();
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
|
@ -48,8 +48,7 @@ WhereOp::WhereOp(Var* cond, Var* x, Var* y) {
|
|||
|
||||
void WhereOp::infer_shape() {
|
||||
auto ndim = cond->shape.size();
|
||||
auto num = cond->num;
|
||||
if (num>0) num = -num;
|
||||
auto num = -cond->num;
|
||||
for (uint i=0; i<ndim; i++)
|
||||
outs[i]->set_shape({num});
|
||||
}
|
||||
|
|
|
@ -124,9 +124,6 @@ static void getitem_inplace(GetitemOp* op) {
|
|||
// return if out is all ready inplaced
|
||||
if (ou->allocator)
|
||||
return;
|
||||
// return if input or output's shape is variable
|
||||
if (in->num <= 0 || ou->num <= 0)
|
||||
return;
|
||||
|
||||
VarSlices vs = op->vs;
|
||||
auto in_shape = in->shape;
|
||||
|
|
|
@ -17,11 +17,11 @@ SEH_HOOK;
|
|||
// Those function is generated by python
|
||||
EXTERN_LIB void pyjt_def_all(PyObject* m);
|
||||
|
||||
vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets) {
|
||||
vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets, bool retain_graph) {
|
||||
vector<Var*> vs;
|
||||
vs.reserve(targets.size());
|
||||
for (auto* v : targets) vs.push_back(v->var);
|
||||
auto grads = grad(loss->var, vs);
|
||||
auto grads = grad(loss->var, vs, retain_graph);
|
||||
vector<VarHolder*> grads_hold;
|
||||
grads_hold.reserve(targets.size());
|
||||
for (auto& grad : grads)
|
||||
|
|
|
@ -604,7 +604,7 @@ void system_with_check(const char* cmd, const char* cwd) {
|
|||
auto ret = system_popen(cmd, cwd);
|
||||
CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd <<
|
||||
"\nreturn ">> ret >> ". This might be an overcommit issue or out of memory."
|
||||
<< "Try : sudo sysctl vm.overcommit_memory=1";
|
||||
<< "Try : sudo sysctl vm.overcommit_memory=1, or set enviroment variable `export DISABLE_MULTIPROCESSING=1`";
|
||||
CHECKop(ret,==,0) << "Run cmd failed:" << cmd;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ DEFINE_FLAG(bool, no_grad, 0,
|
|||
"No grad for all jittor Var creation");
|
||||
DEFINE_FLAG(bool, no_fuse, 0,
|
||||
"No fusion optimization for all jittor Var creation");
|
||||
// TODO: fuse multiple flags
|
||||
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
|
||||
|
||||
DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16");
|
||||
|
@ -91,13 +92,14 @@ bool Var::alloc(Allocator* allocator) {
|
|||
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Var& var) {
|
||||
os << "Var" << '(' << (void*)&var
|
||||
os << "Var" << '(' << var.id
|
||||
<< ':' << var.forward_liveness
|
||||
<< ':' << var.backward_liveness
|
||||
<< ':' << var.pending_liveness
|
||||
<< ":i" << var._inputs.size()
|
||||
<< ":o" << var._outputs.size()
|
||||
<< ":s" << var.is_finished()
|
||||
<< ":n" << var.flags.get(NodeFlags::_needed_by_backward)
|
||||
<< ','
|
||||
<< var.dtype().to_cstring() << ',' << var.name << ',' << var.mem_ptr
|
||||
<< ')' << var.shape;
|
||||
|
|
|
@ -22,6 +22,7 @@ namespace jittor {
|
|||
DEFINE_FLAG(int, lazy_execution, 1, "Default enabled, if disable, use immediately eager execution rather than lazy execution, This flag makes error message and traceback infomation better. But this flag will raise memory consumption and lower the performance.");
|
||||
|
||||
list<VarHolder*> hold_vars;
|
||||
list<VarHolder*>::iterator sync_ptr = hold_vars.end();
|
||||
|
||||
void add_hold_vars(VarHolder* self) {
|
||||
hold_vars.push_front(self);
|
||||
|
@ -79,6 +80,8 @@ VarHolder::VarHolder(PyObject* obj, NanoString dtype) {
|
|||
|
||||
VarHolder::~VarHolder() {
|
||||
if (PREDICT_BRANCH_NOT_TAKEN(!var)) return;
|
||||
if (iter == sync_ptr)
|
||||
sync_ptr = std::next(sync_ptr);
|
||||
hold_vars.erase(iter);
|
||||
var->release_both_liveness();
|
||||
}
|
||||
|
@ -100,7 +103,6 @@ void VarHolder::operator=(VarPtr&& v) {
|
|||
}
|
||||
|
||||
string VarHolder::to_string() {
|
||||
if (var->num<0) sync();
|
||||
return var->to_string();
|
||||
}
|
||||
|
||||
|
@ -131,8 +133,8 @@ VarHolder* VarHolder::_update(VarHolder* v) {
|
|||
|
||||
EXTERN_LIB Executor exe;
|
||||
|
||||
void VarHolder::sync(bool device_sync) {
|
||||
jittor::sync({this}, device_sync);
|
||||
void VarHolder::sync(bool device_sync, bool weak_sync) {
|
||||
jittor::sync({this}, device_sync, weak_sync);
|
||||
}
|
||||
|
||||
ArrayArgs VarHolder::fetch_sync() {
|
||||
|
@ -178,12 +180,12 @@ void sync_all(bool device_sync) {
|
|||
graph_check();
|
||||
}
|
||||
|
||||
void sync(const vector<VarHolder*>& vh, bool device_sync) {
|
||||
void sync(const vector<VarHolder*>& vh, bool device_sync, bool weak_sync) {
|
||||
vector<Var*> vars;
|
||||
vars.reserve(vh.size());
|
||||
for (auto v : vh) vars.push_back(v->var);
|
||||
graph_check();
|
||||
exe.run_sync(vars, device_sync); //need sync at last
|
||||
exe.run_sync(vars, device_sync, weak_sync); //need sync at last
|
||||
graph_check();
|
||||
}
|
||||
|
||||
|
@ -226,4 +228,17 @@ your code as below::
|
|||
return 0;
|
||||
}
|
||||
|
||||
|
||||
static auto make_ternary = get_op_info("ternary")
|
||||
.get_constructor<VarPtr, Var*, Var*, Var*>();
|
||||
|
||||
extern int no_grad;
|
||||
|
||||
VarHolder* ternary_out_hint(VarHolder* cond, VarHolder* x, VarHolder* y) {
|
||||
if (!no_grad)
|
||||
cond->var->flags.set(NodeFlags::_out_hint);
|
||||
return new VarHolder(make_ternary(cond->var, x->var, y->var));
|
||||
}
|
||||
|
||||
|
||||
} // jittor
|
|
@ -31,6 +31,7 @@ struct ItemData {
|
|||
typedef struct _object PyObject;
|
||||
|
||||
EXTERN_LIB list<VarHolder*> hold_vars;
|
||||
EXTERN_LIB list<VarHolder*>::iterator sync_ptr;
|
||||
|
||||
// @pyjt(Var)
|
||||
// @attrs(heaptype)
|
||||
|
@ -47,7 +48,7 @@ struct VarHolder {
|
|||
~VarHolder();
|
||||
string to_string();
|
||||
// @pyjt(sync)
|
||||
void sync(bool device_sync = false);
|
||||
void sync(bool device_sync = false, bool weak_sync = true);
|
||||
// @pyjt(fetch_sync,numpy)
|
||||
ArrayArgs fetch_sync();
|
||||
|
||||
|
@ -108,7 +109,6 @@ struct VarHolder {
|
|||
*/
|
||||
// @pyjt(numel)
|
||||
inline int64 numel() {
|
||||
if (var->num<0) sync();
|
||||
return var->num;
|
||||
}
|
||||
|
||||
|
@ -155,12 +155,21 @@ struct VarHolder {
|
|||
return var->flags.get(NodeFlags::_stop_fuse);
|
||||
}
|
||||
|
||||
/**
|
||||
* output hint for training optimization
|
||||
*/
|
||||
// @pyjt(out_hint)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* out_hint() {
|
||||
var->flags.set(NodeFlags::_out_hint);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* return the shape of the Var.
|
||||
*/
|
||||
// @pyjt(__get__shape)
|
||||
inline NanoVector shape() {
|
||||
if (var->num<0) sync();
|
||||
return var->shape;
|
||||
}
|
||||
|
||||
|
@ -324,7 +333,7 @@ struct VarHolder {
|
|||
};
|
||||
|
||||
// @pyjt(sync)
|
||||
void sync(const vector<VarHolder*>& vh=vector<VarHolder*>(), bool device_sync=false);
|
||||
void sync(const vector<VarHolder*>& vh=vector<VarHolder*>(), bool device_sync=false, bool weak_sync=true);
|
||||
// @pyjt(fetch_sync)
|
||||
vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh);
|
||||
|
||||
|
@ -347,4 +356,7 @@ inline vector<VarHolder*> make_vh_vector(vector<VarPtr>&& vps) {
|
|||
return a;
|
||||
}
|
||||
|
||||
// @pyjt(ternary_out_hint)
|
||||
VarHolder* ternary_out_hint(VarHolder* cond, VarHolder* x, VarHolder* y);
|
||||
|
||||
} // jittor
|
|
@ -11,7 +11,7 @@ import jittor as jt
|
|||
import numpy as np
|
||||
|
||||
class TestClone(unittest.TestCase):
|
||||
def test(self):
|
||||
def test_mid_stop_grad(self):
|
||||
jt.clean()
|
||||
b = a = jt.array(1.0)
|
||||
for i in range(10):
|
||||
|
@ -19,8 +19,11 @@ class TestClone(unittest.TestCase):
|
|||
if i==5: c=b
|
||||
b.sync()
|
||||
assert jt.number_of_lived_vars()==11
|
||||
c.name("c")
|
||||
c.stop_grad()
|
||||
assert jt.number_of_lived_vars()==3
|
||||
for n in jt.dump_all_graphs().nodes_info:
|
||||
print(n)
|
||||
assert jt.number_of_lived_vars()==3, jt.number_of_lived_vars()
|
||||
|
||||
def test2(self):
|
||||
a = jt.array([1,2])
|
||||
|
|
|
@ -17,6 +17,7 @@ def expect_error(func):
|
|||
raise Exception("Expect an error, but nothing catched.")
|
||||
|
||||
class TestCore(unittest.TestCase):
|
||||
|
||||
def test_number_of_hold_vars(self):
|
||||
assert jt.random([1,2,3]).peek() == "float32[1,2,3,]"
|
||||
assert jt.core.number_of_hold_vars() == 0
|
||||
|
@ -76,6 +77,7 @@ class TestCore(unittest.TestCase):
|
|||
|
||||
def test_var_holder(self):
|
||||
jt.clean()
|
||||
self.assertEqual(jt.number_of_lived_vars(), 0)
|
||||
expect_error(lambda: jt.matmul(1,1))
|
||||
expect_error(lambda: jt.matmul([1],[1]))
|
||||
expect_error(lambda: jt.matmul([[1]],[1]))
|
||||
|
@ -119,5 +121,103 @@ class TestCore(unittest.TestCase):
|
|||
assert a._parameters['a'] is a.a
|
||||
assert a._parameters['b'] is a.b
|
||||
|
||||
def test_copy_memopt(self):
|
||||
# exe: post run
|
||||
# remove pending done
|
||||
# add hold pending done
|
||||
# pending release mem done
|
||||
a = jt.rand(10)
|
||||
b = a.copy().copy().copy()
|
||||
a.name("aa")
|
||||
b.name("bb")
|
||||
|
||||
cnt = 0
|
||||
graphs = jt.dump_all_graphs()
|
||||
for x in graphs.nodes_info:
|
||||
if "Var" not in x: continue
|
||||
print(x)
|
||||
if ",aa," in x:
|
||||
assert ":2:i" in x, x
|
||||
elif ",bb," in x:
|
||||
assert ":1:i" in x
|
||||
else:
|
||||
assert ":1:i" in x
|
||||
|
||||
b.sync()
|
||||
cnt = 0
|
||||
graphs = jt.dump_all_graphs()
|
||||
for x in graphs.nodes_info:
|
||||
# print(x)
|
||||
if "Var" in x and ",0)" in x:
|
||||
cnt += 1
|
||||
assert cnt == 2
|
||||
|
||||
def test_fuse_memopt(self):
|
||||
def check():
|
||||
a = jt.rand(10)
|
||||
b = (a.copy().name("copy_out1") + 1).sqr() + a.copy().name("copy_out2")
|
||||
b.sync()
|
||||
for n in jt.dump_all_graphs().nodes_info:
|
||||
if "Var" not in n: continue
|
||||
# print(n)
|
||||
|
||||
if "copy_out1" in n:
|
||||
# copy out1 is not free
|
||||
assert ",0)" not in n
|
||||
if "copy_out2" in n:
|
||||
# copy out2 is freeed
|
||||
assert ",0)" in n
|
||||
da = jt.grad(b, a)
|
||||
da.sync()
|
||||
check()
|
||||
jt.gc()
|
||||
assert jt.liveness_info()['lived_vars'] == 0
|
||||
|
||||
def test_out_hint1(self):
|
||||
a = jt.rand(10)
|
||||
b = jt.rand(10)
|
||||
c = jt.ternary_out_hint((a<b).out_hint(), a, b).clone()
|
||||
c.sync()
|
||||
da, db = jt.grad(c, [a, b])
|
||||
jt.sync_all()
|
||||
for n in jt.dump_all_graphs().nodes_info:
|
||||
if "Var" in n and "bool" in n:
|
||||
print(n)
|
||||
assert ",0)" not in n
|
||||
jt.ternary_out_hint((a<b).out_hint(), a, 0).sync()
|
||||
|
||||
def test_out_hint2(self):
|
||||
a = jt.rand(10)
|
||||
b = jt.rand(10)
|
||||
c = jt.ternary(a<b, a, b).clone()
|
||||
# c.sync()
|
||||
da, db = jt.grad(c, [a, b])
|
||||
jt.sync_all()
|
||||
for n in jt.dump_all_graphs().nodes_info:
|
||||
if "Var" in n and "bool" in n:
|
||||
print(n)
|
||||
assert ",0)" not in n
|
||||
|
||||
def test_relu_memopt(self):
|
||||
x = a = jt.rand(10,10)
|
||||
for i in range(10):
|
||||
# a = jt.nn.relu(a)
|
||||
a = jt.ternary_out_hint((a>0.0).name("b"+str(i)), a, 0.0)
|
||||
a = jt.matmul(a.name("m1"),jt.rand(10,10).name("m2")).name("m3-"+str(i))
|
||||
da = jt.grad(a, x, True)
|
||||
# jt.clean_graph()
|
||||
da.sync()
|
||||
cnt1 = 0
|
||||
cnt2 = 0
|
||||
for n in jt.dump_all_graphs().nodes_info:
|
||||
if "Var" in n and ",0)" not in n:
|
||||
cnt1 +=1
|
||||
if "bool" in n:
|
||||
cnt2 += 1
|
||||
print(cnt1, cnt2)
|
||||
assert cnt2 == 10
|
||||
assert cnt1 <= 33, cnt1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -55,6 +55,7 @@ class TestExample(unittest.TestCase):
|
|||
|
||||
model = Model(input_size=1)
|
||||
ps = model.parameters()
|
||||
for p in reversed(ps): p.sync(0,0)
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x).name("pred_y")
|
||||
|
|
|
@ -63,6 +63,7 @@ class TestExample(unittest.TestCase):
|
|||
|
||||
model = Model(input_size=1)
|
||||
ps = model.parameters()
|
||||
for p in reversed(ps): p.sync(0,0)
|
||||
opt = Optimizer(ps, lr)
|
||||
all_loss = 0
|
||||
|
||||
|
|
|
@ -258,6 +258,7 @@ class TestFunction(unittest.TestCase):
|
|||
g = jt.grad(c+d*3, [a, b])
|
||||
test()
|
||||
jt.clean()
|
||||
jt.dump_all_graphs()
|
||||
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
|
||||
|
||||
@unittest.skipIf(True, "skip memleak test")
|
||||
|
|
|
@ -78,9 +78,11 @@ class TestFusedOp(unittest.TestCase):
|
|||
def test_add(self):
|
||||
jt.clean()
|
||||
def check(hv, lv, lo):
|
||||
self.assertEqual(jt.number_of_hold_vars(), hv)
|
||||
self.assertEqual(jt.number_of_lived_vars(), lv)
|
||||
self.assertEqual(jt.number_of_lived_ops(), lo)
|
||||
self.assertEqual((
|
||||
jt.number_of_hold_vars(),
|
||||
jt.number_of_lived_vars(),
|
||||
jt.number_of_lived_ops()),
|
||||
(hv, lv, lo))
|
||||
for i in range(8):
|
||||
check(0,0,0)
|
||||
a = jt.array(1.0).name('a').stop_fuse()
|
||||
|
@ -88,7 +90,14 @@ class TestFusedOp(unittest.TestCase):
|
|||
c = (b+jt.array(1.0).name('t2').stop_fuse()).name('c')
|
||||
check(3,5,5)
|
||||
graph = jt.dump_all_graphs()
|
||||
# for n in graph.nodes_info:
|
||||
# print(n)
|
||||
self.assertEqual(c.data, 3)
|
||||
graph2 = jt.dump_all_graphs()
|
||||
print("check", i)
|
||||
for n in graph2.nodes_info:
|
||||
print(n)
|
||||
print(jt.liveness_info())
|
||||
check(3,5,2)
|
||||
graph = jt.dump_all_graphs()
|
||||
for node in graph.nodes_info:
|
||||
|
|
|
@ -39,7 +39,6 @@ class TestIndexOp(unittest.TestCase):
|
|||
def test_vary_shape_dep(self):
|
||||
a, = jt.where([1,0,1])
|
||||
b, = a.index_var()
|
||||
assert a.uncertain_shape==[-3] and b.uncertain_shape==[-3]
|
||||
assert (b.data==[0,1]).all()
|
||||
|
||||
def test_vary_shape_dep2(self):
|
||||
|
@ -48,7 +47,6 @@ class TestIndexOp(unittest.TestCase):
|
|||
index0 = index0.broadcast([1,3], dims=[1]) # [[1,1,1],[2,2,2]]
|
||||
index1 = index0.index_var(1) # [[0,1,2],[0,1,2]]
|
||||
b = a.reindex_var([index0, index1])
|
||||
assert b.uncertain_shape==[-3,3]
|
||||
assert (b.data==[[4,5,6],[7,8,9]]).all()
|
||||
assert (index0.data==[[1,1,1],[2,2,2]]).all()
|
||||
assert (index1.data==[[0,1,2],[0,1,2]]).all()
|
||||
|
|
|
@ -130,6 +130,7 @@ class TestMatmul(unittest.TestCase):
|
|||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
model = Model()
|
||||
for p in reversed(model.parameters()): p.sync(0,0)
|
||||
SGD = jt.nn.SGD(model.parameters(), 0.05, 0.9, 0)
|
||||
n = 1000
|
||||
batch_size = 50
|
||||
|
|
|
@ -141,6 +141,13 @@ class TestNode(unittest.TestCase):
|
|||
# noded opt: build(0.44),execute(0.11)
|
||||
# for i in range(20):
|
||||
# run()
|
||||
|
||||
# version 1.3.2.6 retest(laptop)
|
||||
# mode1:
|
||||
# origin 0.296 exec(0.11)
|
||||
# int32flag 0.298 exec(0.11)
|
||||
# add order 0.299 exec(0.11)
|
||||
# rm p1 rule 0.299 exec(0.11)
|
||||
for i in range(20):
|
||||
run()
|
||||
import gc
|
||||
|
|
|
@ -75,10 +75,7 @@ def conv_transpose_naive(x, w):
|
|||
|
||||
|
||||
def is_fused(x):
|
||||
x.name('_x')
|
||||
graph = jt.dump_all_graphs()
|
||||
node_a = [ node for node in graph.nodes_info if ",_x," in node ]
|
||||
return 's0' in node_a[0]
|
||||
return 's0' in x.debug_msg()
|
||||
|
||||
def check_fused(dim):
|
||||
jt.clean()
|
||||
|
|
|
@ -91,7 +91,10 @@ class TestResnetFp32(unittest.TestCase):
|
|||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||
.format(epoch_id, batch_id, 600,1. * batch_id / 6.0, loss[0], acc, time.time()-prev))
|
||||
# prev = time.time()
|
||||
# async version
|
||||
jt.fetch(epoch_id, batch_id, loss, output, target, callback)
|
||||
# sync version
|
||||
# callback(epoch_id, batch_id, loss.numpy(), output.numpy(), target.numpy())
|
||||
|
||||
# log_conv = find_log_with_re(logs,
|
||||
# "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
|
||||
|
|
|
@ -66,8 +66,9 @@ def test_ring_buffer():
|
|||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
|
||||
test_send_recv(np.random.rand(10,10))
|
||||
n_byte += 1 + 16 + 2 + 10*10*8
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
n_byte += 1 + 16 + 4 + 10*10*8
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push(), \
|
||||
(n_byte, buffer.total_pop(), n_byte, buffer.total_push())
|
||||
test_send_recv(test_ring_buffer)
|
||||
|
||||
test_send_recv(jt.array(np.random.rand(10,10)))
|
||||
|
|
|
@ -381,6 +381,19 @@ class TestSetitem(unittest.TestCase):
|
|||
for i in range(n):
|
||||
np.testing.assert_allclose(g.data[i*m:(i+1)*m], yy[i].data)
|
||||
|
||||
def test_dfs_memopt(self):
|
||||
with jt.flag_scope(profile_memory_enable=1):
|
||||
n = 1024
|
||||
b = []
|
||||
for i in range(n):
|
||||
a = jt.rand(n).copy().copy()
|
||||
a = a.sum()
|
||||
# a.sync()
|
||||
b.append(a)
|
||||
jt.sync_all()
|
||||
jt.get_max_memory_treemap()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ class TestWhereOp(unittest.TestCase):
|
|||
def test(self):
|
||||
assert (self.where([0,1,0,1])[0].data == [1,3]).all()
|
||||
a, = self.where([0,1,0,1])
|
||||
assert a.uncertain_shape==[-4]
|
||||
assert a.uncertain_shape==[2]
|
||||
a.data
|
||||
assert a.uncertain_shape==[2]
|
||||
a,b = self.where([[0,0,1],[1,0,0]])
|
||||
|
|
Binary file not shown.
Loading…
Reference in New Issue