mirror of https://github.com/Jittor/Jittor
memory_profiler
This commit is contained in:
parent
86b7aaaa4e
commit
b170ba73ca
|
@ -76,7 +76,7 @@ def repeat(x, *shape):
|
|||
x = x.broadcast(x_shape)
|
||||
elif len_x_shape > len_shape:
|
||||
rep_shape = (len_x_shape - len_shape) * [1] + shape
|
||||
|
||||
#TODO if input.shape[i]=1, no add [1]
|
||||
reshape_shape = []
|
||||
broadcast_shape = []
|
||||
for x_s,r_s in zip(x_shape,rep_shape):
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
from jittor import nn, Module
|
||||
from jittor.models import resnet
|
||||
import numpy as np
|
||||
import sys, os
|
||||
import random
|
||||
import math
|
||||
import unittest
|
||||
from jittor.test.test_reorder_tuner import simple_parser
|
||||
from jittor.test.test_log import find_log_with_re
|
||||
from jittor.dataset.mnist import MNIST
|
||||
import jittor.transform as trans
|
||||
import time
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
class MnistNet(Module):
|
||||
def __init__(self):
|
||||
self.model = resnet.Resnet18()
|
||||
self.layer = nn.Linear(1000,10)
|
||||
def execute(self, x):
|
||||
x = self.model(x)
|
||||
x = self.layer(x)
|
||||
return x
|
||||
|
||||
@unittest.skipIf(skip_this_test, "skip_this_test")
|
||||
class TestMemoryProfiler(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
# hyper-parameters
|
||||
self.batch_size = 100
|
||||
self.weight_decay = 0.0001
|
||||
self.momentum = 0.9
|
||||
self.learning_rate = 0.1
|
||||
# mnist dataset
|
||||
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
|
||||
.set_attrs(batch_size=self.batch_size, shuffle=True)
|
||||
self.train_loader.num_workers = 4
|
||||
|
||||
# setup random seed
|
||||
def setup_seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
jt.seed(seed)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1, use_stat_allocator=1, trace_py_var=2)
|
||||
def test_resnet(self):
|
||||
self.setup_seed(1)
|
||||
loss_list=[]
|
||||
acc_list=[]
|
||||
mnist_net = MnistNet()
|
||||
global prev
|
||||
prev = time.time()
|
||||
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
|
||||
|
||||
iters = 50
|
||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
if (batch_idx > iters):
|
||||
break
|
||||
jt.display_memory_info()
|
||||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
SGD.step(loss)
|
||||
def callback(batch_idx, loss, output, target):
|
||||
global prev
|
||||
pred = np.argmax(output, axis=1)
|
||||
acc = np.mean(target==pred)
|
||||
loss_list.append(loss[0])
|
||||
acc_list.append(acc)
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||
.format(0, batch_idx, iters,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
|
||||
jt.fetch(batch_idx, loss, output, target, callback)
|
||||
jt.sync_all(True)
|
||||
jt.display_max_memory_info()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -21,10 +21,12 @@
|
|||
#include "fuser.h"
|
||||
#include "profiler/profiler_guard.h"
|
||||
#include "parallel_compiler.h"
|
||||
#include "memory_profiler.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
Executor exe;
|
||||
extern MemoryProfiler memory_profiler;
|
||||
|
||||
// from fetch_op.cc
|
||||
extern list<VarPtr> fetcher_to_free;
|
||||
|
@ -415,6 +417,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
for (auto* var : op->outputs()) {
|
||||
var->alloc(allocator);
|
||||
}
|
||||
memory_profiler.check();
|
||||
LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs();
|
||||
op->do_prepare(jkl);
|
||||
bool is_cuda = op->flags.get(NodeFlags::_cuda);
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
#include "memory_profiler.h"
|
||||
#include "graph.h"
|
||||
#include "var_holder.h"
|
||||
#include "var.h"
|
||||
#include "mem/allocator/sfrl_allocator.h"
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <sys/sysinfo.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
//TODO reuse from mem_info.cc
|
||||
struct FloatOutput_ {
|
||||
double value;
|
||||
string scale;
|
||||
int base;
|
||||
string suffix;
|
||||
int p=4;
|
||||
};
|
||||
std::ostream& operator<<(std::ostream& os, const FloatOutput_& o) {
|
||||
int w = 8;
|
||||
os << std::setw(w-2-o.suffix.size());
|
||||
os << std::setprecision(o.p);
|
||||
uint i=0;
|
||||
double k = o.value;
|
||||
for (; i+1<o.scale.size(); i++) {
|
||||
if (k<o.base) break;
|
||||
k /= o.base;
|
||||
}
|
||||
os << k << o.scale[i];
|
||||
return os << o.suffix;
|
||||
}
|
||||
|
||||
MemoryProfiler memory_profiler;
|
||||
DEFINE_FLAG(int, profile_memory_enable, 0, "Enable memory profiler.");
|
||||
|
||||
MemoryProfiler::MemoryProfiler() {
|
||||
clear();
|
||||
}
|
||||
|
||||
void MemoryProfiler::clear() {
|
||||
allocations.clear();
|
||||
max_memory_size = 0;
|
||||
max_used_memory_size = 0;
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> MemoryProfiler::get_memory_info() {
|
||||
size_t used = 0;
|
||||
size_t unused = 0;
|
||||
//TODO add mssfrl allocator
|
||||
for (auto& a : SFRLAllocator::sfrl_allocators) {
|
||||
used += a->used_memory;
|
||||
unused += a->unused_memory;
|
||||
}
|
||||
return std::make_pair(used, unused);
|
||||
}
|
||||
|
||||
void MemoryProfiler::check() {
|
||||
std::pair<size_t, size_t> mem_info = get_memory_info();
|
||||
if (mem_info.first > max_used_memory_size) {
|
||||
max_used_memory_size = mem_info.first;
|
||||
|
||||
allocations.clear();
|
||||
size_t memory_size = 0;
|
||||
vector<std::pair<string, size_t>> live_vars;
|
||||
vector<Node*> queue;
|
||||
|
||||
auto t = ++Node::tflag_count;
|
||||
for (auto& vh : VarHolder::hold_vars)
|
||||
if (vh->var->tflag != t) {
|
||||
vh->var->tflag = t;
|
||||
queue.push_back(vh->var);
|
||||
}
|
||||
bfs_both(queue, [](Node*){return true;});
|
||||
for (Node* node : queue) {
|
||||
if (node->is_var()) {
|
||||
Var* var = (Var*)node;
|
||||
if (var->mem_ptr != nullptr) {
|
||||
std::stringstream stream;
|
||||
stream << var;
|
||||
live_vars.push_back(std::make_pair(stream.str(), var->size));
|
||||
if (!allocations.count(var->mem_ptr)) {
|
||||
allocations[var->mem_ptr] = 1;
|
||||
memory_size += var->size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
max_live_vars = live_vars;
|
||||
max_memory_size = memory_size;
|
||||
}
|
||||
}
|
||||
|
||||
bool MemoryProfiler::cmp(const std::pair<string, size_t>& a, const std::pair<string, size_t>& b) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
|
||||
void MemoryProfiler::display_max_memory_info() {
|
||||
Log log("", 'i', 0);
|
||||
std::sort(max_live_vars.begin(), max_live_vars.end(), cmp);
|
||||
log << "\n=====display_max_memory_info=====\n";
|
||||
log << "max used memory" << FloatOutput_{(double)max_used_memory_size, " KMG", 1024, "B"} << "\n";
|
||||
log << "max var memory" << FloatOutput_{(double)max_memory_size, " KMG", 1024, "B"} << "\n\n";
|
||||
log << "[Size]" << "[Percent]" << "[Var Info]" << "\n";
|
||||
for (int i = 0; i < max_live_vars.size(); ++i) {
|
||||
log << FloatOutput_{(double)max_live_vars[i].second, " KMG", 1024, "B"} << double(max_live_vars[i].second) / max_memory_size * 100 << "%" << max_live_vars[i].first << "\n\n";
|
||||
}
|
||||
log << "=========================\n";
|
||||
log.end();
|
||||
}
|
||||
|
||||
void display_max_memory_info() {
|
||||
memory_profiler.display_max_memory_info();
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,39 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
#include "mem/allocator.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "var.h"
|
||||
namespace jittor {
|
||||
|
||||
// @pyjt(display_max_memory_info)
|
||||
void display_max_memory_info();
|
||||
|
||||
struct MemoryProfiler {
|
||||
std::map<void*, size_t> allocations;
|
||||
// Max Infos
|
||||
std::vector<std::pair<string, size_t>> max_live_vars;
|
||||
size_t max_used_memory_size;
|
||||
size_t max_memory_size;
|
||||
|
||||
|
||||
MemoryProfiler();
|
||||
static bool cmp(const std::pair<string, size_t>& a, const std::pair<string, size_t>& b);
|
||||
void clear();
|
||||
void check();
|
||||
std::pair<size_t, size_t> get_memory_info();
|
||||
void display_max_memory_info();
|
||||
};
|
||||
|
||||
extern MemoryProfiler memory_profiler;
|
||||
|
||||
DECLARE_FLAG(int, profile_memory_enable);
|
||||
|
||||
} // jittor
|
Loading…
Reference in New Issue