add grad memory check

This commit is contained in:
Dun Liang 2022-11-30 13:12:10 +08:00
parent e014f4f25c
commit be8faf4dfc
3 changed files with 37 additions and 5 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.3.5.39' __version__ = '1.3.5.40'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -55,6 +55,8 @@ std::ostream& operator<<(std::ostream& os, const FloatOutput& o) {
return os << o.suffix; return os << o.suffix;
} }
unordered_map<int64, VarPtr>* _grad_backup_ptr = nullptr;
void display_memory_info(const char* fileline, bool dump_var, bool red_color) { void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
int p = 3; int p = 3;
Log log(fileline, red_color?'e':'i', 0); Log log(fileline, red_color?'e':'i', 0);
@ -66,6 +68,8 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
log << "hold_vars:" << hold_vars.size() log << "hold_vars:" << hold_vars.size()
<< "lived_vars:" << Var::number_of_lived_vars << "lived_vars:" << Var::number_of_lived_vars
<< "lived_ops:" << Op::number_of_lived_ops >> '\n'; << "lived_ops:" << Op::number_of_lived_ops >> '\n';
if (_grad_backup_ptr)
log << "jtorch_grad_vars:" << _grad_backup_ptr->size() >> '\n';
// get the oldest var // get the oldest var
if (trace_py_var) { if (trace_py_var) {
@ -93,10 +97,38 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
cnt.erase(x.first); cnt.erase(x.first);
} }
} }
LOGe << stat << lived_nodes_id.size(); LOGe << "appear time -> node cnt:" << stat;
for (auto nid : lived_nodes_id) { if (lived_nodes_id.size()) {
if (!cnt.count(nid.first)) { LOGe << "lived_nodes cnt:" << lived_nodes_id.size();
LOGe << nid; Node* not_found=nullptr;
int not_found_cnt = 0;
for (auto nid : lived_nodes_id) {
if (!cnt.count(nid.first)) {
not_found_cnt ++;
if (!not_found) not_found = nid.second;
}
}
LOGe << "Total not_found:" << not_found_cnt;
if (not_found)
LOGe << "not found node:" << not_found;
if (_grad_backup_ptr) {
Node* not_found_grad=nullptr;
int parent_id = 0;
int not_found_grad_cnt = 0;
for (auto& gid : *_grad_backup_ptr) {
if (!lived_nodes_id.count(gid.first)) {
not_found_grad_cnt ++;
if (!not_found_grad) {
not_found_grad = gid.second.ptr;
parent_id = gid.first;
}
}
}
LOGe << "Grad not found cnt:" << not_found_grad_cnt;
if (not_found_grad) {
LOGe << "grad not found node" << not_found_grad;
LOGe << "parent id:" << parent_id;
}
} }
} }
} }

Binary file not shown.