This commit is contained in:
cxjyxx_me 2021-01-14 13:09:20 +08:00
parent 31c2be576e
commit 6c83e40883
10 changed files with 51 additions and 23 deletions

View File

@ -3,7 +3,7 @@
# Maintainers: # Maintainers:
# Dun Liang <randonlang@gmail.com>. # Dun Liang <randonlang@gmail.com>.
# Wenyang Zhou <576825820@qq.com> # Wenyang Zhou <576825820@qq.com>
# # Guoye Yang <498731903@qq.com>
# #
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.

View File

@ -1,9 +1,9 @@
# *************************************************************** # ***************************************************************
# Copyright (c) 2020 Jittor. Authors: # Copyright (c) 2021 Jittor. All Rights Reserved.
# Guowei Yang <471184555@qq.com> # Maintainers:
# Meng-Hao Guo <guomenghao1997@gmail.com> # Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>. # Dun Liang <randonlang@gmail.com>.
# All Rights Reserved. #
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
@ -53,7 +53,7 @@ class TestMemoryProfiler(unittest.TestCase):
jt.seed(seed) jt.seed(seed)
@unittest.skipIf(not jt.has_cuda, "Cuda not found") @unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1, use_stat_allocator=1, trace_py_var=2) @jt.flag_scope(use_cuda=1, use_stat_allocator=1, trace_py_var=3, profile_memory_enable=1)
def test_resnet(self): def test_resnet(self):
self.setup_seed(1) self.setup_seed(1)
loss_list=[] loss_list=[]
@ -85,13 +85,13 @@ class TestMemoryProfiler(unittest.TestCase):
_, out = jt.get_max_memory_treemap() _, out = jt.get_max_memory_treemap()
out_ = out.split('\n') out_ = out.split('\n')
assert(out_[0] == 'root()') assert(out_[0] == 'root()')
assert(out_[3] == ' ├─mnist_net(MnistNet)') assert(out_[3] == ' ├─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:194(_run_module_as_main)')
assert(out_[7] == ' | └─model(ResNet)') assert(out_[7] == ' | └─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:87(_run_code)')
_, out = jt.get_max_memory_treemap(build_by=1) _, out = jt.get_max_memory_treemap(build_by=1)
out_ = out.split('\n') out_ = out.split('\n')
assert(out_[0] == 'root()') assert(out_[0] == 'root()')
assert(out_[4] == ' ├─mnist_net(MnistNet)') assert(out_[4] == ' ├─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:194(_run_module_as_main)')
assert(out_[8] == ' | └─model(ResNet)') assert(out_[8] == ' | └─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:87(_run_code)')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,6 +1,9 @@
// *************************************************************** // ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved. // Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>. // Maintainers:
// Dun Liang <randonlang@gmail.com>.
// Guoye Yang <498731903@qq.com>
//
// This file is subject to the terms and conditions defined in // This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
@ -29,6 +32,7 @@ namespace jittor {
Executor exe; Executor exe;
extern MemoryProfiler memory_profiler; extern MemoryProfiler memory_profiler;
DECLARE_FLAG(int, profile_memory_enable);
// from fetch_op.cc // from fetch_op.cc
extern list<VarPtr> fetcher_to_free; extern list<VarPtr> fetcher_to_free;
@ -424,6 +428,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
for (auto* var : op->outputs()) { for (auto* var : op->outputs()) {
var->alloc(allocator); var->alloc(allocator);
} }
if (PREDICT_BRANCH_NOT_TAKEN(profile_memory_enable))
memory_profiler.check(); memory_profiler.check();
LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs(); LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs();
op->do_prepare(jkl); op->do_prepare(jkl);

View File

@ -1,6 +1,9 @@
// *************************************************************** // ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved. // Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>. // Maintainers:
// Dun Liang <randonlang@gmail.com>.
// Guoye Yang <498731903@qq.com>
//
// This file is subject to the terms and conditions defined in // This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************

View File

@ -1,5 +1,5 @@
// *************************************************************** // ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved. // Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: // Maintainers:
// Guoye Yang <498731903@qq.com> // Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>. // Dun Liang <randonlang@gmail.com>.

View File

@ -1,5 +1,5 @@
// *************************************************************** // ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved. // Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: // Maintainers:
// Guoye Yang <498731903@qq.com> // Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>. // Dun Liang <randonlang@gmail.com>.

View File

@ -1,3 +1,12 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "memory_profiler.h" #include "memory_profiler.h"
#include "graph.h" #include "graph.h"
#include "var_holder.h" #include "var_holder.h"
@ -19,7 +28,7 @@ struct FloatOutput_ {
string suffix; string suffix;
int p=4; int p=4;
}; };
std::ostream& operator<<(std::ostream& os, const FloatOutput_& o) { inline std::ostream& operator<<(std::ostream& os, const FloatOutput_& o) {
int w = 8; int w = 8;
os << std::setw(w-2-o.suffix.size()); os << std::setw(w-2-o.suffix.size());
os << std::setprecision(o.p); os << std::setprecision(o.p);
@ -47,6 +56,7 @@ void MemoryProfiler::clear() {
} }
std::pair<size_t, size_t> MemoryProfiler::get_memory_info() { std::pair<size_t, size_t> MemoryProfiler::get_memory_info() {
ASSERT(profile_memory_enable == 1);
size_t used = 0; size_t used = 0;
size_t unused = 0; size_t unused = 0;
//TODO add mssfrl allocator //TODO add mssfrl allocator
@ -58,6 +68,7 @@ std::pair<size_t, size_t> MemoryProfiler::get_memory_info() {
} }
void MemoryProfiler::check() { void MemoryProfiler::check() {
ASSERT(profile_memory_enable == 1);
std::pair<size_t, size_t> mem_info = get_memory_info(); std::pair<size_t, size_t> mem_info = get_memory_info();
if (mem_info.first > max_used_memory_size) { if (mem_info.first > max_used_memory_size) {
max_used_memory_size = mem_info.first; max_used_memory_size = mem_info.first;
@ -99,6 +110,7 @@ bool MemoryProfiler::cmp(const std::pair<std::pair<string, vector<Stack>>, size_
} }
void MemoryProfiler::display_max_memory_info() { void MemoryProfiler::display_max_memory_info() {
ASSERT(profile_memory_enable == 1);
Log log("", 'i', 0); Log log("", 'i', 0);
std::sort(max_live_vars.begin(), max_live_vars.end(), cmp); std::sort(max_live_vars.begin(), max_live_vars.end(), cmp);
log << "\n=====display_max_memory_info=====\n"; log << "\n=====display_max_memory_info=====\n";
@ -117,10 +129,12 @@ void MemoryProfiler::display_max_memory_info() {
} }
void display_max_memory_info() { void display_max_memory_info() {
ASSERT(profile_memory_enable == 1);
memory_profiler.display_max_memory_info(); memory_profiler.display_max_memory_info();
} }
string MemoryProfiler::get_max_memory_info() { string MemoryProfiler::get_max_memory_info() {
ASSERT(profile_memory_enable == 1);
std::stringstream out; std::stringstream out;
string div1 = "[!@#div1!@#]"; string div1 = "[!@#div1!@#]";
string div2 = "[!@#div2!@#]"; string div2 = "[!@#div2!@#]";
@ -142,6 +156,7 @@ string MemoryProfiler::get_max_memory_info() {
} }
string get_max_memory_info() { string get_max_memory_info() {
ASSERT(profile_memory_enable == 1);
return memory_profiler.get_max_memory_info(); return memory_profiler.get_max_memory_info();
} }

View File

@ -1,6 +1,9 @@
// *************************************************************** // ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved. // Copyright (c) 2021 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>. // Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in // This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************

View File

@ -48,7 +48,6 @@ static void setitem_inplace(SetitemOp* op) {
} }
auto output = op->outputs().front(); auto output = op->outputs().front();
output->share_with(input); output->share_with(input);
return;
auto data = op->input(1); auto data = op->input(1);
// if setitem requires type conversion, don't inplace // if setitem requires type conversion, don't inplace

View File

@ -1,6 +1,9 @@
// *************************************************************** // ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved. // Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>. // Maintainers:
// Dun Liang <randonlang@gmail.com>.
// Guoye Yang <498731903@qq.com>
//
// This file is subject to the terms and conditions defined in // This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
@ -186,7 +189,7 @@ void TraceData::record_node(Node* node, bool record_stack) {
NodeData data; NodeData data;
data.id = node_data_cnt++; data.id = node_data_cnt++;
id_map[node] = data.id; id_map[node] = data.id;
if (trace_py_var) { if (!node->is_var() || trace_py_var>=3) {
if (record_stack) { if (record_stack) {
if (trace_grad_op) { if (trace_grad_op) {
auto iter = trace_data.id_map.find(trace_grad_op); auto iter = trace_data.id_map.find(trace_grad_op);