polish mem leak problem

This commit is contained in:
Dun Liang 2022-09-16 00:58:35 +08:00
parent d2b5c281b2
commit c0ed98cbd6
6 changed files with 61 additions and 28 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.5.10'
__version__ = '1.3.5.11'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -1329,6 +1329,9 @@ if use_data_gz:
f.write(md5)
files.append(data_o_path)
files = [f for f in files if "__data__" not in f]
else:
files = [f for f in files
if "__data__" not in f or "src" in f.split("__data__")[1]]
cc_flags += f" -l\"jit_utils_core{lib_suffix}\" "
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)

View File

@ -67,26 +67,39 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
<< "lived_vars:" << Var::number_of_lived_vars
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
#ifdef NODE_MEMCHECK
// get the oldest var
// vector<Node*> queue;
// auto t = ++Node::tflag_count;
// for (auto& vh : hold_vars)
// if (vh->var->tflag != t) {
// vh->var->tflag = t;
// queue.push_back(vh->var);
// }
// bfs_both(queue, [](Node*){return true;});
// vector<pair<int64, Node*>> nodes;
// nodes.reserve(queue.size());
// for (auto* node : queue)
// nodes.push_back({node->__id(), node});
// std::sort(nodes.begin(), nodes.end());
// log << "list of the oldest nodes:\n";
// for (int i=0; i<10 && i<nodes.size(); i++) {
// log << "ID#" >> nodes[i].first >> ":" << nodes[i].second << "\n";
// }
#endif
if (trace_py_var) {
vector<Node*> queue;
auto t = ++Node::tflag_count;
for (auto& vh : hold_vars)
if (vh->var->tflag != t) {
vh->var->tflag = t;
queue.push_back(vh->var);
}
bfs_both(queue, [](Node*){return true;});
static unordered_map<int64, int> cnt;
auto cnt_bk = cnt;
map<int,int> stat;
for (auto* node : queue) {
auto &x = cnt[node->id];
x++;
if (x == 3 && node->is_var()) {
LOGe << node;
}
stat[x]++;
}
for (auto x : cnt_bk) {
if (x.second == cnt[x.first]) {
cnt.erase(x.first);
}
}
LOGe << stat << lived_nodes_id.size();
for (auto nid : lived_nodes_id) {
if (!cnt.count(nid.first)) {
LOGe << nid;
}
}
}
if (use_stat_allocator) {
log << "stat:" << use_stat_allocator;

View File

@ -82,4 +82,8 @@ class TestAdamw(unittest.TestCase):
lt = float(loss_torch.detach().numpy())
lj = float(loss_jittor.data)
# print(abs(lt - lj))
assert abs(lt - lj) < 1e-5
assert abs(lt - lj) < 1e-5
if __name__ == "__main__":
unittest.main()

View File

@ -191,13 +191,26 @@ print(matmul_transpose(a, b))
x = m.state_dict()
m.load_state_dict(x)
# def test_res2net(self):
# import jittor.models
# net = jittor.models.res2net50(True)
# img = jt.random((2,3,224,224))
# out = net(img)
# print(out.shape, out.sum())
# assert out.shape == [2,1000]
def test_res2net(self):
import jittor.models
net = jittor.models.res2net50(True)
img = jt.random((2,3,224,224))
out = net(img)
print(out.shape, out.sum())
jt.display_memory_info()
jt.display_memory_info()
assert out.shape == [2,1000]
def test_argmax_memleak(self):
a = jt.random([10])
_, m = jt.argmax(a, 0)
del _
m.sync()
g = jt.grad(m*10, a)
g.sync()
del a, g, m
jt.display_memory_info()
assert jt.liveness_info()["lived_ops"] == 0
if __name__ == "__main__":

Binary file not shown.