From 86b7aaaa4e21aa6d53678d358dfd2080a642ef0a Mon Sep 17 00:00:00 2001 From: cxjyxx_me <498731903@qq.com> Date: Tue, 8 Dec 2020 11:13:07 +0800 Subject: [PATCH 1/8] cos&step lr_scheduler --- python/jittor/lr_scheduler.py | 56 +++++++++++++++++++++++++++++++++- python/jittor/models/resnet.py | 12 ++++++-- src/opt/pass_manager.cc | 2 +- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/python/jittor/lr_scheduler.py b/python/jittor/lr_scheduler.py index cc507e69..8c9eaf0b 100644 --- a/python/jittor/lr_scheduler.py +++ b/python/jittor/lr_scheduler.py @@ -83,4 +83,58 @@ class ReduceLROnPlateau(object): save = self.threshold + 1.0 return a > b * save else: - return a > b + self.threshold \ No newline at end of file + return a > b + self.threshold + +class CosineAnnealingLR(object): + def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): + self.T_max = T_max + self.eta_min = eta_min + self.optimizer = optimizer + self.last_epoch = last_epoch + self.base_lr = optimizer.lr + #TODO set last_epoch is not ready + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lr + now_lr = self.optimizer.lr + if (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return (now_lr + (self.base_lr - self.eta_min) * + (1 - math.cos(math.pi / self.T_max)) / 2) + return ((1 + math.cos(math.pi * self.last_epoch / self.T_max)) / + (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * + (now_lr - self.eta_min) + self.eta_min) + + def step(self): + self.last_epoch += 1 + self.update_lr(self.get_lr()) + + def update_lr(self, new_lr): + self.optimizer.lr = new_lr + for i, param_group in enumerate(self.optimizer.param_groups): + if param_group.get("lr")!=None: + param_group["lr"] = new_lr + +class MultiStepLR(object): + def __init__(self, optimizer, milestones=[], gamma=0.1, last_epoch=-1): + self.optimizer = optimizer + self.milestones = milestones + self.gamma = gamma + self.last_epoch = last_epoch + #TODO set last_epoch is not ready + + def get_lr(self): + now_lr = self.optimizer.lr + if (self.last_epoch in self.milestones): + now_lr *= gamma + return now_lr + + def step(self): + self.last_epoch += 1 + self.update_lr(self.get_lr()) + + def update_lr(self, new_lr): + self.optimizer.lr = new_lr + for i, param_group in enumerate(self.optimizer.param_groups): + if param_group.get("lr")!=None: + param_group["lr"] = new_lr \ No newline at end of file diff --git a/python/jittor/models/resnet.py b/python/jittor/models/resnet.py index b9a93d16..9e3266cd 100644 --- a/python/jittor/models/resnet.py +++ b/python/jittor/models/resnet.py @@ -11,8 +11,8 @@ import jittor as jt from jittor import nn -__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2', - 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] +__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet26', 'Resnet38', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2', + 'resnet18', 'resnet34', 'resnet26', 'resnet38', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) @@ -166,6 +166,14 @@ def Resnet50(**kwargs): return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs) resnet50 = Resnet50 +def Resnet38(**kwargs): + return _resnet(Bottleneck, [2, 3, 5, 2], **kwargs) +resnet38 = Resnet38 + +def Resnet26(**kwargs): + return _resnet(Bottleneck, [1, 2, 4, 1], **kwargs) +resnet26 = Resnet26 + def Resnet101(**kwargs): """ ResNet-101 model architecture. diff --git a/src/opt/pass_manager.cc b/src/opt/pass_manager.cc index 788ac5ba..8be1e048 100644 --- a/src/opt/pass_manager.cc +++ b/src/opt/pass_manager.cc @@ -91,7 +91,7 @@ void PassManager::run_passes() { run_pass(); run_pass(); - run_pass(); + // run_pass(); run_pass(); From b170ba73ca365e321ad74b0cd32a10d537a2b478 Mon Sep 17 00:00:00 2001 From: cxjyxx_me <498731903@qq.com> Date: Mon, 28 Dec 2020 16:13:35 +0800 Subject: [PATCH 2/8] memory_profiler --- python/jittor/misc.py | 2 +- python/jittor/test/test_memory_profiler.py | 87 +++++++++++++++ src/executor.cc | 3 + src/memory_profiler.cc | 117 +++++++++++++++++++++ src/memory_profiler.h | 39 +++++++ 5 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 python/jittor/test/test_memory_profiler.py create mode 100644 src/memory_profiler.cc create mode 100644 src/memory_profiler.h diff --git a/python/jittor/misc.py b/python/jittor/misc.py index f026b9da..fd227752 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -76,7 +76,7 @@ def repeat(x, *shape): x = x.broadcast(x_shape) elif len_x_shape > len_shape: rep_shape = (len_x_shape - len_shape) * [1] + shape - + #TODO if input.shape[i]=1, no add [1] reshape_shape = [] broadcast_shape = [] for x_s,r_s in zip(x_shape,rep_shape): diff --git a/python/jittor/test/test_memory_profiler.py b/python/jittor/test/test_memory_profiler.py new file mode 100644 index 00000000..f19d01bb --- /dev/null +++ b/python/jittor/test/test_memory_profiler.py @@ -0,0 +1,87 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Guowei Yang <471184555@qq.com> +# Meng-Hao Guo +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn, Module +from jittor.models import resnet +import numpy as np +import sys, os +import random +import math +import unittest +from jittor.test.test_reorder_tuner import simple_parser +from jittor.test.test_log import find_log_with_re +from jittor.dataset.mnist import MNIST +import jittor.transform as trans +import time + +skip_this_test = False + +class MnistNet(Module): + def __init__(self): + self.model = resnet.Resnet18() + self.layer = nn.Linear(1000,10) + def execute(self, x): + x = self.model(x) + x = self.layer(x) + return x + +@unittest.skipIf(skip_this_test, "skip_this_test") +class TestMemoryProfiler(unittest.TestCase): + @classmethod + def setUpClass(self): + # hyper-parameters + self.batch_size = 100 + self.weight_decay = 0.0001 + self.momentum = 0.9 + self.learning_rate = 0.1 + # mnist dataset + self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \ + .set_attrs(batch_size=self.batch_size, shuffle=True) + self.train_loader.num_workers = 4 + + # setup random seed + def setup_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + jt.seed(seed) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1, use_stat_allocator=1, trace_py_var=2) + def test_resnet(self): + self.setup_seed(1) + loss_list=[] + acc_list=[] + mnist_net = MnistNet() + global prev + prev = time.time() + SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) + + iters = 50 + for batch_idx, (data, target) in enumerate(self.train_loader): + if (batch_idx > iters): + break + jt.display_memory_info() + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) + SGD.step(loss) + def callback(batch_idx, loss, output, target): + global prev + pred = np.argmax(output, axis=1) + acc = np.mean(target==pred) + loss_list.append(loss[0]) + acc_list.append(acc) + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' + .format(0, batch_idx, iters,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev)) + jt.fetch(batch_idx, loss, output, target, callback) + jt.sync_all(True) + jt.display_max_memory_info() + +if __name__ == "__main__": + unittest.main() diff --git a/src/executor.cc b/src/executor.cc index 978a444e..7b783624 100644 --- a/src/executor.cc +++ b/src/executor.cc @@ -21,10 +21,12 @@ #include "fuser.h" #include "profiler/profiler_guard.h" #include "parallel_compiler.h" +#include "memory_profiler.h" namespace jittor { Executor exe; +extern MemoryProfiler memory_profiler; // from fetch_op.cc extern list fetcher_to_free; @@ -415,6 +417,7 @@ void Executor::run_sync(vector vars, bool device_sync) { for (auto* var : op->outputs()) { var->alloc(allocator); } + memory_profiler.check(); LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs(); op->do_prepare(jkl); bool is_cuda = op->flags.get(NodeFlags::_cuda); diff --git a/src/memory_profiler.cc b/src/memory_profiler.cc new file mode 100644 index 00000000..1324841c --- /dev/null +++ b/src/memory_profiler.cc @@ -0,0 +1,117 @@ +#include "memory_profiler.h" +#include "graph.h" +#include "var_holder.h" +#include "var.h" +#include "mem/allocator/sfrl_allocator.h" +#include +#include +#include +#include + +namespace jittor { + +//TODO reuse from mem_info.cc +struct FloatOutput_ { + double value; + string scale; + int base; + string suffix; + int p=4; +}; +std::ostream& operator<<(std::ostream& os, const FloatOutput_& o) { + int w = 8; + os << std::setw(w-2-o.suffix.size()); + os << std::setprecision(o.p); + uint i=0; + double k = o.value; + for (; i+1 MemoryProfiler::get_memory_info() { + size_t used = 0; + size_t unused = 0; + //TODO add mssfrl allocator + for (auto& a : SFRLAllocator::sfrl_allocators) { + used += a->used_memory; + unused += a->unused_memory; + } + return std::make_pair(used, unused); +} + +void MemoryProfiler::check() { + std::pair mem_info = get_memory_info(); + if (mem_info.first > max_used_memory_size) { + max_used_memory_size = mem_info.first; + + allocations.clear(); + size_t memory_size = 0; + vector> live_vars; + vector queue; + + auto t = ++Node::tflag_count; + for (auto& vh : VarHolder::hold_vars) + if (vh->var->tflag != t) { + vh->var->tflag = t; + queue.push_back(vh->var); + } + bfs_both(queue, [](Node*){return true;}); + for (Node* node : queue) { + if (node->is_var()) { + Var* var = (Var*)node; + if (var->mem_ptr != nullptr) { + std::stringstream stream; + stream << var; + live_vars.push_back(std::make_pair(stream.str(), var->size)); + if (!allocations.count(var->mem_ptr)) { + allocations[var->mem_ptr] = 1; + memory_size += var->size; + } + } + } + } + max_live_vars = live_vars; + max_memory_size = memory_size; + } +} + +bool MemoryProfiler::cmp(const std::pair& a, const std::pair& b) { + return a.second > b.second; +} + +void MemoryProfiler::display_max_memory_info() { + Log log("", 'i', 0); + std::sort(max_live_vars.begin(), max_live_vars.end(), cmp); + log << "\n=====display_max_memory_info=====\n"; + log << "max used memory" << FloatOutput_{(double)max_used_memory_size, " KMG", 1024, "B"} << "\n"; + log << "max var memory" << FloatOutput_{(double)max_memory_size, " KMG", 1024, "B"} << "\n\n"; + log << "[Size]" << "[Percent]" << "[Var Info]" << "\n"; + for (int i = 0; i < max_live_vars.size(); ++i) { + log << FloatOutput_{(double)max_live_vars[i].second, " KMG", 1024, "B"} << double(max_live_vars[i].second) / max_memory_size * 100 << "%" << max_live_vars[i].first << "\n\n"; + } + log << "=========================\n"; + log.end(); +} + +void display_max_memory_info() { + memory_profiler.display_max_memory_info(); +} + +} // jittor \ No newline at end of file diff --git a/src/memory_profiler.h b/src/memory_profiler.h new file mode 100644 index 00000000..b6cd2d88 --- /dev/null +++ b/src/memory_profiler.h @@ -0,0 +1,39 @@ +// *************************************************************** +// Copyright (c) 2020 Jittor. All Rights Reserved. +// Authors: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "mem/allocator.h" +#include +#include +#include +#include "var.h" +namespace jittor { + +// @pyjt(display_max_memory_info) +void display_max_memory_info(); + +struct MemoryProfiler { + std::map allocations; + // Max Infos + std::vector> max_live_vars; + size_t max_used_memory_size; + size_t max_memory_size; + + + MemoryProfiler(); + static bool cmp(const std::pair& a, const std::pair& b); + void clear(); + void check(); + std::pair get_memory_info(); + void display_max_memory_info(); +}; + +extern MemoryProfiler memory_profiler; + +DECLARE_FLAG(int, profile_memory_enable); + +} // jittor \ No newline at end of file From d80e4056f655a55c662d422500fc899a61fb031e Mon Sep 17 00:00:00 2001 From: cxjyxx_me <498731903@qq.com> Date: Tue, 5 Jan 2021 17:35:01 +0800 Subject: [PATCH 3/8] memory_profiler --- python/jittor/misc.py | 87 ++++++++++++++++++++++ python/jittor/test/test_memory_profiler.py | 3 +- src/memory_profiler.cc | 39 +++++++++- src/memory_profiler.h | 8 +- src/pybind/py_var_tracer.cc | 17 ++++- src/pybind/py_var_tracer.h | 2 +- 6 files changed, 146 insertions(+), 10 deletions(-) diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 6e9d8292..33322083 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -721,3 +721,90 @@ def triu_(x,diagonal=0): return x.reindex(x.shape,indexs,overflow_conditions=overflow_conditions,overflow_value=0) jt.Var.triu_ = triu_ + +def print_tree(now, max_memory_size, prefix1='', prefix2=''): + def format_size(s): + if (s < 1024): + s = str(s) + return s + ' B' + + if (s < 1024*1024): + s = format(s/1024, '.2f') + return s + ' KB' + + if (s < 1024*1024*1024): + s = format(s/1024/1024, '.2f') + return s + ' MB' + + s = format(s/1024/1024/1024, '.2f') + return s + ' GB' + + tab = ' ' + print(prefix1+now['name']+'('+now['type']+')') + print(prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%]') + for p in now['path']: + print(prefix2+p) + if (len(now['children']) > 0): + print(prefix2 + tab + '| ') + else: + print(prefix2) + for i in range(len(now['children'])): + c = now['children'][i] + if i < len(now['children']) - 1: + prefix1_ = prefix2 + tab + '├─' + prefix2_ = prefix2 + tab + '| ' + else: + prefix1_ = prefix2 + tab + '└─' + prefix2_ = prefix2 + tab + ' ' + print_tree(c, max_memory_size, prefix1_, prefix2_) + +def get_max_memory_treemap(): + div1 = "[!@#div1!@#]" + div2 = "[!@#div2!@#]" + div3 = "[!@#div3!@#]" + info = jt.get_max_memory_info() + + vars = [] + vars_ = info.split(div1) + max_memory_size = int(vars_[0]) + vars_ = vars_[1:] + for v_ in vars_: + v__ = v_.split(div2) + var = {'size':int(v__[1]), 'stack':[]} + v__ = v__[2:-1] + for s_ in v__: + s__ = s_.split(div3) + s = {'path':s__[0], 'name':s__[1], 'type':s__[2]} + var['stack'].append(s) + vars.append(var) + tree = {'name':'root', "children":[], 'size':0, 'path':[], 'type':''} + + def find_child(now, key): + for c in now['children']: + if (c['name'] == key): + return c + return None + for v in vars: + now = tree + now['size'] += v['size'] + for s in v['stack']: + ch = find_child(now, s['name']) + if (ch is not None): + if (not s['path'] in ch['path']): + ch['path'].append(s['path']) + assert(ch['type']==s['type']) + now = ch + now['size'] += v['size'] + else: + now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':[s['path']], 'type':s['type']} + now['children'].append(now_) + now = now_ + def sort_tree(now): + def takeSize(elem): + return elem['size'] + now['children'].sort(key=takeSize, reverse=True) + for c in now['children']: + sort_tree(c) + sort_tree(tree) + print_tree(tree, max_memory_size, '', '') + return tree \ No newline at end of file diff --git a/python/jittor/test/test_memory_profiler.py b/python/jittor/test/test_memory_profiler.py index f19d01bb..55613a09 100644 --- a/python/jittor/test/test_memory_profiler.py +++ b/python/jittor/test/test_memory_profiler.py @@ -63,7 +63,7 @@ class TestMemoryProfiler(unittest.TestCase): prev = time.time() SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) - iters = 50 + iters = 10 for batch_idx, (data, target) in enumerate(self.train_loader): if (batch_idx > iters): break @@ -82,6 +82,7 @@ class TestMemoryProfiler(unittest.TestCase): jt.fetch(batch_idx, loss, output, target, callback) jt.sync_all(True) jt.display_max_memory_info() + jt.get_max_memory_treemap() if __name__ == "__main__": unittest.main() diff --git a/src/memory_profiler.cc b/src/memory_profiler.cc index 1324841c..27dbcd8a 100644 --- a/src/memory_profiler.cc +++ b/src/memory_profiler.cc @@ -7,6 +7,7 @@ #include #include #include +#include "pybind/py_var_tracer.h" namespace jittor { @@ -63,7 +64,7 @@ void MemoryProfiler::check() { allocations.clear(); size_t memory_size = 0; - vector> live_vars; + std::vector>, size_t>> live_vars; vector queue; auto t = ++Node::tflag_count; @@ -77,9 +78,10 @@ void MemoryProfiler::check() { if (node->is_var()) { Var* var = (Var*)node; if (var->mem_ptr != nullptr) { + vector stacks = get_node_trace(var); std::stringstream stream; stream << var; - live_vars.push_back(std::make_pair(stream.str(), var->size)); + live_vars.push_back(std::make_pair(std::make_pair(stream.str(), stacks), var->size)); if (!allocations.count(var->mem_ptr)) { allocations[var->mem_ptr] = 1; memory_size += var->size; @@ -92,7 +94,7 @@ void MemoryProfiler::check() { } } -bool MemoryProfiler::cmp(const std::pair& a, const std::pair& b) { +bool MemoryProfiler::cmp(const std::pair>, size_t>& a, const std::pair>, size_t>& b) { return a.second > b.second; } @@ -104,7 +106,11 @@ void MemoryProfiler::display_max_memory_info() { log << "max var memory" << FloatOutput_{(double)max_memory_size, " KMG", 1024, "B"} << "\n\n"; log << "[Size]" << "[Percent]" << "[Var Info]" << "\n"; for (int i = 0; i < max_live_vars.size(); ++i) { - log << FloatOutput_{(double)max_live_vars[i].second, " KMG", 1024, "B"} << double(max_live_vars[i].second) / max_memory_size * 100 << "%" << max_live_vars[i].first << "\n\n"; + log << FloatOutput_{(double)max_live_vars[i].second, " KMG", 1024, "B"} + << double(max_live_vars[i].second) / max_memory_size * 100 << "%" + << max_live_vars[i].first.first + << max_live_vars[i].first.second[0].file_path + ":" + std::to_string(max_live_vars[i].first.second[0].lineno) + << "\n\n"; } log << "=========================\n"; log.end(); @@ -114,4 +120,29 @@ void display_max_memory_info() { memory_profiler.display_max_memory_info(); } +string MemoryProfiler::get_max_memory_info() { + std::stringstream out; + string div1 = "[!@#div1!@#]"; + string div2 = "[!@#div2!@#]"; + string div3 = "[!@#div3!@#]"; + + std::sort(max_live_vars.begin(), max_live_vars.end(), cmp); + out << max_memory_size; + for (int i = 0; i < max_live_vars.size(); ++i) { + out << div1; + out << max_live_vars[i].first.first << div2; + out << max_live_vars[i].second << div2; + for (int j = 0; j < max_live_vars[i].first.second.size(); ++j) { + out << max_live_vars[i].first.second[j].file_path + ":" + std::to_string(max_live_vars[i].first.second[j].lineno) << div3 + << max_live_vars[i].first.second[j].module_name << div3 + << max_live_vars[i].first.second[j].module_type << div2; + } + } + return out.str(); +} + +string get_max_memory_info() { + return memory_profiler.get_max_memory_info(); +} + } // jittor \ No newline at end of file diff --git a/src/memory_profiler.h b/src/memory_profiler.h index b6cd2d88..7a0ce12b 100644 --- a/src/memory_profiler.h +++ b/src/memory_profiler.h @@ -11,25 +11,29 @@ #include #include #include "var.h" +#include "pybind/py_var_tracer.h" namespace jittor { // @pyjt(display_max_memory_info) void display_max_memory_info(); +// @pyjt(get_max_memory_info) +string get_max_memory_info(); struct MemoryProfiler { std::map allocations; // Max Infos - std::vector> max_live_vars; + vector>, size_t>> max_live_vars; size_t max_used_memory_size; size_t max_memory_size; MemoryProfiler(); - static bool cmp(const std::pair& a, const std::pair& b); + static bool cmp(const std::pair>, size_t>& a, const std::pair>, size_t>& b); void clear(); void check(); std::pair get_memory_info(); void display_max_memory_info(); + string get_max_memory_info(); }; extern MemoryProfiler memory_profiler; diff --git a/src/pybind/py_var_tracer.cc b/src/pybind/py_var_tracer.cc index dafb5606..519b1619 100644 --- a/src/pybind/py_var_tracer.cc +++ b/src/pybind/py_var_tracer.cc @@ -97,7 +97,8 @@ static vector get_stack_info() { PyObject* prev_obj = nullptr; if (trace_py_var >= 3) { // trace raw stack - auto start = std::max(0, n-5); + // auto start = std::max(0, n-5); + auto start = 0; for (int i=start; if_code->co_filename); @@ -185,7 +186,7 @@ void TraceData::record_node(Node* node, bool record_stack) { NodeData data; data.id = node_data_cnt++; id_map[node] = data.id; - if (!node->is_var() || trace_py_var>=3) { + if (trace_py_var) { if (record_stack) { if (trace_grad_op) { auto iter = trace_data.id_map.find(trace_grad_op); @@ -363,4 +364,16 @@ void print_node_trace(const Node* node, std::ostream& os) { os << _get_stack_info((Node*)node); } +vector get_node_trace(Node* node) { + auto iter = trace_data.id_map.find(node); + if (iter == trace_data.id_map.end()) + return vector(); + auto node_id = iter->second; + auto iter2 = trace_data.node_data.find(node_id); + if (iter2 == trace_data.node_data.end()) + return vector(); + return iter2->second.stacks; +} + + } // jittor diff --git a/src/pybind/py_var_tracer.h b/src/pybind/py_var_tracer.h index b1e0ddc9..ffa666e1 100644 --- a/src/pybind/py_var_tracer.h +++ b/src/pybind/py_var_tracer.h @@ -67,5 +67,5 @@ struct TraceData { extern TraceData trace_data; void print_node_trace(const Node* node, std::ostream& os); - +vector get_node_trace(Node* node); } // jittor From a3a09a48376b6d332596256c2ec16bf942fe1c91 Mon Sep 17 00:00:00 2001 From: guoye yang <498731903@qq.com> Date: Thu, 7 Jan 2021 10:23:09 +0800 Subject: [PATCH 4/8] fix --- python/jittor/dataset/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index 4d0160fd..774779af 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -128,7 +128,8 @@ class Dataset(object): if self.stop_grad else jt.array(x) if isinstance(batch, np.ndarray): return to_jt(batch) - assert isinstance(batch, Sequence) + if not isinstance(batch, (list, tuple)): + return batch new_batch = [] for a in batch: if isinstance(a, np.ndarray) or \ From a96ccab4bb37fde5696e9b7aba769f6a83d38603 Mon Sep 17 00:00:00 2001 From: cxjyxx_me <498731903@qq.com> Date: Tue, 12 Jan 2021 16:41:15 +0800 Subject: [PATCH 5/8] temp_allocator --- extern/cuda/cub/ops/cub_arg_reduce_op.cc | 8 +- extern/cuda/cub/ops/cub_argsort_op.cc | 4 +- extern/cuda/cub/ops/cub_where_op.cc | 8 +- .../cudnn/ops/cudnn_conv_backward_w_op.cc | 8 +- .../cudnn/ops/cudnn_conv_backward_x_op.cc | 8 +- extern/cuda/cudnn/ops/cudnn_conv_op.cc | 8 +- src/executor.cc | 2 + src/executor.h | 1 + src/mem/allocator.cc | 8 +- src/mem/allocator.h | 2 +- src/mem/allocator/temp_allocator.cc | 116 ++++++++++++++++++ src/mem/allocator/temp_allocator.h | 57 +++++++++ src/mem/mem_info.cc | 10 +- src/ops/candidate_op.cc | 8 +- src/ops/where_op.cc | 4 +- src/opt/gopt/setitem_gopt.cc | 2 +- 16 files changed, 221 insertions(+), 33 deletions(-) create mode 100644 src/mem/allocator/temp_allocator.cc create mode 100644 src/mem/allocator/temp_allocator.h diff --git a/extern/cuda/cub/ops/cub_arg_reduce_op.cc b/extern/cuda/cub/ops/cub_arg_reduce_op.cc index 8d5e2b6e..0216c525 100644 --- a/extern/cuda/cub/ops/cub_arg_reduce_op.cc +++ b/extern/cuda/cub/ops/cub_arg_reduce_op.cc @@ -87,7 +87,7 @@ void CubArgReduceOp::jit_run() { num_segments *= x->shape[i]; } size_t allocation_dout; - cub::KeyValuePair *d_out = (cub::KeyValuePair *)exe.allocator->alloc(sizeof(cub::KeyValuePair) * num_segments, allocation_dout); + cub::KeyValuePair *d_out = (cub::KeyValuePair *)exe.temp_allocator->alloc(sizeof(cub::KeyValuePair) * num_segments, allocation_dout); // Determine temporary device storage requirementse = NULL; void *d_temp_storage = NULL; @@ -96,7 +96,7 @@ void CubArgReduceOp::jit_run() { xp, d_out, num_segments, offsetsp, offsetsp + 1); // Allocate temporary storage size_t allocation; - d_temp_storage = exe.allocator->alloc(temp_storage_bytes, allocation); + d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, allocation); // Run sorting operation cub::DeviceSegmentedReduce::@FUNC@@(d_temp_storage, temp_storage_bytes, xp, d_out, num_segments, offsetsp, offsetsp + 1); @@ -105,8 +105,8 @@ void CubArgReduceOp::jit_run() { auto* __restrict__ y_keyp = y_key->ptr(); split<<>>(d_out, y_keyp, yp, num_segments); - exe.allocator->free(d_temp_storage, temp_storage_bytes, allocation); - exe.allocator->free(d_out, sizeof(cub::KeyValuePair) * num_segments, allocation_dout); + exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, allocation); + exe.temp_allocator->free(d_out, sizeof(cub::KeyValuePair) * num_segments, allocation_dout); } #endif // JIT_cuda #endif // JIT diff --git a/extern/cuda/cub/ops/cub_argsort_op.cc b/extern/cuda/cub/ops/cub_argsort_op.cc index 4e47d276..1ca57c15 100644 --- a/extern/cuda/cub/ops/cub_argsort_op.cc +++ b/extern/cuda/cub/ops/cub_argsort_op.cc @@ -85,12 +85,12 @@ void CubArgsortOp::jit_run() { num_items, num_segments, offsetsp, offsetsp + 1); // Allocate temporary storage size_t allocation; - d_temp_storage = exe.allocator->alloc(temp_storage_bytes, allocation); + d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, allocation); // Run sorting operation cub::DeviceSegmentedRadixSort::@FUNC@@(d_temp_storage, temp_storage_bytes, xp, y_keyp, indexesp, yp, num_items, num_segments, offsetsp, offsetsp + 1); - exe.allocator->free(d_temp_storage, temp_storage_bytes, allocation); + exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, allocation); } #endif // JIT_cuda #endif // JIT diff --git a/extern/cuda/cub/ops/cub_where_op.cc b/extern/cuda/cub/ops/cub_where_op.cc index bca1f8b3..65e459ac 100644 --- a/extern/cuda/cub/ops/cub_where_op.cc +++ b/extern/cuda/cub/ops/cub_where_op.cc @@ -82,7 +82,7 @@ void CubWhereOp::jit_run(){ int N = cond->num; size_t temp_storage_bytes=0; size_t num_nonzeros_allocation; - auto num_nonzeros = exe.allocator->alloc(sizeof(To), num_nonzeros_allocation); + auto num_nonzeros = exe.temp_allocator->alloc(sizeof(To), num_nonzeros_allocation); size_t temp_storage_allocation; void* temp_storage; @@ -93,9 +93,9 @@ void CubWhereOp::jit_run(){ cub::TransformInputIterator, Ti*> itr(cond->ptr(), NonZeroOp()); temp_storage_bytes = 0; checkCudaErrors(cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, out_temp, (To*)num_nonzeros, N)); - temp_storage = exe.allocator->alloc(temp_storage_bytes, temp_storage_allocation); + temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, temp_storage_allocation); checkCudaErrors(cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, counting_itr, itr,out_temp, (To*)num_nonzeros, N)); - exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation); + exe.temp_allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation); To num_nonzeros_h; cudaMemcpy(&num_nonzeros_h, num_nonzeros, sizeof(To), cudaMemcpyDeviceToHost); @@ -110,7 +110,7 @@ void CubWhereOp::jit_run(){ @for(i, 0, NDIM, 1, , cond->shape[@i], outs[@i]->ptr()) ); } - exe.allocator->free(num_nonzeros, sizeof(int), num_nonzeros_allocation); + exe.temp_allocator->free(num_nonzeros, sizeof(int), num_nonzeros_allocation); } #endif diff --git a/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc b/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc index ed1c6f8c..ce1c18a5 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc +++ b/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc @@ -203,7 +203,7 @@ void CudnnConvBackwardWOp::jit_run() { if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; } size_t allocation; - void* ws = exe.allocator->alloc(max_ws_size, allocation); + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); checkCudaErrors(cudnnFindConvolutionBackwardFilterAlgorithmEx( handle_, cudnnIdesc, x->ptr(), @@ -215,7 +215,7 @@ void CudnnConvBackwardWOp::jit_run() { perf_results, ws, max_ws_size)); - exe.allocator->free(ws, max_ws_size, allocation); + exe.temp_allocator->free(ws, max_ws_size, allocation); } else { checkCudaErrors(cudnnGetConvolutionBackwardFilterAlgorithm_v7( handle_, @@ -250,7 +250,7 @@ void CudnnConvBackwardWOp::jit_run() { cudnnFdesc, algo, &workSpaceSize)); size_t allocation; if (workSpaceSize > 0) { - workSpace = exe.allocator->alloc(workSpaceSize, allocation); + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); } float alpha=1, beta=0; checkCudaErrors(cudnnConvolutionBackwardFilter( @@ -265,7 +265,7 @@ void CudnnConvBackwardWOp::jit_run() { cudnnFdesc, w->ptr()) ); if (workSpace) - exe.allocator->free(workSpace, workSpaceSize, allocation); + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); diff --git a/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc b/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc index 5ecb503a..bbf72a1f 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc +++ b/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc @@ -204,7 +204,7 @@ void CudnnConvBackwardXOp::jit_run() { if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; } size_t allocation; - void* ws = exe.allocator->alloc(max_ws_size, allocation); + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); checkCudaErrors(cudnnFindConvolutionBackwardDataAlgorithmEx( handle_, cudnnFdesc, w->ptr(), @@ -216,7 +216,7 @@ void CudnnConvBackwardXOp::jit_run() { perf_results, ws, max_ws_size)); - exe.allocator->free(ws, max_ws_size, allocation); + exe.temp_allocator->free(ws, max_ws_size, allocation); } else { checkCudaErrors(cudnnGetConvolutionBackwardDataAlgorithm_v7( handle_, @@ -251,7 +251,7 @@ void CudnnConvBackwardXOp::jit_run() { cudnnIdesc, algo, &workSpaceSize)); size_t allocation; if (workSpaceSize > 0) { - workSpace = exe.allocator->alloc(workSpaceSize, allocation); + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); } float alpha=1, beta=0; checkCudaErrors(cudnnConvolutionBackwardData( @@ -266,7 +266,7 @@ void CudnnConvBackwardXOp::jit_run() { cudnnIdesc, x->ptr()) ); if (workSpace) - exe.allocator->free(workSpace, workSpaceSize, allocation); + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); diff --git a/extern/cuda/cudnn/ops/cudnn_conv_op.cc b/extern/cuda/cudnn/ops/cudnn_conv_op.cc index 5798e789..548b5ae3 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_op.cc +++ b/extern/cuda/cudnn/ops/cudnn_conv_op.cc @@ -208,7 +208,7 @@ void CudnnConvOp::jit_run() { if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; } size_t allocation; - void* ws = exe.allocator->alloc(max_ws_size, allocation); + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); checkCudaErrors(cudnnFindConvolutionForwardAlgorithmEx( handle_, cudnnIdesc, x->ptr(), @@ -220,7 +220,7 @@ void CudnnConvOp::jit_run() { perf_results, ws, max_ws_size)); - exe.allocator->free(ws, max_ws_size, allocation); + exe.temp_allocator->free(ws, max_ws_size, allocation); } else { checkCudaErrors(cudnnGetConvolutionForwardAlgorithm_v7( handle_, @@ -255,7 +255,7 @@ void CudnnConvOp::jit_run() { cudnnOdesc, algo, &workSpaceSize) ); size_t allocation; if (workSpaceSize > 0) { - workSpace = exe.allocator->alloc(workSpaceSize, allocation); + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); } float alpha=1, beta=0; checkCudaErrors(cudnnConvolutionForward( @@ -270,7 +270,7 @@ void CudnnConvOp::jit_run() { cudnnOdesc, y->ptr()) ); if (workSpace) - exe.allocator->free(workSpace, workSpaceSize, allocation); + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); diff --git a/src/executor.cc b/src/executor.cc index 4401c72d..72d4ce18 100644 --- a/src/executor.cc +++ b/src/executor.cc @@ -92,7 +92,9 @@ void load_fused_op(FusedOp& fused_op, vector& fuse_ops, vector& ops, i void Executor::run_sync(vector vars, bool device_sync) { auto allocator = get_allocator(); + auto temp_allocator = get_allocator(true); this->allocator = allocator; + this->temp_allocator = temp_allocator; // bfs find all ops need to run int op_num = 0; vector bfs_q; diff --git a/src/executor.h b/src/executor.h index 2126a880..9baef47c 100644 --- a/src/executor.h +++ b/src/executor.h @@ -16,6 +16,7 @@ namespace jittor { struct Executor { Allocator* allocator; + Allocator* temp_allocator; bool last_is_cuda = false; void run_sync(vector vars, bool device_sync); }; diff --git a/src/mem/allocator.cc b/src/mem/allocator.cc index d3ec59f2..7145d511 100644 --- a/src/mem/allocator.cc +++ b/src/mem/allocator.cc @@ -15,6 +15,7 @@ #include "mem/allocator/stat_allocator.h" #include "mem/allocator/sfrl_allocator.h" #include "mem/allocator/nfef_allocator.h" +#include "mem/allocator/temp_allocator.h" namespace jittor { @@ -46,7 +47,7 @@ Allocator* setup_allocator(Allocator* underlying) { Allocator* cpu_allocator = setup_allocator(&aligned_allocator); -Allocator* get_allocator() { +Allocator* get_allocator(bool temp_allocator) { Allocator* allocator = nullptr; #ifdef HAS_CUDA if (use_cuda && !allocator) { @@ -72,7 +73,10 @@ Allocator* get_allocator() { allocator = setup_allocator(allocator); return allocator; } - if (use_sfrl_allocator) { + if (temp_allocator && use_temp_allocator) { + LOGvv << "Using temp_allocator"; + allocator = setup_allocator(allocator); + } else if (use_sfrl_allocator) { LOGvv << "Using sfrl_allocator"; allocator = setup_allocator(allocator); } diff --git a/src/mem/allocator.h b/src/mem/allocator.h index 34553800..8f8c637e 100644 --- a/src/mem/allocator.h +++ b/src/mem/allocator.h @@ -49,7 +49,7 @@ struct Allocation { }; extern Allocator* cpu_allocator; -Allocator* get_allocator(); +Allocator* get_allocator(bool temp_allocator=false); // @pyjt(gc) void gc_all(); diff --git a/src/mem/allocator/temp_allocator.cc b/src/mem/allocator/temp_allocator.cc new file mode 100644 index 00000000..88c1398c --- /dev/null +++ b/src/mem/allocator/temp_allocator.cc @@ -0,0 +1,116 @@ +// *************************************************************** +// Copyright (c) 2020 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "mem/allocator/temp_allocator.h" + +namespace jittor { + +DEFINE_FLAG(int, use_temp_allocator, 1, "Enable temp allocator"); + +TempAllocator::~TempAllocator() { + while (!cached_blocks.empty()) { + auto it = cached_blocks.begin(); + TempCachingBlock* block = it->second; + cached_blocks.erase(it); + delete block; + } +} + +const char* TempAllocator::name() const {return "temp";} + +void TempAllocator::setup(Allocator* underlying) { + this->underlying = underlying; +} + +size_t TempAllocator::align_size(size_t size) { + return (size + ALIGN_SIZE - 1) / ALIGN_SIZE * ALIGN_SIZE; +} + +unsigned long long TempAllocator::get_key(TempCachingBlock* block) { + return ((unsigned long long)block->size) * ID_LIMIT + block->id; +} + +void* TempAllocator::alloc(size_t size, size_t& allocation) { + size = align_size(size); + + auto temp = TempCachingBlock(size); + auto it = cached_blocks.lower_bound(get_key(&temp)); + TempCachingBlock* block = nullptr; + if (it != cached_blocks.end()) { + block = it->second; + cached_blocks.erase(it); + unused_memory -= block->size; + } else { + void* ptr = underlying->alloc(size, allocation); + block = new TempCachingBlock(size, ptr); + size_t id; + if (!block_ids.empty()) { + id = block_ids.back(); + block_ids.pop_back(); + } else { + ASSERT(tot_block_id < ID_LIMIT - 1) << "block id limit extended."; + id = ++tot_block_id; + } + block->id = id; + } + + used_memory += block->size; + occupied_id_mapper[block->id] = block; + allocation = block->id; + return block->memory_ptr; +} + +void TempAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + size = align_size(size); + ASSERT(occupied_id_mapper[allocation] != nullptr) << "allocation not found"; + TempCachingBlock* block = occupied_id_mapper[allocation]; + occupied_id_mapper[allocation] = nullptr; + used_memory -= block->size; + unused_memory += block->size; + bool can_add = true; + if (cached_blocks.size() > cache_blocks_limit-1) { + ASSERT(cached_blocks.size() == cache_blocks_limit); + auto it = cached_blocks.lower_bound(get_key(block)); + if (it == cached_blocks.begin()) { + can_add = false; + } else { + --it; + TempCachingBlock* block = it->second; + underlying->free((void*)block->memory_ptr, block->size, 0); + unused_memory -= block->size; + block_ids.push_back(block->id); + cached_blocks.erase(it); + delete block; + } + } + if (can_add) { + cached_blocks[get_key(block)] = block; + } +} + +void TempAllocator::gc() { + while (!cached_blocks.empty()) { + auto it = cached_blocks.begin(); + TempCachingBlock* block = it->second; + underlying->free((void*)block->memory_ptr, block->size, 0); + unused_memory -= block->size; + block_ids.push_back(block->id); + cached_blocks.erase(it); + delete block; + } +} + +bool TempAllocator::share_with(size_t size, size_t allocation) { + ASSERT(false); + return true; +} + +} // jittor + diff --git a/src/mem/allocator/temp_allocator.h b/src/mem/allocator/temp_allocator.h new file mode 100644 index 00000000..0402e421 --- /dev/null +++ b/src/mem/allocator/temp_allocator.h @@ -0,0 +1,57 @@ +// *************************************************************** +// Copyright (c) 2020 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "mem/allocator.h" + +namespace jittor { + +struct TempCachingBlock { + size_t size; + size_t id; + void* memory_ptr; + + TempCachingBlock(size_t size):size(size),id(0) {} + TempCachingBlock(size_t size, void* memory_ptr):size(size),id(0), memory_ptr(memory_ptr) {} +}; + +struct TempAllocator : Allocator { + static const size_t ALIGN_SIZE = 512; + static const size_t ID_LIMIT = 1 << 18; + Allocator* underlying; + size_t cache_blocks_limit, used_memory, unused_memory; + std::map cached_blocks; + std::vector block_ids; + size_t tot_block_id; + std::unique_ptr occupied_id_mapper; + + + inline TempAllocator(size_t cache_blocks_limit=2) : cache_blocks_limit(cache_blocks_limit), used_memory(0), unused_memory(0), tot_block_id(0), occupied_id_mapper(new TempCachingBlock*[ID_LIMIT]) { + } + inline TempAllocator(Allocator* underlying, size_t cache_blocks_limit=2) : TempAllocator(cache_blocks_limit) { + setup(underlying); + } + ~TempAllocator(); + + size_t align_size(size_t size); + unsigned long long get_key(TempCachingBlock* block); + // free all unused memory of all sfrl allocators. + void setup(Allocator* underlying); + uint64 flags() const override { return underlying->flags(); } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; + void gc() override; + virtual bool share_with(size_t size, size_t allocation) override; +}; + +DECLARE_FLAG(int, use_temp_allocator); + +}//jittor + diff --git a/src/mem/mem_info.cc b/src/mem/mem_info.cc index 375f5f7d..f935685c 100644 --- a/src/mem/mem_info.cc +++ b/src/mem/mem_info.cc @@ -15,8 +15,10 @@ #include "misc/cuda_flags.h" #include "mem/allocator/sfrl_allocator.h" #include "mem/allocator/stat_allocator.h" +#include "mem/allocator/temp_allocator.h" #include "mem/mem_info.h" #include "update_queue.h" +#include "executor.h" namespace jittor { @@ -101,7 +103,13 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) { log << "cpu&gpu:" << FloatOutput{(double)all_total, " KMG", 1024, "B"} << "gpu:" << FloatOutput{(double)gpu_total, " KMG", 1024, "B"} << "cpu:" << FloatOutput{(double)cpu_total, " KMG", 1024, "B"} >> '\n'; - + if (use_temp_allocator) { + TempAllocator* temp_allocator = (TempAllocator*)exe.temp_allocator; + log << "\nname:" << temp_allocator->name() << "\n"; + log << "used_memory:" << FloatOutput{(double)temp_allocator->used_memory, " KMG", 1024, "B"} << "\n"; + log << "unused_memory:" << FloatOutput{(double)temp_allocator->unused_memory, " KMG", 1024, "B"} << "\n"; + + } if (dump_var) { vector queue; unordered_set visited; diff --git a/src/ops/candidate_op.cc b/src/ops/candidate_op.cc index c68fc76b..2cdda0c0 100644 --- a/src/ops/candidate_op.cc +++ b/src/ops/candidate_op.cc @@ -76,9 +76,9 @@ void CandidateOp::jit_run() { // define ys auto* __restrict__ yp = y->ptr(); size_t n_allocation; - int* np = (int*)exe.allocator->alloc(4, n_allocation); + int* np = (int*)exe.temp_allocator->alloc(4, n_allocation); size_t mask_allocation; - bool* maskp = (bool*)exe.allocator->alloc(xshape0, mask_allocation); + bool* maskp = (bool*)exe.temp_allocator->alloc(xshape0, mask_allocation); checkCudaErrors(cudaMemsetAsync(maskp, 1, xshape0)); candidate_kernel<<<1, std::max(1, std::min(1024, xshape0)) >>>( @@ -93,8 +93,8 @@ void CandidateOp::jit_run() { // checkCudaErrors(cudaDeviceSynchronize()); checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault)); y->set_shape({n}); - exe.allocator->free(np, 4, n_allocation); - exe.allocator->free(maskp, xshape0, mask_allocation); + exe.temp_allocator->free(np, 4, n_allocation); + exe.temp_allocator->free(maskp, xshape0, mask_allocation); } #else void CandidateOp::jit_run() { diff --git a/src/ops/where_op.cc b/src/ops/where_op.cc index 0cf31899..d3c5a3ca 100644 --- a/src/ops/where_op.cc +++ b/src/ops/where_op.cc @@ -196,7 +196,7 @@ void WhereOp::jit_run() { @for(i, 0, NDIM, auto* __restrict__ outs@i@@p = outs[@i]->ptr();) size_t n_allocation; - int* np = (int*)exe.allocator->alloc(4, n_allocation); + int* np = (int*)exe.temp_allocator->alloc(4, n_allocation); // one block kernel, result maybe unstable // int tnum = condshape@{NDIM-1}; @@ -232,7 +232,7 @@ void WhereOp::jit_run() { // checkCudaErrors(cudaDeviceSynchronize()); checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault)); @for(i, 0, NDIM, outs[@i]->set_shape({n});) - exe.allocator->free(np, 4, n_allocation); + exe.temp_allocator->free(np, 4, n_allocation); } #else diff --git a/src/opt/gopt/setitem_gopt.cc b/src/opt/gopt/setitem_gopt.cc index 44414457..0831609d 100644 --- a/src/opt/gopt/setitem_gopt.cc +++ b/src/opt/gopt/setitem_gopt.cc @@ -48,7 +48,7 @@ static void setitem_inplace(SetitemOp* op) { } auto output = op->outputs().front(); output->share_with(input); - // return; + return; // LOGir << "pass setitem optim one"; From 8cd3dcb2349e5dfb89f025d066f4be64b18de7ff Mon Sep 17 00:00:00 2001 From: cxjyxx_me <498731903@qq.com> Date: Thu, 14 Jan 2021 11:07:51 +0800 Subject: [PATCH 6/8] test --- python/jittor/misc.py | 97 ++++++++++++++-------- python/jittor/test/test_memory_profiler.py | 11 ++- 2 files changed, 74 insertions(+), 34 deletions(-) diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 33322083..12cc1756 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -722,7 +722,7 @@ def triu_(x,diagonal=0): jt.Var.triu_ = triu_ -def print_tree(now, max_memory_size, prefix1='', prefix2=''): +def print_tree(now, max_memory_size, prefix1, prefix2, build_by): def format_size(s): if (s < 1024): s = str(s) @@ -739,15 +739,19 @@ def print_tree(now, max_memory_size, prefix1='', prefix2=''): s = format(s/1024/1024/1024, '.2f') return s + ' GB' + out = '' tab = ' ' - print(prefix1+now['name']+'('+now['type']+')') - print(prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%]') - for p in now['path']: - print(prefix2+p) - if (len(now['children']) > 0): - print(prefix2 + tab + '| ') + out += prefix1+now['name']+'('+now['type']+')\n' + out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%]\n' + if (build_by == 0): + for p in now['path']: + out += prefix2+p+'\n' else: - print(prefix2) + out += prefix2+now['path'] + '\n' + if (len(now['children']) > 0): + out += prefix2 + tab + '| ' + '\n' + else: + out += prefix2 + '\n' for i in range(len(now['children'])): c = now['children'][i] if i < len(now['children']) - 1: @@ -756,9 +760,10 @@ def print_tree(now, max_memory_size, prefix1='', prefix2=''): else: prefix1_ = prefix2 + tab + '└─' prefix2_ = prefix2 + tab + ' ' - print_tree(c, max_memory_size, prefix1_, prefix2_) + out += print_tree(c, max_memory_size, prefix1_, prefix2_, build_by) + return out -def get_max_memory_treemap(): +def get_max_memory_treemap(build_by=0, do_print=True): div1 = "[!@#div1!@#]" div2 = "[!@#div2!@#]" div3 = "[!@#div3!@#]" @@ -777,28 +782,52 @@ def get_max_memory_treemap(): s = {'path':s__[0], 'name':s__[1], 'type':s__[2]} var['stack'].append(s) vars.append(var) - tree = {'name':'root', "children":[], 'size':0, 'path':[], 'type':''} + if (build_by == 0): # build tree by name + tree = {'name':'root', "children":[], 'size':0, 'path':[], 'type':''} - def find_child(now, key): - for c in now['children']: - if (c['name'] == key): - return c - return None - for v in vars: - now = tree - now['size'] += v['size'] - for s in v['stack']: - ch = find_child(now, s['name']) - if (ch is not None): - if (not s['path'] in ch['path']): - ch['path'].append(s['path']) - assert(ch['type']==s['type']) - now = ch - now['size'] += v['size'] - else: - now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':[s['path']], 'type':s['type']} - now['children'].append(now_) - now = now_ + def find_child(now, key): + for c in now['children']: + if (c['name'] == key): + return c + return None + for v in vars: + now = tree + now['size'] += v['size'] + for s in v['stack']: + ch = find_child(now, s['name']) + if (ch is not None): + if (not s['path'] in ch['path']): + ch['path'].append(s['path']) + assert(ch['type']==s['type']) + now = ch + now['size'] += v['size'] + else: + now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':[s['path']], 'type':s['type']} + now['children'].append(now_) + now = now_ + elif (build_by == 1): # build tree by path + tree = {'name':'root', "children":[], 'size':0, 'path':'_root_', 'type':''} + + def find_child(now, key): + for c in now['children']: + if (c['path'] == key): + return c + return None + for v in vars: + now = tree + now['size'] += v['size'] + for s in v['stack']: + ch = find_child(now, s['path']) + if (ch is not None): + now = ch + now['size'] += v['size'] + else: + now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':s['path'], 'type':s['type']} + now['children'].append(now_) + now = now_ + else: + assert(False) + def sort_tree(now): def takeSize(elem): return elem['size'] @@ -806,5 +835,7 @@ def get_max_memory_treemap(): for c in now['children']: sort_tree(c) sort_tree(tree) - print_tree(tree, max_memory_size, '', '') - return tree \ No newline at end of file + out = print_tree(tree, max_memory_size, '', '', build_by) + if (do_print): + print(out) + return tree, out \ No newline at end of file diff --git a/python/jittor/test/test_memory_profiler.py b/python/jittor/test/test_memory_profiler.py index 55613a09..02085a93 100644 --- a/python/jittor/test/test_memory_profiler.py +++ b/python/jittor/test/test_memory_profiler.py @@ -82,7 +82,16 @@ class TestMemoryProfiler(unittest.TestCase): jt.fetch(batch_idx, loss, output, target, callback) jt.sync_all(True) jt.display_max_memory_info() - jt.get_max_memory_treemap() + _, out = jt.get_max_memory_treemap() + out_ = out.split('\n') + assert(out_[0] == 'root()') + assert(out_[3] == ' ├─mnist_net(MnistNet)') + assert(out_[7] == ' | └─model(ResNet)') + _, out = jt.get_max_memory_treemap(build_by=1) + out_ = out.split('\n') + assert(out_[0] == 'root()') + assert(out_[4] == ' ├─mnist_net(MnistNet)') + assert(out_[8] == ' | └─model(ResNet)') if __name__ == "__main__": unittest.main() From 6c83e40883e903063d3313df949a1452a305919e Mon Sep 17 00:00:00 2001 From: cxjyxx_me <498731903@qq.com> Date: Thu, 14 Jan 2021 13:09:20 +0800 Subject: [PATCH 7/8] fix --- python/jittor/misc.py | 2 +- python/jittor/test/test_memory_profiler.py | 18 +++++++++--------- src/executor.cc | 9 +++++++-- src/executor.h | 7 +++++-- src/mem/allocator/temp_allocator.cc | 2 +- src/mem/allocator/temp_allocator.h | 2 +- src/memory_profiler.cc | 17 ++++++++++++++++- src/memory_profiler.h | 7 +++++-- src/opt/gopt/setitem_gopt.cc | 1 - src/pybind/py_var_tracer.cc | 9 ++++++--- 10 files changed, 51 insertions(+), 23 deletions(-) diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 5f725476..2345a864 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -3,7 +3,7 @@ # Maintainers: # Dun Liang . # Wenyang Zhou <576825820@qq.com> -# +# Guoye Yang <498731903@qq.com> # # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. diff --git a/python/jittor/test/test_memory_profiler.py b/python/jittor/test/test_memory_profiler.py index 02085a93..3eb59b5d 100644 --- a/python/jittor/test/test_memory_profiler.py +++ b/python/jittor/test/test_memory_profiler.py @@ -1,9 +1,9 @@ # *************************************************************** -# Copyright (c) 2020 Jittor. Authors: -# Guowei Yang <471184555@qq.com> -# Meng-Hao Guo +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> # Dun Liang . -# All Rights Reserved. +# # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** @@ -53,7 +53,7 @@ class TestMemoryProfiler(unittest.TestCase): jt.seed(seed) @unittest.skipIf(not jt.has_cuda, "Cuda not found") - @jt.flag_scope(use_cuda=1, use_stat_allocator=1, trace_py_var=2) + @jt.flag_scope(use_cuda=1, use_stat_allocator=1, trace_py_var=3, profile_memory_enable=1) def test_resnet(self): self.setup_seed(1) loss_list=[] @@ -85,13 +85,13 @@ class TestMemoryProfiler(unittest.TestCase): _, out = jt.get_max_memory_treemap() out_ = out.split('\n') assert(out_[0] == 'root()') - assert(out_[3] == ' ├─mnist_net(MnistNet)') - assert(out_[7] == ' | └─model(ResNet)') + assert(out_[3] == ' ├─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:194(_run_module_as_main)') + assert(out_[7] == ' | └─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:87(_run_code)') _, out = jt.get_max_memory_treemap(build_by=1) out_ = out.split('\n') assert(out_[0] == 'root()') - assert(out_[4] == ' ├─mnist_net(MnistNet)') - assert(out_[8] == ' | └─model(ResNet)') + assert(out_[4] == ' ├─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:194(_run_module_as_main)') + assert(out_[8] == ' | └─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:87(_run_code)') if __name__ == "__main__": unittest.main() diff --git a/src/executor.cc b/src/executor.cc index b80905ae..3f1b87f4 100644 --- a/src/executor.cc +++ b/src/executor.cc @@ -1,6 +1,9 @@ // *************************************************************** // Copyright (c) 2021 Jittor. All Rights Reserved. -// Maintainers: Dun Liang . +// Maintainers: +// Dun Liang . +// Guoye Yang <498731903@qq.com> +// // This file is subject to the terms and conditions defined in // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** @@ -29,6 +32,7 @@ namespace jittor { Executor exe; extern MemoryProfiler memory_profiler; +DECLARE_FLAG(int, profile_memory_enable); // from fetch_op.cc extern list fetcher_to_free; @@ -424,7 +428,8 @@ void Executor::run_sync(vector vars, bool device_sync) { for (auto* var : op->outputs()) { var->alloc(allocator); } - memory_profiler.check(); + if (PREDICT_BRANCH_NOT_TAKEN(profile_memory_enable)) + memory_profiler.check(); LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs(); op->do_prepare(jkl); bool is_cuda = op->flags.get(NodeFlags::_cuda); diff --git a/src/executor.h b/src/executor.h index cc5d4eb2..5c924e98 100644 --- a/src/executor.h +++ b/src/executor.h @@ -1,6 +1,9 @@ // *************************************************************** -// Copyright (c) 2021 Jittor. All Rights Reserved. -// Maintainers: Dun Liang . +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// Guoye Yang <498731903@qq.com> +// // This file is subject to the terms and conditions defined in // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** diff --git a/src/mem/allocator/temp_allocator.cc b/src/mem/allocator/temp_allocator.cc index 88c1398c..b6b2c5a9 100644 --- a/src/mem/allocator/temp_allocator.cc +++ b/src/mem/allocator/temp_allocator.cc @@ -1,5 +1,5 @@ // *************************************************************** -// Copyright (c) 2020 Jittor. All Rights Reserved. +// Copyright (c) 2021 Jittor. All Rights Reserved. // Maintainers: // Guoye Yang <498731903@qq.com> // Dun Liang . diff --git a/src/mem/allocator/temp_allocator.h b/src/mem/allocator/temp_allocator.h index 0402e421..50a0969b 100644 --- a/src/mem/allocator/temp_allocator.h +++ b/src/mem/allocator/temp_allocator.h @@ -1,5 +1,5 @@ // *************************************************************** -// Copyright (c) 2020 Jittor. All Rights Reserved. +// Copyright (c) 2021 Jittor. All Rights Reserved. // Maintainers: // Guoye Yang <498731903@qq.com> // Dun Liang . diff --git a/src/memory_profiler.cc b/src/memory_profiler.cc index 27dbcd8a..572ff213 100644 --- a/src/memory_profiler.cc +++ b/src/memory_profiler.cc @@ -1,3 +1,12 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** #include "memory_profiler.h" #include "graph.h" #include "var_holder.h" @@ -19,7 +28,7 @@ struct FloatOutput_ { string suffix; int p=4; }; -std::ostream& operator<<(std::ostream& os, const FloatOutput_& o) { +inline std::ostream& operator<<(std::ostream& os, const FloatOutput_& o) { int w = 8; os << std::setw(w-2-o.suffix.size()); os << std::setprecision(o.p); @@ -47,6 +56,7 @@ void MemoryProfiler::clear() { } std::pair MemoryProfiler::get_memory_info() { + ASSERT(profile_memory_enable == 1); size_t used = 0; size_t unused = 0; //TODO add mssfrl allocator @@ -58,6 +68,7 @@ std::pair MemoryProfiler::get_memory_info() { } void MemoryProfiler::check() { + ASSERT(profile_memory_enable == 1); std::pair mem_info = get_memory_info(); if (mem_info.first > max_used_memory_size) { max_used_memory_size = mem_info.first; @@ -99,6 +110,7 @@ bool MemoryProfiler::cmp(const std::pair>, size_ } void MemoryProfiler::display_max_memory_info() { + ASSERT(profile_memory_enable == 1); Log log("", 'i', 0); std::sort(max_live_vars.begin(), max_live_vars.end(), cmp); log << "\n=====display_max_memory_info=====\n"; @@ -117,10 +129,12 @@ void MemoryProfiler::display_max_memory_info() { } void display_max_memory_info() { + ASSERT(profile_memory_enable == 1); memory_profiler.display_max_memory_info(); } string MemoryProfiler::get_max_memory_info() { + ASSERT(profile_memory_enable == 1); std::stringstream out; string div1 = "[!@#div1!@#]"; string div2 = "[!@#div2!@#]"; @@ -142,6 +156,7 @@ string MemoryProfiler::get_max_memory_info() { } string get_max_memory_info() { + ASSERT(profile_memory_enable == 1); return memory_profiler.get_max_memory_info(); } diff --git a/src/memory_profiler.h b/src/memory_profiler.h index 7a0ce12b..be1daac2 100644 --- a/src/memory_profiler.h +++ b/src/memory_profiler.h @@ -1,6 +1,9 @@ // *************************************************************** -// Copyright (c) 2020 Jittor. All Rights Reserved. -// Authors: Dun Liang . +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// // This file is subject to the terms and conditions defined in // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** diff --git a/src/opt/gopt/setitem_gopt.cc b/src/opt/gopt/setitem_gopt.cc index 16867b5a..c777d2a0 100644 --- a/src/opt/gopt/setitem_gopt.cc +++ b/src/opt/gopt/setitem_gopt.cc @@ -48,7 +48,6 @@ static void setitem_inplace(SetitemOp* op) { } auto output = op->outputs().front(); output->share_with(input); - return; auto data = op->input(1); // if setitem requires type conversion, don't inplace diff --git a/src/pybind/py_var_tracer.cc b/src/pybind/py_var_tracer.cc index 36fe8352..e46b2afe 100644 --- a/src/pybind/py_var_tracer.cc +++ b/src/pybind/py_var_tracer.cc @@ -1,6 +1,9 @@ // *************************************************************** -// Copyright (c) 2021 Jittor. All Rights Reserved. -// Maintainers: Dun Liang . +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// Guoye Yang <498731903@qq.com> +// // This file is subject to the terms and conditions defined in // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** @@ -186,7 +189,7 @@ void TraceData::record_node(Node* node, bool record_stack) { NodeData data; data.id = node_data_cnt++; id_map[node] = data.id; - if (trace_py_var) { + if (!node->is_var() || trace_py_var>=3) { if (record_stack) { if (trace_grad_op) { auto iter = trace_data.id_map.find(trace_grad_op); From f50faa5048fb6f00e958a3cbee25fd5f04eff246 Mon Sep 17 00:00:00 2001 From: guoye yang <498731903@qq.com> Date: Thu, 14 Jan 2021 14:52:55 +0800 Subject: [PATCH 8/8] fix --- python/jittor/test/test_memory_profiler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/jittor/test/test_memory_profiler.py b/python/jittor/test/test_memory_profiler.py index 3eb59b5d..3ad49127 100644 --- a/python/jittor/test/test_memory_profiler.py +++ b/python/jittor/test/test_memory_profiler.py @@ -85,13 +85,13 @@ class TestMemoryProfiler(unittest.TestCase): _, out = jt.get_max_memory_treemap() out_ = out.split('\n') assert(out_[0] == 'root()') - assert(out_[3] == ' ├─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:194(_run_module_as_main)') - assert(out_[7] == ' | └─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:87(_run_code)') + assert(out_[3].endswith('(_run_module_as_main)')) + assert(out_[7].endswith('(_run_code)')) _, out = jt.get_max_memory_treemap(build_by=1) out_ = out.split('\n') assert(out_[0] == 'root()') - assert(out_[4] == ' ├─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:194(_run_module_as_main)') - assert(out_[8] == ' | └─/home/cxjyxx_me/miniconda3/lib/python3.8/runpy.py:87(_run_code)') + assert(out_[4].endswith('(_run_module_as_main)')) + assert(out_[8].endswith('(_run_code)')) if __name__ == "__main__": unittest.main()