mirror of https://github.com/Jittor/Jittor
fix contrib
This commit is contained in:
parent
5de2aec717
commit
993f73cb4e
|
@ -50,7 +50,7 @@ def slice_var_index(x, slices):
|
|||
slices = (slices,)
|
||||
if isinstance(slices[0], jt.Var):
|
||||
if len(slices) == 1 and slices[0].dtype == "bool":
|
||||
return (slices[0].where(),)
|
||||
return slice_var_index(x, tuple(slices[0].where()))
|
||||
bc = []
|
||||
ml = -1
|
||||
for idx, s in enumerate(slices):
|
||||
|
|
|
@ -192,9 +192,9 @@ class BatchNorm(Module):
|
|||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0,2,3]) - self.running_mean) * self.momentum)
|
||||
(xmean.reshape((-1,)) - self.running_mean) * self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0,2,3])-self.running_var)*self.momentum)
|
||||
(xvar.reshape((-1,))-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
|
|
|
@ -25,7 +25,9 @@ class TestStopFuse(unittest.TestCase):
|
|||
jt.sync(dbs+[a])
|
||||
|
||||
for a in report[1:]:
|
||||
assert len(a[0].split("opkey")) < 50
|
||||
# origin is 50
|
||||
# after update queue, increase to 102
|
||||
assert len(a[0].split("opkey")) < 110, len(a[0].split("opkey"))
|
||||
|
||||
def test_stop_fuse2(self):
|
||||
with jt.profile_scope() as report:
|
||||
|
@ -43,7 +45,9 @@ class TestStopFuse(unittest.TestCase):
|
|||
jt.sync(dbs+[a])
|
||||
|
||||
for a in report[1:]:
|
||||
assert len(a[0].split("opkey")) < 8
|
||||
# origin is 8
|
||||
# after update queue, increase to 12
|
||||
assert len(a[0].split("opkey")) < 16, len(a[0].split("opkey"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -105,7 +105,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
// var_fused represents:
|
||||
// 0: can fused
|
||||
// 1: cannot fused
|
||||
// 2: can shared
|
||||
// 2: weak shared(may turn into 1 or 3 by shared operator cutting)
|
||||
// 3: strong shared(force shared)
|
||||
vector<int> roots, next(op_num, -1);
|
||||
vector<int> deps(op_num, 0);
|
||||
roots.reserve(op_num);
|
||||
|
@ -176,6 +177,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
sharegraph_q.reserve(16);
|
||||
vector<int> shared_id(op_num, -1);
|
||||
|
||||
// for fused op in reversed order
|
||||
for (uint rid=0; rid<queue.size(); rid++) {
|
||||
int root = queue[queue.size()-rid-1];
|
||||
auto& queue = subgraph;
|
||||
|
@ -193,10 +195,13 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
if (fopid == root)
|
||||
deps[i]++;
|
||||
else if (shared_id[opid] != root) {
|
||||
auto& vf = var_fused[v->custom_data];
|
||||
// var_fused = 1 cannot share input op
|
||||
// TODO: check this input op's output var all can be shared
|
||||
if (var_fused[v->custom_data] == 1)
|
||||
if (vf == 1)
|
||||
continue;
|
||||
// if weak share, turn into strong share
|
||||
if (vf == 2) vf = 3;
|
||||
// new shared op
|
||||
deps[opid] = 0;
|
||||
shared_id[opid] = root;
|
||||
|
@ -216,6 +221,15 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
int vi = v->custom_data;
|
||||
if (var_fused[vi] == 1)
|
||||
continue;
|
||||
// if weak share, cut off
|
||||
if (var_fused[vi] == 2) {
|
||||
if (sharegraph.size() - sn < 32)
|
||||
var_fused[vi] = 3;
|
||||
else {
|
||||
var_fused[vi] = 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Op* opi = v->input();
|
||||
int opid = opi->custom_data;
|
||||
int& dep = deps[opid];
|
||||
|
|
18
src/fuser.cc
18
src/fuser.cc
|
@ -174,10 +174,20 @@ void count_fuse(int64_t tt, int start_var_num, const vector<Op*>& ops, const vec
|
|||
var_fused[i]=1;
|
||||
}
|
||||
}
|
||||
if (vf==0) var_fused[i]=1;
|
||||
if (var_fused[i] && vf &&
|
||||
(iop->type()==OpType::broadcast || all_reduce || v->flags.get(NodeFlags::_force_fuse)))
|
||||
var_fused[i]=2;
|
||||
if (vf==0)
|
||||
// cannot fused
|
||||
var_fused[i]=1;
|
||||
else if (var_fused[i]) {
|
||||
if (iop->type()==OpType::broadcast ||
|
||||
all_reduce ||
|
||||
v->flags.get(NodeFlags::_force_fuse))
|
||||
// strong fused
|
||||
var_fused[i] = 3;
|
||||
else
|
||||
// weak fused
|
||||
var_fused[i] = 2;
|
||||
// var_fused[i] = 3;
|
||||
}
|
||||
}
|
||||
// output vars can not be fused
|
||||
for (int i=0; i<start_var_num; i++)
|
||||
|
|
|
@ -41,7 +41,7 @@ std::ostream& operator<<(std::ostream& os, const FloatOutput& o) {
|
|||
return os << o.suffix;
|
||||
}
|
||||
|
||||
void display_memory_info(const char* fileline) {
|
||||
void display_memory_info(const char* fileline, bool dump_var) {
|
||||
int p = 3;
|
||||
Log log(fileline, 'i', 0);
|
||||
log << "\n=== display_memory_info ===\n";
|
||||
|
@ -52,28 +52,28 @@ void display_memory_info(const char* fileline) {
|
|||
log << "hold_vars:" << VarHolder::hold_vars.size()
|
||||
<< "lived_vars:" << Var::number_of_lived_vars
|
||||
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
|
||||
log << "update queue:" << update_queue.map.size()
|
||||
log << "update queue:" << update_queue.queue.size()
|
||||
>> '/' >> update_queue.map.size() >> '\n';
|
||||
|
||||
#ifdef NODE_MEMCHECK
|
||||
// get the oldest var
|
||||
vector<Node*> queue;
|
||||
auto t = ++Node::tflag_count;
|
||||
for (auto& vh : VarHolder::hold_vars)
|
||||
if (vh->var->tflag != t) {
|
||||
vh->var->tflag = t;
|
||||
queue.push_back(vh->var);
|
||||
}
|
||||
bfs_both(queue, [](Node*){return true;});
|
||||
vector<pair<int64, Node*>> nodes;
|
||||
nodes.reserve(queue.size());
|
||||
for (auto* node : queue)
|
||||
nodes.push_back({node->__id(), node});
|
||||
std::sort(nodes.begin(), nodes.end());
|
||||
log << "list of the oldest nodes:\n";
|
||||
for (int i=0; i<10 && i<nodes.size(); i++) {
|
||||
log << "ID#" >> nodes[i].first >> ":" << nodes[i].second << "\n";
|
||||
}
|
||||
// vector<Node*> queue;
|
||||
// auto t = ++Node::tflag_count;
|
||||
// for (auto& vh : VarHolder::hold_vars)
|
||||
// if (vh->var->tflag != t) {
|
||||
// vh->var->tflag = t;
|
||||
// queue.push_back(vh->var);
|
||||
// }
|
||||
// bfs_both(queue, [](Node*){return true;});
|
||||
// vector<pair<int64, Node*>> nodes;
|
||||
// nodes.reserve(queue.size());
|
||||
// for (auto* node : queue)
|
||||
// nodes.push_back({node->__id(), node});
|
||||
// std::sort(nodes.begin(), nodes.end());
|
||||
// log << "list of the oldest nodes:\n";
|
||||
// for (int i=0; i<10 && i<nodes.size(); i++) {
|
||||
// log << "ID#" >> nodes[i].first >> ":" << nodes[i].second << "\n";
|
||||
// }
|
||||
#endif
|
||||
|
||||
if (use_stat_allocator) {
|
||||
|
@ -81,10 +81,15 @@ void display_memory_info(const char* fileline) {
|
|||
log << "total alloc:" << FloatOutput{(double)(stat_allocator_total_alloc_byte
|
||||
- stat_allocator_total_free_byte), " KMG", 1024, "B"};
|
||||
log << "total alloc call:" << FloatOutput{(double)(stat_allocator_total_alloc_call
|
||||
- stat_allocator_total_free_call), " KMG", 1000, ""} >> '\n';
|
||||
- stat_allocator_total_free_call), " KMG", 1000, ""}
|
||||
>> '(' >> stat_allocator_total_alloc_call >> '/' >>
|
||||
stat_allocator_total_free_call >> ")\n";
|
||||
}
|
||||
int64 all_total = 0, gpu_total = 0, cpu_total = 0;
|
||||
for (auto& a : SFRLAllocator::sfrl_allocators) {
|
||||
auto total = a->used_memory + a->unused_memory;
|
||||
all_total += total;
|
||||
a->is_cuda() ? gpu_total += total : cpu_total += total;
|
||||
log << "name:" << a->name() << "is_cuda:" << a->is_cuda()
|
||||
<< "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"}
|
||||
>> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)"
|
||||
|
@ -92,6 +97,47 @@ void display_memory_info(const char* fileline) {
|
|||
>> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"
|
||||
<< "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n";
|
||||
}
|
||||
log << "cpu&gpu:" << FloatOutput{(double)all_total, " KMG", 1024, "B"}
|
||||
<< "gpu:" << FloatOutput{(double)gpu_total, " KMG", 1024, "B"}
|
||||
<< "cpu:" << FloatOutput{(double)cpu_total, " KMG", 1024, "B"} >> '\n';
|
||||
|
||||
if (dump_var) {
|
||||
vector<Node*> queue;
|
||||
unordered_set<Node*> visited;
|
||||
for (auto& vh : VarHolder::hold_vars)
|
||||
if (!visited.count(vh->var)) {
|
||||
queue.push_back(vh->var);
|
||||
visited.insert(vh->var);
|
||||
}
|
||||
int64 cum = 0;
|
||||
for (int i=0; i<queue.size(); i++) {
|
||||
for (auto* n : queue[i]->inputs())
|
||||
if (!visited.count(n)) {
|
||||
queue.push_back(n);
|
||||
visited.insert(n);
|
||||
}
|
||||
for (auto* n : queue[i]->outputs())
|
||||
if (!visited.count(n)) {
|
||||
queue.push_back(n);
|
||||
visited.insert(n);
|
||||
}
|
||||
if (queue[i]->is_var()) {
|
||||
auto v = (Var*)queue[i];
|
||||
if (v->size>=0 && v->mem_ptr) {
|
||||
cum += v->size;
|
||||
log << FloatOutput{(double)v->size, " KMG", 1024, "B"}
|
||||
>> "(" >> std::setprecision(p) >> v->size*100.0 / all_total >> "%)"
|
||||
<< FloatOutput{(double)cum, " KMG", 1024, "B"}
|
||||
>> "(" >> std::setprecision(p) >> cum*100.0 / all_total >> "%)"
|
||||
<< v >> "\n";
|
||||
if (v->size == 100*64*112*112*4) {
|
||||
for (auto op : v->outputs())
|
||||
log << "\t" << op << '\n';
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
log >> "===========================\n";
|
||||
log.end();
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
namespace jittor {
|
||||
|
||||
// @pyjt(display_memory_info)
|
||||
void display_memory_info(const char* fileline="");
|
||||
void display_memory_info(const char* fileline="", bool dump_var=false);
|
||||
|
||||
// @pyjt(MemInfo)
|
||||
struct MemInfo {
|
||||
|
|
|
@ -101,7 +101,7 @@ void UpdateQueue::auto_flush() {
|
|||
vector<Var*> vars;
|
||||
vars.reserve(queue.size());
|
||||
for (auto& l : queue) {
|
||||
while (l.size() && l.size() >= update_queue_auto_flush_depth) {
|
||||
while (l.size() && l.size() >= update_queue_auto_flush_delay) {
|
||||
auto iter = l.end(); iter--;
|
||||
auto v = iter->v;
|
||||
vars.push_back(v);
|
||||
|
@ -128,7 +128,7 @@ void UpdateQueue::push(Var* v, Var* prev) {
|
|||
queue.emplace_front();
|
||||
owner = queue.begin();
|
||||
}
|
||||
if (owner->size() >= update_queue_auto_flush_depth)
|
||||
if (owner->size() >= update_queue_auto_flush_delay)
|
||||
auto_flush();
|
||||
owner->emplace_front(UpdateQueue::Item{owner, v});
|
||||
map[v] = owner->begin();
|
||||
|
|
Loading…
Reference in New Issue