fix contrib

This commit is contained in:
Dun Liang 2020-06-23 22:48:42 +08:00
parent 5de2aec717
commit 993f73cb4e
8 changed files with 108 additions and 34 deletions

View File

@ -50,7 +50,7 @@ def slice_var_index(x, slices):
slices = (slices,)
if isinstance(slices[0], jt.Var):
if len(slices) == 1 and slices[0].dtype == "bool":
return (slices[0].where(),)
return slice_var_index(x, tuple(slices[0].where()))
bc = []
ml = -1
for idx, s in enumerate(slices):

View File

@ -192,9 +192,9 @@ class BatchNorm(Module):
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean.update(self.running_mean +
(xmean.sum([0,2,3]) - self.running_mean) * self.momentum)
(xmean.reshape((-1,)) - self.running_mean) * self.momentum)
self.running_var.update(self.running_var +
(xvar.sum([0,2,3])-self.running_var)*self.momentum)
(xvar.reshape((-1,))-self.running_var)*self.momentum)
else:
running_mean = self.running_mean.broadcast(x, [0,2,3])
running_var = self.running_var.broadcast(x, [0,2,3])

View File

@ -25,7 +25,9 @@ class TestStopFuse(unittest.TestCase):
jt.sync(dbs+[a])
for a in report[1:]:
assert len(a[0].split("opkey")) < 50
# origin is 50
# after update queue, increase to 102
assert len(a[0].split("opkey")) < 110, len(a[0].split("opkey"))
def test_stop_fuse2(self):
with jt.profile_scope() as report:
@ -43,7 +45,9 @@ class TestStopFuse(unittest.TestCase):
jt.sync(dbs+[a])
for a in report[1:]:
assert len(a[0].split("opkey")) < 8
# origin is 8
# after update queue, increase to 12
assert len(a[0].split("opkey")) < 16, len(a[0].split("opkey"))
if __name__ == "__main__":
unittest.main()

View File

@ -105,7 +105,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
// var_fused represents:
// 0: can fused
// 1: cannot fused
// 2: can shared
// 2: weak shared(may turn into 1 or 3 by shared operator cutting)
// 3: strong shared(force shared)
vector<int> roots, next(op_num, -1);
vector<int> deps(op_num, 0);
roots.reserve(op_num);
@ -176,6 +177,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
sharegraph_q.reserve(16);
vector<int> shared_id(op_num, -1);
// for fused op in reversed order
for (uint rid=0; rid<queue.size(); rid++) {
int root = queue[queue.size()-rid-1];
auto& queue = subgraph;
@ -193,10 +195,13 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
if (fopid == root)
deps[i]++;
else if (shared_id[opid] != root) {
auto& vf = var_fused[v->custom_data];
// var_fused = 1 cannot share input op
// TODO: check this input op's output var all can be shared
if (var_fused[v->custom_data] == 1)
if (vf == 1)
continue;
// if weak share, turn into strong share
if (vf == 2) vf = 3;
// new shared op
deps[opid] = 0;
shared_id[opid] = root;
@ -216,6 +221,15 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
int vi = v->custom_data;
if (var_fused[vi] == 1)
continue;
// if weak share, cut off
if (var_fused[vi] == 2) {
if (sharegraph.size() - sn < 32)
var_fused[vi] = 3;
else {
var_fused[vi] = 1;
continue;
}
}
Op* opi = v->input();
int opid = opi->custom_data;
int& dep = deps[opid];

View File

@ -174,10 +174,20 @@ void count_fuse(int64_t tt, int start_var_num, const vector<Op*>& ops, const vec
var_fused[i]=1;
}
}
if (vf==0) var_fused[i]=1;
if (var_fused[i] && vf &&
(iop->type()==OpType::broadcast || all_reduce || v->flags.get(NodeFlags::_force_fuse)))
var_fused[i]=2;
if (vf==0)
// cannot fused
var_fused[i]=1;
else if (var_fused[i]) {
if (iop->type()==OpType::broadcast ||
all_reduce ||
v->flags.get(NodeFlags::_force_fuse))
// strong fused
var_fused[i] = 3;
else
// weak fused
var_fused[i] = 2;
// var_fused[i] = 3;
}
}
// output vars can not be fused
for (int i=0; i<start_var_num; i++)

View File

@ -41,7 +41,7 @@ std::ostream& operator<<(std::ostream& os, const FloatOutput& o) {
return os << o.suffix;
}
void display_memory_info(const char* fileline) {
void display_memory_info(const char* fileline, bool dump_var) {
int p = 3;
Log log(fileline, 'i', 0);
log << "\n=== display_memory_info ===\n";
@ -52,28 +52,28 @@ void display_memory_info(const char* fileline) {
log << "hold_vars:" << VarHolder::hold_vars.size()
<< "lived_vars:" << Var::number_of_lived_vars
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
log << "update queue:" << update_queue.map.size()
log << "update queue:" << update_queue.queue.size()
>> '/' >> update_queue.map.size() >> '\n';
#ifdef NODE_MEMCHECK
// get the oldest var
vector<Node*> queue;
auto t = ++Node::tflag_count;
for (auto& vh : VarHolder::hold_vars)
if (vh->var->tflag != t) {
vh->var->tflag = t;
queue.push_back(vh->var);
}
bfs_both(queue, [](Node*){return true;});
vector<pair<int64, Node*>> nodes;
nodes.reserve(queue.size());
for (auto* node : queue)
nodes.push_back({node->__id(), node});
std::sort(nodes.begin(), nodes.end());
log << "list of the oldest nodes:\n";
for (int i=0; i<10 && i<nodes.size(); i++) {
log << "ID#" >> nodes[i].first >> ":" << nodes[i].second << "\n";
}
// vector<Node*> queue;
// auto t = ++Node::tflag_count;
// for (auto& vh : VarHolder::hold_vars)
// if (vh->var->tflag != t) {
// vh->var->tflag = t;
// queue.push_back(vh->var);
// }
// bfs_both(queue, [](Node*){return true;});
// vector<pair<int64, Node*>> nodes;
// nodes.reserve(queue.size());
// for (auto* node : queue)
// nodes.push_back({node->__id(), node});
// std::sort(nodes.begin(), nodes.end());
// log << "list of the oldest nodes:\n";
// for (int i=0; i<10 && i<nodes.size(); i++) {
// log << "ID#" >> nodes[i].first >> ":" << nodes[i].second << "\n";
// }
#endif
if (use_stat_allocator) {
@ -81,10 +81,15 @@ void display_memory_info(const char* fileline) {
log << "total alloc:" << FloatOutput{(double)(stat_allocator_total_alloc_byte
- stat_allocator_total_free_byte), " KMG", 1024, "B"};
log << "total alloc call:" << FloatOutput{(double)(stat_allocator_total_alloc_call
- stat_allocator_total_free_call), " KMG", 1000, ""} >> '\n';
- stat_allocator_total_free_call), " KMG", 1000, ""}
>> '(' >> stat_allocator_total_alloc_call >> '/' >>
stat_allocator_total_free_call >> ")\n";
}
int64 all_total = 0, gpu_total = 0, cpu_total = 0;
for (auto& a : SFRLAllocator::sfrl_allocators) {
auto total = a->used_memory + a->unused_memory;
all_total += total;
a->is_cuda() ? gpu_total += total : cpu_total += total;
log << "name:" << a->name() << "is_cuda:" << a->is_cuda()
<< "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)"
@ -92,6 +97,47 @@ void display_memory_info(const char* fileline) {
>> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"
<< "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n";
}
log << "cpu&gpu:" << FloatOutput{(double)all_total, " KMG", 1024, "B"}
<< "gpu:" << FloatOutput{(double)gpu_total, " KMG", 1024, "B"}
<< "cpu:" << FloatOutput{(double)cpu_total, " KMG", 1024, "B"} >> '\n';
if (dump_var) {
vector<Node*> queue;
unordered_set<Node*> visited;
for (auto& vh : VarHolder::hold_vars)
if (!visited.count(vh->var)) {
queue.push_back(vh->var);
visited.insert(vh->var);
}
int64 cum = 0;
for (int i=0; i<queue.size(); i++) {
for (auto* n : queue[i]->inputs())
if (!visited.count(n)) {
queue.push_back(n);
visited.insert(n);
}
for (auto* n : queue[i]->outputs())
if (!visited.count(n)) {
queue.push_back(n);
visited.insert(n);
}
if (queue[i]->is_var()) {
auto v = (Var*)queue[i];
if (v->size>=0 && v->mem_ptr) {
cum += v->size;
log << FloatOutput{(double)v->size, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> v->size*100.0 / all_total >> "%)"
<< FloatOutput{(double)cum, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> cum*100.0 / all_total >> "%)"
<< v >> "\n";
if (v->size == 100*64*112*112*4) {
for (auto op : v->outputs())
log << "\t" << op << '\n';
}
}
}
}
}
log >> "===========================\n";
log.end();
}

View File

@ -9,7 +9,7 @@
namespace jittor {
// @pyjt(display_memory_info)
void display_memory_info(const char* fileline="");
void display_memory_info(const char* fileline="", bool dump_var=false);
// @pyjt(MemInfo)
struct MemInfo {

View File

@ -101,7 +101,7 @@ void UpdateQueue::auto_flush() {
vector<Var*> vars;
vars.reserve(queue.size());
for (auto& l : queue) {
while (l.size() && l.size() >= update_queue_auto_flush_depth) {
while (l.size() && l.size() >= update_queue_auto_flush_delay) {
auto iter = l.end(); iter--;
auto v = iter->v;
vars.push_back(v);
@ -128,7 +128,7 @@ void UpdateQueue::push(Var* v, Var* prev) {
queue.emplace_front();
owner = queue.begin();
}
if (owner->size() >= update_queue_auto_flush_depth)
if (owner->size() >= update_queue_auto_flush_delay)
auto_flush();
owner->emplace_front(UpdateQueue::Item{owner, v});
map[v] = owner->begin();