This commit is contained in:
cxjyxx_me 2021-01-14 11:07:51 +08:00
parent a96ccab4bb
commit 8cd3dcb234
2 changed files with 74 additions and 34 deletions

View File

@ -722,7 +722,7 @@ def triu_(x,diagonal=0):
jt.Var.triu_ = triu_
def print_tree(now, max_memory_size, prefix1='', prefix2=''):
def print_tree(now, max_memory_size, prefix1, prefix2, build_by):
def format_size(s):
if (s < 1024):
s = str(s)
@ -739,15 +739,19 @@ def print_tree(now, max_memory_size, prefix1='', prefix2=''):
s = format(s/1024/1024/1024, '.2f')
return s + ' GB'
out = ''
tab = ' '
print(prefix1+now['name']+'('+now['type']+')')
print(prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%]')
for p in now['path']:
print(prefix2+p)
if (len(now['children']) > 0):
print(prefix2 + tab + '| ')
out += prefix1+now['name']+'('+now['type']+')\n'
out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%]\n'
if (build_by == 0):
for p in now['path']:
out += prefix2+p+'\n'
else:
print(prefix2)
out += prefix2+now['path'] + '\n'
if (len(now['children']) > 0):
out += prefix2 + tab + '| ' + '\n'
else:
out += prefix2 + '\n'
for i in range(len(now['children'])):
c = now['children'][i]
if i < len(now['children']) - 1:
@ -756,9 +760,10 @@ def print_tree(now, max_memory_size, prefix1='', prefix2=''):
else:
prefix1_ = prefix2 + tab + '└─'
prefix2_ = prefix2 + tab + ' '
print_tree(c, max_memory_size, prefix1_, prefix2_)
out += print_tree(c, max_memory_size, prefix1_, prefix2_, build_by)
return out
def get_max_memory_treemap():
def get_max_memory_treemap(build_by=0, do_print=True):
div1 = "[!@#div1!@#]"
div2 = "[!@#div2!@#]"
div3 = "[!@#div3!@#]"
@ -777,28 +782,52 @@ def get_max_memory_treemap():
s = {'path':s__[0], 'name':s__[1], 'type':s__[2]}
var['stack'].append(s)
vars.append(var)
tree = {'name':'root', "children":[], 'size':0, 'path':[], 'type':''}
if (build_by == 0): # build tree by name
tree = {'name':'root', "children":[], 'size':0, 'path':[], 'type':''}
def find_child(now, key):
for c in now['children']:
if (c['name'] == key):
return c
return None
for v in vars:
now = tree
now['size'] += v['size']
for s in v['stack']:
ch = find_child(now, s['name'])
if (ch is not None):
if (not s['path'] in ch['path']):
ch['path'].append(s['path'])
assert(ch['type']==s['type'])
now = ch
now['size'] += v['size']
else:
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':[s['path']], 'type':s['type']}
now['children'].append(now_)
now = now_
def find_child(now, key):
for c in now['children']:
if (c['name'] == key):
return c
return None
for v in vars:
now = tree
now['size'] += v['size']
for s in v['stack']:
ch = find_child(now, s['name'])
if (ch is not None):
if (not s['path'] in ch['path']):
ch['path'].append(s['path'])
assert(ch['type']==s['type'])
now = ch
now['size'] += v['size']
else:
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':[s['path']], 'type':s['type']}
now['children'].append(now_)
now = now_
elif (build_by == 1): # build tree by path
tree = {'name':'root', "children":[], 'size':0, 'path':'_root_', 'type':''}
def find_child(now, key):
for c in now['children']:
if (c['path'] == key):
return c
return None
for v in vars:
now = tree
now['size'] += v['size']
for s in v['stack']:
ch = find_child(now, s['path'])
if (ch is not None):
now = ch
now['size'] += v['size']
else:
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':s['path'], 'type':s['type']}
now['children'].append(now_)
now = now_
else:
assert(False)
def sort_tree(now):
def takeSize(elem):
return elem['size']
@ -806,5 +835,7 @@ def get_max_memory_treemap():
for c in now['children']:
sort_tree(c)
sort_tree(tree)
print_tree(tree, max_memory_size, '', '')
return tree
out = print_tree(tree, max_memory_size, '', '', build_by)
if (do_print):
print(out)
return tree, out

View File

@ -82,7 +82,16 @@ class TestMemoryProfiler(unittest.TestCase):
jt.fetch(batch_idx, loss, output, target, callback)
jt.sync_all(True)
jt.display_max_memory_info()
jt.get_max_memory_treemap()
_, out = jt.get_max_memory_treemap()
out_ = out.split('\n')
assert(out_[0] == 'root()')
assert(out_[3] == ' ├─mnist_net(MnistNet)')
assert(out_[7] == ' | └─model(ResNet)')
_, out = jt.get_max_memory_treemap(build_by=1)
out_ = out.split('\n')
assert(out_[0] == 'root()')
assert(out_[4] == ' ├─mnist_net(MnistNet)')
assert(out_[8] == ' | └─model(ResNet)')
if __name__ == "__main__":
unittest.main()