mirror of https://github.com/Jittor/Jittor
add node_order control execute order
This commit is contained in:
parent
3e6fb4cad8
commit
88bb84255f
|
@ -1526,7 +1526,8 @@ def jittor_exit():
|
|||
elif hooks.exception is not None:
|
||||
pass
|
||||
else:
|
||||
core.sync_all(True)
|
||||
pass
|
||||
# core.sync_all(True)
|
||||
core.cleanup()
|
||||
atexit.register(jittor_exit)
|
||||
|
||||
|
|
|
@ -1146,7 +1146,10 @@ if os.name == 'nt':
|
|||
return cmd
|
||||
|
||||
if ' -O' not in cc_flags:
|
||||
opt_flags += " -O2 "
|
||||
if os.environ.get("debug", "0") == "1":
|
||||
opt_flags += " -O0 "
|
||||
else:
|
||||
opt_flags += " -O2 "
|
||||
kernel_opt_flags += " -Ofast "
|
||||
lto_flags = ""
|
||||
if os.environ.get("enable_lto") == "1":
|
||||
|
|
|
@ -98,16 +98,39 @@ class Optimizer(object):
|
|||
def zero_grad(self):
|
||||
self.__zero_grad = True
|
||||
|
||||
def pre_step(self, loss, retain_graph=False):
|
||||
""" something should be done before step, such as calc gradients, mpi sync, and so on.
|
||||
def backward(self, loss, retain_graph=False):
|
||||
'''
|
||||
optimize.backward(loss) is used for accumulate multiple step,
|
||||
it can be used as following:
|
||||
|
||||
Example::
|
||||
Origin source code ::
|
||||
|
||||
class MyOptimizer(Optimizer):
|
||||
def step(self, loss):
|
||||
self.post_step(loss)
|
||||
...
|
||||
"""
|
||||
n_iter = 10000
|
||||
batch_size = 100
|
||||
...
|
||||
for i in range(n_iter):
|
||||
...
|
||||
loss = calc_loss()
|
||||
optimizer.step(loss)
|
||||
|
||||
Accumulation version ::
|
||||
|
||||
n_iter = 10000
|
||||
batch_size = 100
|
||||
accumulation_steps = 10
|
||||
n_iter *= accumulation_steps
|
||||
batch_size //= accumulation_steps
|
||||
...
|
||||
for i in range(n_iter):
|
||||
...
|
||||
loss = calc_loss()
|
||||
# if loss is a mean across batch, we need to divide accumulation_steps
|
||||
optimizer.backward(loss / accumulation_steps)
|
||||
if (i+1) % accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
|
||||
|
||||
'''
|
||||
# clean prev grads
|
||||
params = []
|
||||
params_has_grad = []
|
||||
|
@ -117,6 +140,9 @@ class Optimizer(object):
|
|||
if not p.is_stop_grad():
|
||||
params_has_grad.append(p)
|
||||
|
||||
# sync prev params
|
||||
jt.sync(params_has_grad)
|
||||
|
||||
# get gradient
|
||||
grads = jt.grad(loss, params_has_grad, retain_graph)
|
||||
|
||||
|
@ -153,50 +179,44 @@ class Optimizer(object):
|
|||
pid += 1
|
||||
self.__zero_grad = False
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
'''
|
||||
optimize.backward(loss) is used for accumulate multiple step,
|
||||
it can be used as following:
|
||||
def pre_step(self, loss, retain_graph=False):
|
||||
""" something should be done before step, such as calc gradients, mpi sync, and so on.
|
||||
|
||||
Origin source code ::
|
||||
Example::
|
||||
|
||||
n_iter = 10000
|
||||
batch_size = 100
|
||||
...
|
||||
for i in range(n_iter):
|
||||
...
|
||||
loss = calc_loss()
|
||||
optimizer.step(loss)
|
||||
class MyOptimizer(Optimizer):
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
...
|
||||
self.post_step(loss)
|
||||
"""
|
||||
if loss is not None:
|
||||
self.backward(loss, retain_graph)
|
||||
jt.flags.node_order = 1
|
||||
|
||||
Accumulation version ::
|
||||
def post_step(self):
|
||||
""" something should be done before step, such as zero grad, and so on.
|
||||
|
||||
n_iter = 10000
|
||||
batch_size = 100
|
||||
accumulation_steps = 10
|
||||
n_iter *= accumulation_steps
|
||||
batch_size //= accumulation_steps
|
||||
...
|
||||
for i in range(n_iter):
|
||||
...
|
||||
loss = calc_loss()
|
||||
# if loss is a mean across batch, we need to divide accumulation_steps
|
||||
optimizer.backward(loss / accumulation_steps)
|
||||
if (i+1) % accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
Example::
|
||||
|
||||
class MyOptimizer(Optimizer):
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
...
|
||||
self.post_step(loss)
|
||||
"""
|
||||
jt.flags.node_order = 0
|
||||
self.zero_grad()
|
||||
|
||||
'''
|
||||
self.pre_step(loss, retain_graph)
|
||||
|
||||
def step(self, loss=None, retain_graph=False):
|
||||
if loss is not None:
|
||||
self.pre_step(loss, retain_graph)
|
||||
self.pre_step(loss, retain_graph)
|
||||
for pg in self.param_groups:
|
||||
lr = pg.get("lr", self.lr)
|
||||
for p, g in zip(pg["params"], pg["grads"]):
|
||||
if p.is_stop_grad(): continue
|
||||
p.update(p - g * lr)
|
||||
self.zero_grad()
|
||||
self.post_step()
|
||||
|
||||
def _build_grad_map(self):
|
||||
_grad_map = {}
|
||||
|
@ -255,9 +275,9 @@ class SGD(Optimizer):
|
|||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
self.param_groups.append(group)
|
||||
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
def step(self, loss=None, retain_graph=False):
|
||||
self.pre_step(loss, retain_graph=False)
|
||||
jt.flags.node_order = 1
|
||||
for pg in self.param_groups:
|
||||
# get arguments from each param_groups
|
||||
lr = pg.get("lr", self.lr)
|
||||
|
@ -275,7 +295,7 @@ class SGD(Optimizer):
|
|||
p.update(p - (dp + momentum * v) * lr)
|
||||
else:
|
||||
p.update(p - v * lr)
|
||||
self.zero_grad()
|
||||
self.post_step()
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
""" RMSprop Optimizer.
|
||||
|
@ -306,9 +326,8 @@ class RMSprop(Optimizer):
|
|||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
self.param_groups.append(group)
|
||||
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
def step(self, loss=None, retain_graph=False):
|
||||
self.pre_step(loss, retain_graph)
|
||||
for pg in self.param_groups:
|
||||
# get arguments from each param_groups
|
||||
lr = pg.get("lr", self.lr)
|
||||
|
@ -318,7 +337,7 @@ class RMSprop(Optimizer):
|
|||
if p.is_stop_grad(): continue
|
||||
v.update(alpha * v + (1-alpha) * g * g)
|
||||
p.update(p - lr * g / (jt.sqrt(v) + eps))
|
||||
self.zero_grad()
|
||||
self.post_step()
|
||||
|
||||
class Adam(Optimizer):
|
||||
""" Adam Optimizer.
|
||||
|
@ -351,10 +370,10 @@ class Adam(Optimizer):
|
|||
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
self.param_groups.append(group)
|
||||
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
def step(self, loss=None, retain_graph=False):
|
||||
self.pre_step(loss, retain_graph)
|
||||
n = float(self.n_step)
|
||||
jt.flags.node_order = 1
|
||||
for pg in self.param_groups:
|
||||
# get arguments from each param_groups
|
||||
lr = pg.get("lr", self.lr)
|
||||
|
@ -368,7 +387,7 @@ class Adam(Optimizer):
|
|||
v.update(b1 * v + (1-b1) * g * g)
|
||||
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
|
||||
p.update(p - m * step_size / (jt.sqrt(v) + eps))
|
||||
self.zero_grad()
|
||||
self.post_step()
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
|
@ -402,9 +421,8 @@ class AdamW(Optimizer):
|
|||
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
self.param_groups.append(group)
|
||||
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
def step(self, loss=None, retain_graph=False):
|
||||
self.pre_step(loss, retain_graph)
|
||||
n = float(self.n_step)
|
||||
for pg in self.param_groups:
|
||||
# get arguments from each param_groups
|
||||
|
@ -422,7 +440,7 @@ class AdamW(Optimizer):
|
|||
denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps
|
||||
step_size = lr / bias_correction1
|
||||
p.update(p - step_size * m / denom)
|
||||
self.zero_grad()
|
||||
self.post_step()
|
||||
|
||||
|
||||
class LRScheduler:
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
// ***************************************************************
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <queue>
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include "helper_cuda.h"
|
||||
|
@ -209,11 +210,13 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
|||
// get all nodes need to be executed
|
||||
int need_opt = 0;
|
||||
auto t = ++Node::tflag_count;
|
||||
int64 max_id = 0;
|
||||
for (Var* v : vars)
|
||||
if (!v->is_finished() && v->tflag != t) {
|
||||
v->tflag = t;
|
||||
start_var_num++;
|
||||
bfs_q.push_back(v);
|
||||
max_id = std::max(max_id, v->id);
|
||||
}
|
||||
for (int i=0; i<bfs_q.size(); i++) {
|
||||
auto node = bfs_q[i];
|
||||
|
@ -225,12 +228,14 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
|||
bfs_q.push_back(i.node);
|
||||
}
|
||||
// this var has been fetched
|
||||
if (node->flags.get(NodeFlags::_fetch)) {
|
||||
if (weak_sync || node->flags.get(NodeFlags::_fetch)) {
|
||||
for (auto& n : node->_outputs) {
|
||||
// if not in queue and is fetch op
|
||||
if (n.node->tflag != t &&
|
||||
n.node->pending_liveness &&
|
||||
!n.node->is_finished() &&
|
||||
n.node->flags.get(NodeFlags::_fetch)) {
|
||||
(n.node->id <= max_id ||
|
||||
n.node->flags.get(NodeFlags::_fetch))) {
|
||||
n.node->tflag = t;
|
||||
need_opt += n.node->flags.get(NodeFlags::_has_gopt);
|
||||
bfs_q.push_back(n.node);
|
||||
|
@ -330,7 +335,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
|||
{
|
||||
// queue.clear();
|
||||
#ifndef JT_bfs_executor
|
||||
map<int64, int> p_queue;
|
||||
std::priority_queue<pair<int64,int64>> p_queue;
|
||||
#endif
|
||||
for (int root : roots) {
|
||||
for (int i=root; i>=0; i=next[i]) {
|
||||
|
@ -349,7 +354,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
|||
queue.push_back(root);
|
||||
#else
|
||||
if (deps[root] == 0)
|
||||
p_queue[ops[root]->id] = root;
|
||||
p_queue.emplace(-ops[root]->order(), root);
|
||||
#endif
|
||||
}
|
||||
#ifdef JT_bfs_executor
|
||||
|
@ -361,8 +366,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
|||
#ifdef JT_bfs_executor
|
||||
int op_id = queue[s];
|
||||
#else
|
||||
int op_id = p_queue.begin()->second;
|
||||
p_queue.erase(p_queue.begin());
|
||||
int op_id = p_queue.top().second;
|
||||
p_queue.pop();
|
||||
queue.push_back(op_id);
|
||||
#endif
|
||||
for (int i=op_id; i>=0; i=next[i]) {
|
||||
|
@ -382,7 +387,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
|||
queue.push_back(op2_id);
|
||||
#else
|
||||
if (deps[op2_id] == 0)
|
||||
p_queue[op2->id] = op2_id;
|
||||
p_queue.emplace(-op2->order(), op2_id);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ namespace jittor {
|
|||
|
||||
DEFINE_FLAG(int, check_graph, 0, "Unify graph sanity check.");
|
||||
|
||||
EXTERN_LIB unordered_map<void*, int64> lived_nodes;
|
||||
|
||||
template <typename T>
|
||||
string ss_convert(T x) {
|
||||
|
|
|
@ -29,7 +29,6 @@
|
|||
#include "mem/allocator/stat_allocator.h"
|
||||
#include "mem/allocator/temp_allocator.h"
|
||||
#include "mem/mem_info.h"
|
||||
#include "update_queue.h"
|
||||
#include "executor.h"
|
||||
|
||||
namespace jittor {
|
||||
|
@ -67,8 +66,6 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
|
|||
log << "hold_vars:" << hold_vars.size()
|
||||
<< "lived_vars:" << Var::number_of_lived_vars
|
||||
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
|
||||
log << "update queue:" << update_queue.queue.size()
|
||||
>> '/' >> update_queue.map.size() >> '\n';
|
||||
|
||||
#ifdef NODE_MEMCHECK
|
||||
// get the oldest var
|
||||
|
|
|
@ -13,9 +13,14 @@
|
|||
namespace jittor {
|
||||
|
||||
EXTERN_LIB unordered_map<void*, int64> lived_nodes;
|
||||
EXTERN_LIB unordered_map<int64, Node*> lived_nodes_id;
|
||||
EXTERN_LIB int64 total_node;
|
||||
EXTERN_LIB int64 nt;
|
||||
EXTERN_LIB vector<Node*> free_buffer;
|
||||
EXTERN_LIB uint8 node_order;
|
||||
|
||||
inline static Node* get_node(int64 id)
|
||||
{ return lived_nodes_id.count(id) ? lived_nodes_id[id] : nullptr; }
|
||||
|
||||
struct NodeFlags {
|
||||
typedef uint32 nf_t;
|
||||
|
@ -29,12 +34,14 @@ struct NodeFlags {
|
|||
_stop_grad=2,
|
||||
// bit3: is fetch
|
||||
_fetch=3,
|
||||
_n=4,
|
||||
// bit4: node order low
|
||||
_node_order_low=4,
|
||||
_node_order_high=5,
|
||||
_n=6,
|
||||
|
||||
// var related flags
|
||||
_force_fuse=_n+0,
|
||||
_stop_fuse=_n+1,
|
||||
_in_update_queue=_n+2,
|
||||
_needed_by_backward=_n+3,
|
||||
_out_hint=_n+4,
|
||||
|
||||
|
@ -129,19 +136,28 @@ struct Node {
|
|||
list<input_t> _inputs;
|
||||
list<output_t> _outputs;
|
||||
|
||||
#ifdef NODE_MEMCHECK
|
||||
inline Node() {
|
||||
lived_nodes[(void*)this] = id = ++total_node;
|
||||
int64 order() {
|
||||
if (flags.get(NodeFlags::_node_order_low)) return 0;
|
||||
if (flags.get(NodeFlags::_node_order_high)) return 1ll<<60;
|
||||
return id;
|
||||
}
|
||||
|
||||
inline virtual ~Node() {
|
||||
inline Node() {
|
||||
id = ++total_node;
|
||||
#ifdef NODE_MEMCHECK
|
||||
lived_nodes_id[id] = this;
|
||||
lived_nodes[(void*)this] = id;
|
||||
#endif
|
||||
flags.set(NodeFlags::_node_order_low, node_order, 2);
|
||||
}
|
||||
inline virtual ~Node() {
|
||||
#ifdef NODE_MEMCHECK
|
||||
lived_nodes_id.erase(id);
|
||||
lived_nodes.erase((void*)this);
|
||||
#endif
|
||||
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);
|
||||
}
|
||||
#else
|
||||
inline Node() { id = ++total_node; };
|
||||
inline virtual ~Node() { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);};
|
||||
#endif
|
||||
|
||||
inline Var* var() { return (Var*)this; }
|
||||
inline Op* op() { return (Op*)this; }
|
||||
inline Node* node() { return this; }
|
||||
|
@ -155,13 +171,6 @@ struct Node {
|
|||
#endif
|
||||
}
|
||||
void memcheck_all_exist() const;
|
||||
inline int64 __id() const {
|
||||
#ifdef NODE_MEMCHECK
|
||||
return lived_nodes.at((void*)this);
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
// release from counter and memory checker
|
||||
void __release();
|
||||
#define CHECK_NODE_EXIST(node) \
|
||||
|
|
|
@ -298,9 +298,6 @@ std::ostream& operator<<(std::ostream& os, const Op* op) {
|
|||
os << "->" << (void*)v;
|
||||
}
|
||||
os << ')';
|
||||
#ifdef NODE_MEMCHECK
|
||||
os << '<' << op->__id() << '>';
|
||||
#endif
|
||||
if (trace_py_var) {
|
||||
os << '{';
|
||||
print_node_trace(op, os);
|
||||
|
|
|
@ -113,8 +113,6 @@ VarPtr BroadcastToOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
if (v_index==1) return nullptr;
|
||||
if (bcast_mask==0) return dout;
|
||||
VarPtr dv = make_reduce(dout, ns_add, bcast_mask, keepdims_mask);
|
||||
if (dv->shape.size() != v->shape.size())
|
||||
dv->shape = v->shape;
|
||||
return dv;
|
||||
}
|
||||
|
||||
|
@ -150,7 +148,7 @@ void BroadcastToOp::infer_shape() {
|
|||
}
|
||||
auto mask = ((xshape==1 && (yshape!=1 || !bx))&1) << i;
|
||||
bcast_mask |= mask;
|
||||
keepdims_mask |= mask;
|
||||
if (bx) keepdims_mask |= mask;
|
||||
int64 zs;
|
||||
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
|
||||
zs = xshape * yshape;
|
||||
|
|
|
@ -39,6 +39,9 @@ CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
|
|||
if (_outputs[0]->num < 0) {
|
||||
check_vary_shape(_outputs[0]->shape);
|
||||
}
|
||||
if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0) {
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -60,6 +63,8 @@ CodeOp::CodeOp(
|
|||
check_vary_shape(_outputs[i]->shape);
|
||||
}
|
||||
}
|
||||
if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0)
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
}
|
||||
|
||||
CodeOp::CodeOp(
|
||||
|
@ -81,6 +86,8 @@ CodeOp::CodeOp(
|
|||
TODO: vary shape not allowed in direct output
|
||||
*/
|
||||
}
|
||||
if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0)
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -245,8 +245,7 @@ void TraceData::record_node(Node* node, bool record_stack) {
|
|||
}
|
||||
} else {
|
||||
}
|
||||
if (node->__id())
|
||||
data.attrs["__id"] = S(node->__id());
|
||||
data.attrs["__id"] = S(node->id);
|
||||
data.attrs["is_var"] = node->is_var() ? "1" : "0";
|
||||
data.attrs["name"] = "unname";
|
||||
node_data[data.id] = move(data);
|
||||
|
|
|
@ -1,152 +0,0 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2022 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "update_queue.h"
|
||||
#include "executor.h"
|
||||
#include "node.h"
|
||||
#include "var.h"
|
||||
#include "var_holder.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
/*
|
||||
|
||||
The update queue is designed to batch update parameters asynchronously.
|
||||
It maintains several queues internally.
|
||||
Each updated parameter corresponds to a queue,
|
||||
and the elements in the queue represent several updates of this parameter.
|
||||
When a parameter is updated,
|
||||
jittor internally updates the previous parameter several times
|
||||
instead of the current parameter.
|
||||
|
||||
update queue 设计用于批量异步更新参数,其内部维护了若干个队列,
|
||||
每一个被更新的参数对应了一个队列,而队列中的元素代表了这个参数
|
||||
的若干次更新。当一个参数被更新,jittor内部会批量更新若干次之前的
|
||||
参数,而不是当前参数。
|
||||
|
||||
below fig shows a async update process
|
||||
|
||||
下图演示了一个异步更新的过程:
|
||||
|
||||
first iter
|
||||
第一次迭代:
|
||||
|
||||
\ iter 0
|
||||
param
|
||||
a 0
|
||||
b 0
|
||||
c 0
|
||||
d 0
|
||||
|
||||
second iter
|
||||
第二次迭代:
|
||||
|
||||
\ iter 0 1
|
||||
params
|
||||
a 0 1
|
||||
b 0 1
|
||||
c 0 1
|
||||
d 0 1
|
||||
|
||||
third iter begin
|
||||
第三次开始时,迭代0的update被执行:
|
||||
\ iter 0 1 2
|
||||
params
|
||||
a [0]1 2
|
||||
b [0]1
|
||||
c [0]1
|
||||
d [0]1
|
||||
|
||||
third iter end
|
||||
第三次结束:
|
||||
|
||||
\ iter 0 1 2
|
||||
params
|
||||
a 1 2
|
||||
b 1 2
|
||||
c 1 2
|
||||
d 1 2
|
||||
|
||||
update_queue_auto_flush_delay: 异步多少个iter更新.
|
||||
|
||||
update queue的提出主要是为了解决统一计算图规模持续增长(lived_var不断变多)的问题,
|
||||
在 update queue 提出之前, 计算图运行是由optimizer负责的,optim.step被调用的
|
||||
时候,会自动运行还没有运行的计算图,已经运行的计算图节点会被回收,从而计算图规模可以
|
||||
在每次迭代之间保持一个常数。
|
||||
|
||||
但是如果用户并没有调用optim.step进行更新,计算图就会持续增长,比如下面两种情况:
|
||||
|
||||
* 训练 GAN 的时候,只用 SGD 运行了 generator,没有用SGD 运行 discriminator,
|
||||
discriminator 的 batch norm 参数持续不断地更新,但是一直没有运行,导致计算图
|
||||
规模持续增长。
|
||||
* 用户在 inference 的时候忘记设置 model.eval, 这时候因为没有 SGD 刷新参数,
|
||||
然后 batch norm 的参数持续不断更新,再次导致计算图规模持续增长。
|
||||
|
||||
这些细节对于用户来说过于难以理解(LD:我有时候都很晕),一个粗暴的解决方案是 jt.sync_all,
|
||||
直接强制刷新全图,把没运行的都运行了,但是这会导致显存占用过大,因为 sync_all 运行的
|
||||
拓扑顺序不优。
|
||||
|
||||
为了让用户可以不关心这些细节, 我们在参数更新的时候,使用 var.update(new_var),
|
||||
这个接口会把更新托管给 update queue, 从而不需要关心底层计算图的大小。
|
||||
|
||||
*/
|
||||
|
||||
DEFINE_FLAG(int, update_queue_auto_flush_delay, 2, "when size of a update queue is great than this value, update queue trigger auto flush(default 2).");
|
||||
|
||||
UpdateQueue update_queue;
|
||||
|
||||
void UpdateQueue::auto_flush() {
|
||||
vector<Var*> vars;
|
||||
vars.reserve(queue.size());
|
||||
for (auto& l : queue) {
|
||||
while (l.size() && l.size() >= update_queue_auto_flush_delay) {
|
||||
auto iter = l.end(); iter--;
|
||||
auto v = iter->v;
|
||||
vars.push_back(v);
|
||||
map.erase(v);
|
||||
v->flags.set(NodeFlags::_in_update_queue, 0);
|
||||
l.pop_back();
|
||||
}
|
||||
}
|
||||
LOGvv << "auto flush var size" << vars.size();
|
||||
exe.run_sync(move(vars), false);
|
||||
}
|
||||
|
||||
void UpdateQueue::push(Var* v, Var* prev) {
|
||||
if (v->flags.get(NodeFlags::_in_update_queue))
|
||||
return;
|
||||
list<list<Item>>::iterator owner;
|
||||
|
||||
if (prev->flags.get(NodeFlags::_in_update_queue)) {
|
||||
auto iter = map.find(prev);
|
||||
ASSERT(iter != map.end());
|
||||
owner = iter->second->owner;
|
||||
} else {
|
||||
queue.emplace_front();
|
||||
owner = queue.begin();
|
||||
}
|
||||
if (owner->size() >= update_queue_auto_flush_delay) {
|
||||
auto_flush();
|
||||
}
|
||||
v->flags.set(NodeFlags::_in_update_queue);
|
||||
owner->emplace_front(UpdateQueue::Item{owner, v});
|
||||
map[v] = owner->begin();
|
||||
// if total size of update queue is too big,
|
||||
// force sync all
|
||||
if (map.size() > 100000)
|
||||
sync_all();
|
||||
}
|
||||
|
||||
void UpdateQueue::pop(Var* v) {
|
||||
auto iter = map.find(v);
|
||||
iter->second->owner->erase(iter->second);
|
||||
if (iter->second->owner->size() == 0)
|
||||
queue.erase(iter->second->owner);
|
||||
map.erase(iter);
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2022 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct UpdateQueue {
|
||||
struct Item {
|
||||
list<list<Item>>::iterator owner;
|
||||
Var* v;
|
||||
};
|
||||
list<list<Item>> queue;
|
||||
unordered_map<Var*, list<Item>::iterator> map;
|
||||
|
||||
void push(Var* v, Var* prev);
|
||||
void pop(Var* v);
|
||||
void auto_flush();
|
||||
};
|
||||
|
||||
EXTERN_LIB UpdateQueue update_queue;
|
||||
|
||||
} // jittor
|
||||
|
|
@ -69,7 +69,7 @@ constexpr char yellow[] = "\x1b[1;33m";
|
|||
|
||||
|
||||
inline static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
|
||||
if (level == 'i') {
|
||||
if (level == 'i' || level == 'I') {
|
||||
if (verbose == 0) color_begin = "\x1b[1;32m"; else
|
||||
if (verbose < 10) color_begin = "\x1b[1;32m"; else
|
||||
if (verbose < 100) color_begin = "\x1b[1;32m"; else
|
||||
|
@ -90,7 +90,7 @@ constexpr char red[] = "\033[38;5;1m";
|
|||
constexpr char yellow[] = "\033[38;5;3m";
|
||||
|
||||
inline static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
|
||||
if (level == 'i') {
|
||||
if (level == 'i' || level == 'I') {
|
||||
if (verbose == 0) color_begin = "\033[38;5;2m"; else
|
||||
if (verbose < 10) color_begin = "\033[38;5;250m"; else
|
||||
if (verbose < 100) color_begin = "\033[38;5;244m"; else
|
||||
|
@ -280,9 +280,10 @@ bool check_vlog(const char* fileline, int verbose);
|
|||
#define LOGrrrr LOGvvvv >> jittor::red
|
||||
#define LOGyyyy LOGvvvv >> jittor::yellow
|
||||
|
||||
#define LOGir LOGi >> jittor::red
|
||||
#define LOGig LOGi >> jittor::green
|
||||
#define LOGiy LOGi >> jittor::yellow
|
||||
#define LOGI jittor::LogVoidify() && jittor::Log(__FILELINE__, 'I', 0)
|
||||
#define LOGir LOGI >> jittor::red
|
||||
#define LOGig LOGI >> jittor::green
|
||||
#define LOGiy LOGI >> jittor::yellow
|
||||
|
||||
void system_with_check(const char* cmd, const char* cwd=nullptr);
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
#include "op.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "pybind/py_var_tracer.h"
|
||||
#include "update_queue.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -22,6 +21,7 @@ DEFINE_FLAG(bool, no_grad, 0,
|
|||
"No grad for all jittor Var creation");
|
||||
DEFINE_FLAG(bool, no_fuse, 0,
|
||||
"No fusion optimization for all jittor Var creation");
|
||||
DEFINE_FLAG(uint8, node_order, 0, "id prior");
|
||||
// TODO: fuse multiple flags
|
||||
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
|
||||
|
||||
|
@ -101,11 +101,8 @@ std::ostream& operator<<(std::ostream& os, const Var& var) {
|
|||
<< ":s" << var.is_finished()
|
||||
<< ":n" << var.flags.get(NodeFlags::_needed_by_backward)
|
||||
<< ','
|
||||
<< var.dtype().to_cstring() << ',' << var.name << ',' << std::hex <<(uint64)var.mem_ptr
|
||||
<< var.dtype().to_cstring() << ',' << var.name << ',' << std::hex <<(uint64)var.mem_ptr << std::dec
|
||||
<< ')' << var.shape;
|
||||
#ifdef NODE_MEMCHECK
|
||||
os << '<' << var.__id() << '>';
|
||||
#endif
|
||||
if (trace_py_var) {
|
||||
os << '{';
|
||||
print_node_trace(&var, os);
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
#include "var.h"
|
||||
#include "executor.h"
|
||||
#include "graph.h"
|
||||
#include "update_queue.h"
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
|
@ -27,7 +26,7 @@ list<VarHolder*>::iterator sync_ptr = hold_vars.end();
|
|||
void add_hold_vars(VarHolder* self) {
|
||||
hold_vars.push_front(self);
|
||||
self->iter = hold_vars.begin();
|
||||
if (lazy_execution) return;
|
||||
if (lazy_execution && Op::number_of_lived_ops < 100000) return;
|
||||
auto v = self->var;
|
||||
for (int i=0; i<5; i++) {
|
||||
auto op = v->input();
|
||||
|
@ -115,19 +114,15 @@ VarHolder* VarHolder::assign(VarHolder* v) {
|
|||
}
|
||||
|
||||
VarHolder* VarHolder::update(VarHolder* v) {
|
||||
auto dv = jittor::detach(v->var);
|
||||
update_queue.push(dv.ptr, var);
|
||||
*this = move(dv);
|
||||
return this;
|
||||
v->var->flags.set(NodeFlags::_out_hint);
|
||||
return assign(v);
|
||||
}
|
||||
|
||||
VarHolder* VarHolder::_update(VarHolder* v) {
|
||||
auto dv = jittor::detach(v->var);
|
||||
if (var->flags.get(NodeFlags::_in_update_queue))
|
||||
update_queue.push(dv.ptr, var);
|
||||
v->var->own_both_liveness();
|
||||
var->release_both_liveness();
|
||||
var = dv.ptr;
|
||||
dv.ptr = nullptr;
|
||||
var = v->var;
|
||||
var->flags.set(NodeFlags::_out_hint);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
|
|
@ -204,7 +204,8 @@ struct VarHolder {
|
|||
inline VarHolder* start_grad() {
|
||||
if (!var->dtype().is_float())
|
||||
LOGw << "cannot enable grad of a non-float value:" << var;
|
||||
_update(this);
|
||||
auto dvar = jittor::detach(var);
|
||||
std::swap(dvar.ptr, var);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
|
|
@ -40,6 +40,13 @@ class TestACL(unittest.TestCase):
|
|||
y = jt.float32(x)
|
||||
np.testing.assert_allclose(x, y.numpy())
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_rand(self):
|
||||
a = jt.rand(10)
|
||||
b = a*10
|
||||
b.sync()
|
||||
print(b)
|
||||
|
||||
def test_meminfo(self):
|
||||
jt.display_memory_info()
|
||||
|
||||
|
|
|
@ -194,14 +194,11 @@ class TestArray(unittest.TestCase):
|
|||
np.testing.assert_allclose(a, c)
|
||||
|
||||
def test_scalar_fuse_unary(self):
|
||||
c = jt.ones(10)
|
||||
jt.sync_all()
|
||||
with jt.profile_scope() as rep:
|
||||
a = jt.array([1])
|
||||
b = -a
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
jt.sync([a, b])
|
||||
assert a.data == 1
|
||||
assert b.data == -1
|
||||
b = c-1
|
||||
assert b.data[1] == 0
|
||||
assert len(rep) == 2
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
|
|
|
@ -218,6 +218,32 @@ class TestCore(unittest.TestCase):
|
|||
assert cnt2 == 10
|
||||
assert cnt1 <= 33, cnt1
|
||||
|
||||
def test_node_order(self):
|
||||
a = jt.nn.Sequential()
|
||||
for i in range(10):
|
||||
a.append(jt.nn.Linear(10,10, bias=False))
|
||||
sgd = jt.optim.SGD(a.parameters(), 0.1)
|
||||
jt.sync_all()
|
||||
with jt.log_capture_scope(log_silent=1,
|
||||
log_vprefix="exe=100") as logs:
|
||||
x = jt.rand(3,10)
|
||||
y = a(x)
|
||||
sgd.step(y*y)
|
||||
jt.sync_all()
|
||||
orders = []
|
||||
for l in logs:
|
||||
msg = l["msg"]
|
||||
if "Finished" in msg:
|
||||
# print(msg)
|
||||
if "weight" in msg:
|
||||
assert msg.count("Var") >= 2
|
||||
order = int(msg.split('fused ')[1].split("/")[0])
|
||||
# print(order)
|
||||
orders.append(order)
|
||||
assert len(orders) == 10, orders
|
||||
for i in range(10):
|
||||
assert orders[i] <= 14+i*3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -85,14 +85,14 @@ class TestFusedOp(unittest.TestCase):
|
|||
(hv, lv, lo))
|
||||
for i in range(8):
|
||||
check(0,0,0)
|
||||
a = jt.array(1.0).name('a').stop_fuse()
|
||||
b = (a+jt.array(1.0).name('t1').stop_fuse()).name('b')
|
||||
c = (b+jt.array(1.0).name('t2').stop_fuse()).name('c')
|
||||
a = jt.array([1.0,1.0]).name('a').stop_fuse()
|
||||
b = (a+jt.array([1.0,1.0]).name('t1').stop_fuse()).name('b')
|
||||
c = (b+jt.array([1.0,1.0]).name('t2').stop_fuse()).name('c')
|
||||
check(3,5,5)
|
||||
graph = jt.dump_all_graphs()
|
||||
# for n in graph.nodes_info:
|
||||
# print(n)
|
||||
self.assertEqual(c.data, 3)
|
||||
np.testing.assert_allclose(c.data, [3,3])
|
||||
graph2 = jt.dump_all_graphs()
|
||||
print("check", i)
|
||||
for n in graph2.nodes_info:
|
||||
|
|
|
@ -83,6 +83,14 @@ class TestTransposeOp(unittest.TestCase):
|
|||
b = a.transpose((1,0))
|
||||
b.sync()
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cutt_bug(self):
|
||||
a = jt.rand(640000,4,3)
|
||||
b = a.transpose(0,2,1)
|
||||
b.sync(True)
|
||||
print(a.shape, b.shape)
|
||||
|
||||
|
||||
class TestFuseTransposeOp(unittest.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue