mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor
This commit is contained in:
commit
6fafe44810
|
@ -87,7 +87,7 @@ void CubArgReduceOp::jit_run() {
|
|||
num_segments *= x->shape[i];
|
||||
}
|
||||
size_t allocation_dout;
|
||||
cub::KeyValuePair<int, Tx> *d_out = (cub::KeyValuePair<int, Tx> *)exe.allocator->alloc(sizeof(cub::KeyValuePair<int, Tx>) * num_segments, allocation_dout);
|
||||
cub::KeyValuePair<int, Tx> *d_out = (cub::KeyValuePair<int, Tx> *)exe.temp_allocator->alloc(sizeof(cub::KeyValuePair<int, Tx>) * num_segments, allocation_dout);
|
||||
|
||||
// Determine temporary device storage requirementse = NULL;
|
||||
void *d_temp_storage = NULL;
|
||||
|
@ -96,7 +96,7 @@ void CubArgReduceOp::jit_run() {
|
|||
xp, d_out, num_segments, offsetsp, offsetsp + 1);
|
||||
// Allocate temporary storage
|
||||
size_t allocation;
|
||||
d_temp_storage = exe.allocator->alloc(temp_storage_bytes, allocation);
|
||||
d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, allocation);
|
||||
// Run sorting operation
|
||||
cub::DeviceSegmentedReduce::@FUNC@@(d_temp_storage, temp_storage_bytes,
|
||||
xp, d_out, num_segments, offsetsp, offsetsp + 1);
|
||||
|
@ -105,8 +105,8 @@ void CubArgReduceOp::jit_run() {
|
|||
auto* __restrict__ y_keyp = y_key->ptr<Tx>();
|
||||
split<<<max(1,num_segments/1024),1024>>>(d_out, y_keyp, yp, num_segments);
|
||||
|
||||
exe.allocator->free(d_temp_storage, temp_storage_bytes, allocation);
|
||||
exe.allocator->free(d_out, sizeof(cub::KeyValuePair<int, Tx>) * num_segments, allocation_dout);
|
||||
exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, allocation);
|
||||
exe.temp_allocator->free(d_out, sizeof(cub::KeyValuePair<int, Tx>) * num_segments, allocation_dout);
|
||||
}
|
||||
#endif // JIT_cuda
|
||||
#endif // JIT
|
||||
|
|
|
@ -85,12 +85,12 @@ void CubArgsortOp::jit_run() {
|
|||
num_items, num_segments, offsetsp, offsetsp + 1);
|
||||
// Allocate temporary storage
|
||||
size_t allocation;
|
||||
d_temp_storage = exe.allocator->alloc(temp_storage_bytes, allocation);
|
||||
d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, allocation);
|
||||
// Run sorting operation
|
||||
cub::DeviceSegmentedRadixSort::@FUNC@@(d_temp_storage, temp_storage_bytes,
|
||||
xp, y_keyp, indexesp, yp,
|
||||
num_items, num_segments, offsetsp, offsetsp + 1);
|
||||
exe.allocator->free(d_temp_storage, temp_storage_bytes, allocation);
|
||||
exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, allocation);
|
||||
}
|
||||
#endif // JIT_cuda
|
||||
#endif // JIT
|
||||
|
|
|
@ -82,7 +82,7 @@ void CubWhereOp::jit_run(){
|
|||
int N = cond->num;
|
||||
size_t temp_storage_bytes=0;
|
||||
size_t num_nonzeros_allocation;
|
||||
auto num_nonzeros = exe.allocator->alloc(sizeof(To), num_nonzeros_allocation);
|
||||
auto num_nonzeros = exe.temp_allocator->alloc(sizeof(To), num_nonzeros_allocation);
|
||||
|
||||
size_t temp_storage_allocation;
|
||||
void* temp_storage;
|
||||
|
@ -93,9 +93,9 @@ void CubWhereOp::jit_run(){
|
|||
cub::TransformInputIterator<bool, NonZeroOp<Ti>, Ti*> itr(cond->ptr<Ti>(), NonZeroOp<Ti>());
|
||||
temp_storage_bytes = 0;
|
||||
checkCudaErrors(cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, out_temp, (To*)num_nonzeros, N));
|
||||
temp_storage = exe.allocator->alloc(temp_storage_bytes, temp_storage_allocation);
|
||||
temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, temp_storage_allocation);
|
||||
checkCudaErrors(cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, counting_itr, itr,out_temp, (To*)num_nonzeros, N));
|
||||
exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation);
|
||||
exe.temp_allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation);
|
||||
|
||||
To num_nonzeros_h;
|
||||
cudaMemcpy(&num_nonzeros_h, num_nonzeros, sizeof(To), cudaMemcpyDeviceToHost);
|
||||
|
@ -110,7 +110,7 @@ void CubWhereOp::jit_run(){
|
|||
@for(i, 0, NDIM, 1, , cond->shape[@i], outs[@i]->ptr<To>())
|
||||
);
|
||||
}
|
||||
exe.allocator->free(num_nonzeros, sizeof(int), num_nonzeros_allocation);
|
||||
exe.temp_allocator->free(num_nonzeros, sizeof(int), num_nonzeros_allocation);
|
||||
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -203,7 +203,7 @@ void CudnnConvBackwardWOp::jit_run() {
|
|||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.allocator->alloc(max_ws_size, allocation);
|
||||
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
|
||||
checkCudaErrors(cudnnFindConvolutionBackwardFilterAlgorithmEx(
|
||||
handle_,
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
|
@ -215,7 +215,7 @@ void CudnnConvBackwardWOp::jit_run() {
|
|||
perf_results,
|
||||
ws,
|
||||
max_ws_size));
|
||||
exe.allocator->free(ws, max_ws_size, allocation);
|
||||
exe.temp_allocator->free(ws, max_ws_size, allocation);
|
||||
} else {
|
||||
checkCudaErrors(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
|
||||
handle_,
|
||||
|
@ -250,7 +250,7 @@ void CudnnConvBackwardWOp::jit_run() {
|
|||
cudnnFdesc, algo, &workSpaceSize));
|
||||
size_t allocation;
|
||||
if (workSpaceSize > 0) {
|
||||
workSpace = exe.allocator->alloc(workSpaceSize, allocation);
|
||||
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
|
||||
}
|
||||
float alpha=1, beta=0;
|
||||
checkCudaErrors(cudnnConvolutionBackwardFilter(
|
||||
|
@ -265,7 +265,7 @@ void CudnnConvBackwardWOp::jit_run() {
|
|||
cudnnFdesc, w->ptr<Tw>())
|
||||
);
|
||||
if (workSpace)
|
||||
exe.allocator->free(workSpace, workSpaceSize, allocation);
|
||||
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||
|
|
|
@ -204,7 +204,7 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.allocator->alloc(max_ws_size, allocation);
|
||||
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
|
||||
checkCudaErrors(cudnnFindConvolutionBackwardDataAlgorithmEx(
|
||||
handle_,
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
|
@ -216,7 +216,7 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
perf_results,
|
||||
ws,
|
||||
max_ws_size));
|
||||
exe.allocator->free(ws, max_ws_size, allocation);
|
||||
exe.temp_allocator->free(ws, max_ws_size, allocation);
|
||||
} else {
|
||||
checkCudaErrors(cudnnGetConvolutionBackwardDataAlgorithm_v7(
|
||||
handle_,
|
||||
|
@ -251,7 +251,7 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
cudnnIdesc, algo, &workSpaceSize));
|
||||
size_t allocation;
|
||||
if (workSpaceSize > 0) {
|
||||
workSpace = exe.allocator->alloc(workSpaceSize, allocation);
|
||||
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
|
||||
}
|
||||
float alpha=1, beta=0;
|
||||
checkCudaErrors(cudnnConvolutionBackwardData(
|
||||
|
@ -266,7 +266,7 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
cudnnIdesc, x->ptr<Tx>())
|
||||
);
|
||||
if (workSpace)
|
||||
exe.allocator->free(workSpace, workSpaceSize, allocation);
|
||||
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||
|
|
|
@ -208,7 +208,7 @@ void CudnnConvOp::jit_run() {
|
|||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.allocator->alloc(max_ws_size, allocation);
|
||||
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
|
||||
checkCudaErrors(cudnnFindConvolutionForwardAlgorithmEx(
|
||||
handle_,
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
|
@ -220,7 +220,7 @@ void CudnnConvOp::jit_run() {
|
|||
perf_results,
|
||||
ws,
|
||||
max_ws_size));
|
||||
exe.allocator->free(ws, max_ws_size, allocation);
|
||||
exe.temp_allocator->free(ws, max_ws_size, allocation);
|
||||
} else {
|
||||
checkCudaErrors(cudnnGetConvolutionForwardAlgorithm_v7(
|
||||
handle_,
|
||||
|
@ -255,7 +255,7 @@ void CudnnConvOp::jit_run() {
|
|||
cudnnOdesc, algo, &workSpaceSize) );
|
||||
size_t allocation;
|
||||
if (workSpaceSize > 0) {
|
||||
workSpace = exe.allocator->alloc(workSpaceSize, allocation);
|
||||
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
|
||||
}
|
||||
float alpha=1, beta=0;
|
||||
checkCudaErrors(cudnnConvolutionForward(
|
||||
|
@ -270,7 +270,7 @@ void CudnnConvOp::jit_run() {
|
|||
cudnnOdesc, y->ptr<Ty>())
|
||||
);
|
||||
if (workSpace)
|
||||
exe.allocator->free(workSpace, workSpaceSize, allocation);
|
||||
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||
|
|
|
@ -128,7 +128,8 @@ class Dataset(object):
|
|||
if self.stop_grad else jt.array(x)
|
||||
if isinstance(batch, np.ndarray):
|
||||
return to_jt(batch)
|
||||
assert isinstance(batch, Sequence)
|
||||
if not isinstance(batch, (list, tuple)):
|
||||
return batch
|
||||
new_batch = []
|
||||
for a in batch:
|
||||
if isinstance(a, np.ndarray) or \
|
||||
|
|
|
@ -84,4 +84,58 @@ class ReduceLROnPlateau(object):
|
|||
save = self.threshold + 1.0
|
||||
return a > b * save
|
||||
else:
|
||||
return a > b + self.threshold
|
||||
return a > b + self.threshold
|
||||
|
||||
class CosineAnnealingLR(object):
|
||||
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
self.optimizer = optimizer
|
||||
self.last_epoch = last_epoch
|
||||
self.base_lr = optimizer.lr
|
||||
#TODO set last_epoch is not ready
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch == 0:
|
||||
return self.base_lr
|
||||
now_lr = self.optimizer.lr
|
||||
if (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
|
||||
return (now_lr + (self.base_lr - self.eta_min) *
|
||||
(1 - math.cos(math.pi / self.T_max)) / 2)
|
||||
return ((1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
|
||||
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
|
||||
(now_lr - self.eta_min) + self.eta_min)
|
||||
|
||||
def step(self):
|
||||
self.last_epoch += 1
|
||||
self.update_lr(self.get_lr())
|
||||
|
||||
def update_lr(self, new_lr):
|
||||
self.optimizer.lr = new_lr
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
if param_group.get("lr")!=None:
|
||||
param_group["lr"] = new_lr
|
||||
|
||||
class MultiStepLR(object):
|
||||
def __init__(self, optimizer, milestones=[], gamma=0.1, last_epoch=-1):
|
||||
self.optimizer = optimizer
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
self.last_epoch = last_epoch
|
||||
#TODO set last_epoch is not ready
|
||||
|
||||
def get_lr(self):
|
||||
now_lr = self.optimizer.lr
|
||||
if (self.last_epoch in self.milestones):
|
||||
now_lr *= gamma
|
||||
return now_lr
|
||||
|
||||
def step(self):
|
||||
self.last_epoch += 1
|
||||
self.update_lr(self.get_lr())
|
||||
|
||||
def update_lr(self, new_lr):
|
||||
self.optimizer.lr = new_lr
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
if param_group.get("lr")!=None:
|
||||
param_group["lr"] = new_lr
|
|
@ -3,7 +3,7 @@
|
|||
# Maintainers:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
#
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
|
@ -77,7 +77,7 @@ def repeat(x, *shape):
|
|||
x = x.broadcast(x_shape)
|
||||
elif len_x_shape > len_shape:
|
||||
rep_shape = (len_x_shape - len_shape) * [1] + shape
|
||||
|
||||
#TODO if input.shape[i]=1, no add [1]
|
||||
reshape_shape = []
|
||||
broadcast_shape = []
|
||||
for x_s,r_s in zip(x_shape,rep_shape):
|
||||
|
@ -722,6 +722,123 @@ def triu_(x,diagonal=0):
|
|||
|
||||
jt.Var.triu_ = triu_
|
||||
|
||||
def print_tree(now, max_memory_size, prefix1, prefix2, build_by):
|
||||
def format_size(s):
|
||||
if (s < 1024):
|
||||
s = str(s)
|
||||
return s + ' B'
|
||||
|
||||
if (s < 1024*1024):
|
||||
s = format(s/1024, '.2f')
|
||||
return s + ' KB'
|
||||
|
||||
if (s < 1024*1024*1024):
|
||||
s = format(s/1024/1024, '.2f')
|
||||
return s + ' MB'
|
||||
|
||||
s = format(s/1024/1024/1024, '.2f')
|
||||
return s + ' GB'
|
||||
|
||||
out = ''
|
||||
tab = ' '
|
||||
out += prefix1+now['name']+'('+now['type']+')\n'
|
||||
out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%]\n'
|
||||
if (build_by == 0):
|
||||
for p in now['path']:
|
||||
out += prefix2+p+'\n'
|
||||
else:
|
||||
out += prefix2+now['path'] + '\n'
|
||||
if (len(now['children']) > 0):
|
||||
out += prefix2 + tab + '| ' + '\n'
|
||||
else:
|
||||
out += prefix2 + '\n'
|
||||
for i in range(len(now['children'])):
|
||||
c = now['children'][i]
|
||||
if i < len(now['children']) - 1:
|
||||
prefix1_ = prefix2 + tab + '├─'
|
||||
prefix2_ = prefix2 + tab + '| '
|
||||
else:
|
||||
prefix1_ = prefix2 + tab + '└─'
|
||||
prefix2_ = prefix2 + tab + ' '
|
||||
out += print_tree(c, max_memory_size, prefix1_, prefix2_, build_by)
|
||||
return out
|
||||
|
||||
def get_max_memory_treemap(build_by=0, do_print=True):
|
||||
div1 = "[!@#div1!@#]"
|
||||
div2 = "[!@#div2!@#]"
|
||||
div3 = "[!@#div3!@#]"
|
||||
info = jt.get_max_memory_info()
|
||||
|
||||
vars = []
|
||||
vars_ = info.split(div1)
|
||||
max_memory_size = int(vars_[0])
|
||||
vars_ = vars_[1:]
|
||||
for v_ in vars_:
|
||||
v__ = v_.split(div2)
|
||||
var = {'size':int(v__[1]), 'stack':[]}
|
||||
v__ = v__[2:-1]
|
||||
for s_ in v__:
|
||||
s__ = s_.split(div3)
|
||||
s = {'path':s__[0], 'name':s__[1], 'type':s__[2]}
|
||||
var['stack'].append(s)
|
||||
vars.append(var)
|
||||
if (build_by == 0): # build tree by name
|
||||
tree = {'name':'root', "children":[], 'size':0, 'path':[], 'type':''}
|
||||
|
||||
def find_child(now, key):
|
||||
for c in now['children']:
|
||||
if (c['name'] == key):
|
||||
return c
|
||||
return None
|
||||
for v in vars:
|
||||
now = tree
|
||||
now['size'] += v['size']
|
||||
for s in v['stack']:
|
||||
ch = find_child(now, s['name'])
|
||||
if (ch is not None):
|
||||
if (not s['path'] in ch['path']):
|
||||
ch['path'].append(s['path'])
|
||||
assert(ch['type']==s['type'])
|
||||
now = ch
|
||||
now['size'] += v['size']
|
||||
else:
|
||||
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':[s['path']], 'type':s['type']}
|
||||
now['children'].append(now_)
|
||||
now = now_
|
||||
elif (build_by == 1): # build tree by path
|
||||
tree = {'name':'root', "children":[], 'size':0, 'path':'_root_', 'type':''}
|
||||
|
||||
def find_child(now, key):
|
||||
for c in now['children']:
|
||||
if (c['path'] == key):
|
||||
return c
|
||||
return None
|
||||
for v in vars:
|
||||
now = tree
|
||||
now['size'] += v['size']
|
||||
for s in v['stack']:
|
||||
ch = find_child(now, s['path'])
|
||||
if (ch is not None):
|
||||
now = ch
|
||||
now['size'] += v['size']
|
||||
else:
|
||||
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':s['path'], 'type':s['type']}
|
||||
now['children'].append(now_)
|
||||
now = now_
|
||||
else:
|
||||
assert(False)
|
||||
|
||||
def sort_tree(now):
|
||||
def takeSize(elem):
|
||||
return elem['size']
|
||||
now['children'].sort(key=takeSize, reverse=True)
|
||||
for c in now['children']:
|
||||
sort_tree(c)
|
||||
sort_tree(tree)
|
||||
out = print_tree(tree, max_memory_size, '', '', build_by)
|
||||
if (do_print):
|
||||
print(out)
|
||||
return tree, out
|
||||
def python_pass_warper(mod_func, args, kw):
|
||||
import importlib
|
||||
mod, func = mod_func.rsplit(".", 1)
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
import jittor as jt
|
||||
from jittor import nn
|
||||
|
||||
__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2',
|
||||
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
|
||||
__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet26', 'Resnet38', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2',
|
||||
'resnet18', 'resnet34', 'resnet26', 'resnet38', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
@ -174,6 +174,14 @@ def Resnet50(pretrained=False, **kwargs):
|
|||
|
||||
resnet50 = Resnet50
|
||||
|
||||
def Resnet38(**kwargs):
|
||||
return _resnet(Bottleneck, [2, 3, 5, 2], **kwargs)
|
||||
resnet38 = Resnet38
|
||||
|
||||
def Resnet26(**kwargs):
|
||||
return _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
|
||||
resnet26 = Resnet26
|
||||
|
||||
def Resnet101(pretrained=False, **kwargs):
|
||||
"""
|
||||
ResNet-101 model architecture.
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
from jittor import nn, Module
|
||||
from jittor.models import resnet
|
||||
import numpy as np
|
||||
import sys, os
|
||||
import random
|
||||
import math
|
||||
import unittest
|
||||
from jittor.test.test_reorder_tuner import simple_parser
|
||||
from jittor.test.test_log import find_log_with_re
|
||||
from jittor.dataset.mnist import MNIST
|
||||
import jittor.transform as trans
|
||||
import time
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
class MnistNet(Module):
|
||||
def __init__(self):
|
||||
self.model = resnet.Resnet18()
|
||||
self.layer = nn.Linear(1000,10)
|
||||
def execute(self, x):
|
||||
x = self.model(x)
|
||||
x = self.layer(x)
|
||||
return x
|
||||
|
||||
@unittest.skipIf(skip_this_test, "skip_this_test")
|
||||
class TestMemoryProfiler(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
# hyper-parameters
|
||||
self.batch_size = 100
|
||||
self.weight_decay = 0.0001
|
||||
self.momentum = 0.9
|
||||
self.learning_rate = 0.1
|
||||
# mnist dataset
|
||||
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
|
||||
.set_attrs(batch_size=self.batch_size, shuffle=True)
|
||||
self.train_loader.num_workers = 4
|
||||
|
||||
# setup random seed
|
||||
def setup_seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
jt.seed(seed)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1, use_stat_allocator=1, trace_py_var=3, profile_memory_enable=1)
|
||||
def test_resnet(self):
|
||||
self.setup_seed(1)
|
||||
loss_list=[]
|
||||
acc_list=[]
|
||||
mnist_net = MnistNet()
|
||||
global prev
|
||||
prev = time.time()
|
||||
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
|
||||
|
||||
iters = 10
|
||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
if (batch_idx > iters):
|
||||
break
|
||||
jt.display_memory_info()
|
||||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
SGD.step(loss)
|
||||
def callback(batch_idx, loss, output, target):
|
||||
global prev
|
||||
pred = np.argmax(output, axis=1)
|
||||
acc = np.mean(target==pred)
|
||||
loss_list.append(loss[0])
|
||||
acc_list.append(acc)
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||
.format(0, batch_idx, iters,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
|
||||
jt.fetch(batch_idx, loss, output, target, callback)
|
||||
jt.sync_all(True)
|
||||
jt.display_max_memory_info()
|
||||
_, out = jt.get_max_memory_treemap()
|
||||
out_ = out.split('\n')
|
||||
assert(out_[0] == 'root()')
|
||||
assert(out_[3].endswith('(_run_module_as_main)'))
|
||||
assert(out_[7].endswith('(_run_code)'))
|
||||
_, out = jt.get_max_memory_treemap(build_by=1)
|
||||
out_ = out.split('\n')
|
||||
assert(out_[0] == 'root()')
|
||||
assert(out_[4].endswith('(_run_module_as_main)'))
|
||||
assert(out_[8].endswith('(_run_code)'))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,6 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -22,11 +25,14 @@
|
|||
#include "fuser.h"
|
||||
#include "profiler/profiler_guard.h"
|
||||
#include "parallel_compiler.h"
|
||||
#include "memory_profiler.h"
|
||||
#include "misc/nan_checker.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
Executor exe;
|
||||
extern MemoryProfiler memory_profiler;
|
||||
DECLARE_FLAG(int, profile_memory_enable);
|
||||
|
||||
// from fetch_op.cc
|
||||
extern list<VarPtr> fetcher_to_free;
|
||||
|
@ -90,7 +96,9 @@ void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, i
|
|||
|
||||
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||
auto allocator = get_allocator();
|
||||
auto temp_allocator = get_allocator(true);
|
||||
this->allocator = allocator;
|
||||
this->temp_allocator = temp_allocator;
|
||||
// bfs find all ops need to run
|
||||
int op_num = 0;
|
||||
vector<Node*> bfs_q;
|
||||
|
@ -420,6 +428,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
for (auto* var : op->outputs()) {
|
||||
var->alloc(allocator);
|
||||
}
|
||||
if (PREDICT_BRANCH_NOT_TAKEN(profile_memory_enable))
|
||||
memory_profiler.check();
|
||||
LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs();
|
||||
op->do_prepare(jkl);
|
||||
bool is_cuda = op->flags.get(NodeFlags::_cuda);
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -16,6 +19,7 @@ namespace jittor {
|
|||
|
||||
struct Executor {
|
||||
Allocator* allocator;
|
||||
Allocator* temp_allocator;
|
||||
bool last_is_cuda = false;
|
||||
void run_sync(vector<Var*> vars, bool device_sync);
|
||||
};
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mem/allocator/stat_allocator.h"
|
||||
#include "mem/allocator/sfrl_allocator.h"
|
||||
#include "mem/allocator/nfef_allocator.h"
|
||||
#include "mem/allocator/temp_allocator.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -46,7 +47,7 @@ Allocator* setup_allocator(Allocator* underlying) {
|
|||
|
||||
Allocator* cpu_allocator = setup_allocator<SFRLAllocator>(&aligned_allocator);
|
||||
|
||||
Allocator* get_allocator() {
|
||||
Allocator* get_allocator(bool temp_allocator) {
|
||||
Allocator* allocator = nullptr;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda && !allocator) {
|
||||
|
@ -72,7 +73,10 @@ Allocator* get_allocator() {
|
|||
allocator = setup_allocator<NFEFAllocator>(allocator);
|
||||
return allocator;
|
||||
}
|
||||
if (use_sfrl_allocator) {
|
||||
if (temp_allocator && use_temp_allocator) {
|
||||
LOGvv << "Using temp_allocator";
|
||||
allocator = setup_allocator<TempAllocator>(allocator);
|
||||
} else if (use_sfrl_allocator) {
|
||||
LOGvv << "Using sfrl_allocator";
|
||||
allocator = setup_allocator<SFRLAllocator>(allocator);
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ struct Allocation {
|
|||
};
|
||||
|
||||
extern Allocator* cpu_allocator;
|
||||
Allocator* get_allocator();
|
||||
Allocator* get_allocator(bool temp_allocator=false);
|
||||
// @pyjt(gc)
|
||||
void gc_all();
|
||||
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
||||
#include "mem/allocator/temp_allocator.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(int, use_temp_allocator, 1, "Enable temp allocator");
|
||||
|
||||
TempAllocator::~TempAllocator() {
|
||||
while (!cached_blocks.empty()) {
|
||||
auto it = cached_blocks.begin();
|
||||
TempCachingBlock* block = it->second;
|
||||
cached_blocks.erase(it);
|
||||
delete block;
|
||||
}
|
||||
}
|
||||
|
||||
const char* TempAllocator::name() const {return "temp";}
|
||||
|
||||
void TempAllocator::setup(Allocator* underlying) {
|
||||
this->underlying = underlying;
|
||||
}
|
||||
|
||||
size_t TempAllocator::align_size(size_t size) {
|
||||
return (size + ALIGN_SIZE - 1) / ALIGN_SIZE * ALIGN_SIZE;
|
||||
}
|
||||
|
||||
unsigned long long TempAllocator::get_key(TempCachingBlock* block) {
|
||||
return ((unsigned long long)block->size) * ID_LIMIT + block->id;
|
||||
}
|
||||
|
||||
void* TempAllocator::alloc(size_t size, size_t& allocation) {
|
||||
size = align_size(size);
|
||||
|
||||
auto temp = TempCachingBlock(size);
|
||||
auto it = cached_blocks.lower_bound(get_key(&temp));
|
||||
TempCachingBlock* block = nullptr;
|
||||
if (it != cached_blocks.end()) {
|
||||
block = it->second;
|
||||
cached_blocks.erase(it);
|
||||
unused_memory -= block->size;
|
||||
} else {
|
||||
void* ptr = underlying->alloc(size, allocation);
|
||||
block = new TempCachingBlock(size, ptr);
|
||||
size_t id;
|
||||
if (!block_ids.empty()) {
|
||||
id = block_ids.back();
|
||||
block_ids.pop_back();
|
||||
} else {
|
||||
ASSERT(tot_block_id < ID_LIMIT - 1) << "block id limit extended.";
|
||||
id = ++tot_block_id;
|
||||
}
|
||||
block->id = id;
|
||||
}
|
||||
|
||||
used_memory += block->size;
|
||||
occupied_id_mapper[block->id] = block;
|
||||
allocation = block->id;
|
||||
return block->memory_ptr;
|
||||
}
|
||||
|
||||
void TempAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
|
||||
size = align_size(size);
|
||||
ASSERT(occupied_id_mapper[allocation] != nullptr) << "allocation not found";
|
||||
TempCachingBlock* block = occupied_id_mapper[allocation];
|
||||
occupied_id_mapper[allocation] = nullptr;
|
||||
used_memory -= block->size;
|
||||
unused_memory += block->size;
|
||||
bool can_add = true;
|
||||
if (cached_blocks.size() > cache_blocks_limit-1) {
|
||||
ASSERT(cached_blocks.size() == cache_blocks_limit);
|
||||
auto it = cached_blocks.lower_bound(get_key(block));
|
||||
if (it == cached_blocks.begin()) {
|
||||
can_add = false;
|
||||
} else {
|
||||
--it;
|
||||
TempCachingBlock* block = it->second;
|
||||
underlying->free((void*)block->memory_ptr, block->size, 0);
|
||||
unused_memory -= block->size;
|
||||
block_ids.push_back(block->id);
|
||||
cached_blocks.erase(it);
|
||||
delete block;
|
||||
}
|
||||
}
|
||||
if (can_add) {
|
||||
cached_blocks[get_key(block)] = block;
|
||||
}
|
||||
}
|
||||
|
||||
void TempAllocator::gc() {
|
||||
while (!cached_blocks.empty()) {
|
||||
auto it = cached_blocks.begin();
|
||||
TempCachingBlock* block = it->second;
|
||||
underlying->free((void*)block->memory_ptr, block->size, 0);
|
||||
unused_memory -= block->size;
|
||||
block_ids.push_back(block->id);
|
||||
cached_blocks.erase(it);
|
||||
delete block;
|
||||
}
|
||||
}
|
||||
|
||||
bool TempAllocator::share_with(size_t size, size_t allocation) {
|
||||
ASSERT(false);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "mem/allocator.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct TempCachingBlock {
|
||||
size_t size;
|
||||
size_t id;
|
||||
void* memory_ptr;
|
||||
|
||||
TempCachingBlock(size_t size):size(size),id(0) {}
|
||||
TempCachingBlock(size_t size, void* memory_ptr):size(size),id(0), memory_ptr(memory_ptr) {}
|
||||
};
|
||||
|
||||
struct TempAllocator : Allocator {
|
||||
static const size_t ALIGN_SIZE = 512;
|
||||
static const size_t ID_LIMIT = 1 << 18;
|
||||
Allocator* underlying;
|
||||
size_t cache_blocks_limit, used_memory, unused_memory;
|
||||
std::map<unsigned long long, TempCachingBlock*> cached_blocks;
|
||||
std::vector<size_t> block_ids;
|
||||
size_t tot_block_id;
|
||||
std::unique_ptr<TempCachingBlock*[]> occupied_id_mapper;
|
||||
|
||||
|
||||
inline TempAllocator(size_t cache_blocks_limit=2) : cache_blocks_limit(cache_blocks_limit), used_memory(0), unused_memory(0), tot_block_id(0), occupied_id_mapper(new TempCachingBlock*[ID_LIMIT]) {
|
||||
}
|
||||
inline TempAllocator(Allocator* underlying, size_t cache_blocks_limit=2) : TempAllocator(cache_blocks_limit) {
|
||||
setup(underlying);
|
||||
}
|
||||
~TempAllocator();
|
||||
|
||||
size_t align_size(size_t size);
|
||||
unsigned long long get_key(TempCachingBlock* block);
|
||||
// free all unused memory of all sfrl allocators.
|
||||
void setup(Allocator* underlying);
|
||||
uint64 flags() const override { return underlying->flags(); }
|
||||
const char* name() const override;
|
||||
void* alloc(size_t size, size_t& allocation) override;
|
||||
void free(void* mem_ptr, size_t size, const size_t& allocation) override;
|
||||
void gc() override;
|
||||
virtual bool share_with(size_t size, size_t allocation) override;
|
||||
};
|
||||
|
||||
DECLARE_FLAG(int, use_temp_allocator);
|
||||
|
||||
}//jittor
|
||||
|
|
@ -15,8 +15,10 @@
|
|||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator/sfrl_allocator.h"
|
||||
#include "mem/allocator/stat_allocator.h"
|
||||
#include "mem/allocator/temp_allocator.h"
|
||||
#include "mem/mem_info.h"
|
||||
#include "update_queue.h"
|
||||
#include "executor.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -101,7 +103,13 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
|
|||
log << "cpu&gpu:" << FloatOutput{(double)all_total, " KMG", 1024, "B"}
|
||||
<< "gpu:" << FloatOutput{(double)gpu_total, " KMG", 1024, "B"}
|
||||
<< "cpu:" << FloatOutput{(double)cpu_total, " KMG", 1024, "B"} >> '\n';
|
||||
|
||||
if (use_temp_allocator) {
|
||||
TempAllocator* temp_allocator = (TempAllocator*)exe.temp_allocator;
|
||||
log << "\nname:" << temp_allocator->name() << "\n";
|
||||
log << "used_memory:" << FloatOutput{(double)temp_allocator->used_memory, " KMG", 1024, "B"} << "\n";
|
||||
log << "unused_memory:" << FloatOutput{(double)temp_allocator->unused_memory, " KMG", 1024, "B"} << "\n";
|
||||
|
||||
}
|
||||
if (dump_var) {
|
||||
vector<Node*> queue;
|
||||
unordered_set<Node*> visited;
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "memory_profiler.h"
|
||||
#include "graph.h"
|
||||
#include "var_holder.h"
|
||||
#include "var.h"
|
||||
#include "mem/allocator/sfrl_allocator.h"
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <sys/sysinfo.h>
|
||||
#include <sstream>
|
||||
#include "pybind/py_var_tracer.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
//TODO reuse from mem_info.cc
|
||||
struct FloatOutput_ {
|
||||
double value;
|
||||
string scale;
|
||||
int base;
|
||||
string suffix;
|
||||
int p=4;
|
||||
};
|
||||
inline std::ostream& operator<<(std::ostream& os, const FloatOutput_& o) {
|
||||
int w = 8;
|
||||
os << std::setw(w-2-o.suffix.size());
|
||||
os << std::setprecision(o.p);
|
||||
uint i=0;
|
||||
double k = o.value;
|
||||
for (; i+1<o.scale.size(); i++) {
|
||||
if (k<o.base) break;
|
||||
k /= o.base;
|
||||
}
|
||||
os << k << o.scale[i];
|
||||
return os << o.suffix;
|
||||
}
|
||||
|
||||
MemoryProfiler memory_profiler;
|
||||
DEFINE_FLAG(int, profile_memory_enable, 0, "Enable memory profiler.");
|
||||
|
||||
MemoryProfiler::MemoryProfiler() {
|
||||
clear();
|
||||
}
|
||||
|
||||
void MemoryProfiler::clear() {
|
||||
allocations.clear();
|
||||
max_memory_size = 0;
|
||||
max_used_memory_size = 0;
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> MemoryProfiler::get_memory_info() {
|
||||
ASSERT(profile_memory_enable == 1);
|
||||
size_t used = 0;
|
||||
size_t unused = 0;
|
||||
//TODO add mssfrl allocator
|
||||
for (auto& a : SFRLAllocator::sfrl_allocators) {
|
||||
used += a->used_memory;
|
||||
unused += a->unused_memory;
|
||||
}
|
||||
return std::make_pair(used, unused);
|
||||
}
|
||||
|
||||
void MemoryProfiler::check() {
|
||||
ASSERT(profile_memory_enable == 1);
|
||||
std::pair<size_t, size_t> mem_info = get_memory_info();
|
||||
if (mem_info.first > max_used_memory_size) {
|
||||
max_used_memory_size = mem_info.first;
|
||||
|
||||
allocations.clear();
|
||||
size_t memory_size = 0;
|
||||
std::vector<std::pair<std::pair<string, vector<Stack>>, size_t>> live_vars;
|
||||
vector<Node*> queue;
|
||||
|
||||
auto t = ++Node::tflag_count;
|
||||
for (auto& vh : VarHolder::hold_vars)
|
||||
if (vh->var->tflag != t) {
|
||||
vh->var->tflag = t;
|
||||
queue.push_back(vh->var);
|
||||
}
|
||||
bfs_both(queue, [](Node*){return true;});
|
||||
for (Node* node : queue) {
|
||||
if (node->is_var()) {
|
||||
Var* var = (Var*)node;
|
||||
if (var->mem_ptr != nullptr) {
|
||||
vector<Stack> stacks = get_node_trace(var);
|
||||
std::stringstream stream;
|
||||
stream << var;
|
||||
live_vars.push_back(std::make_pair(std::make_pair(stream.str(), stacks), var->size));
|
||||
if (!allocations.count(var->mem_ptr)) {
|
||||
allocations[var->mem_ptr] = 1;
|
||||
memory_size += var->size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
max_live_vars = live_vars;
|
||||
max_memory_size = memory_size;
|
||||
}
|
||||
}
|
||||
|
||||
bool MemoryProfiler::cmp(const std::pair<std::pair<string, vector<Stack>>, size_t>& a, const std::pair<std::pair<string, vector<Stack>>, size_t>& b) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
|
||||
void MemoryProfiler::display_max_memory_info() {
|
||||
ASSERT(profile_memory_enable == 1);
|
||||
Log log("", 'i', 0);
|
||||
std::sort(max_live_vars.begin(), max_live_vars.end(), cmp);
|
||||
log << "\n=====display_max_memory_info=====\n";
|
||||
log << "max used memory" << FloatOutput_{(double)max_used_memory_size, " KMG", 1024, "B"} << "\n";
|
||||
log << "max var memory" << FloatOutput_{(double)max_memory_size, " KMG", 1024, "B"} << "\n\n";
|
||||
log << "[Size]" << "[Percent]" << "[Var Info]" << "\n";
|
||||
for (int i = 0; i < max_live_vars.size(); ++i) {
|
||||
log << FloatOutput_{(double)max_live_vars[i].second, " KMG", 1024, "B"}
|
||||
<< double(max_live_vars[i].second) / max_memory_size * 100 << "%"
|
||||
<< max_live_vars[i].first.first
|
||||
<< max_live_vars[i].first.second[0].file_path + ":" + std::to_string(max_live_vars[i].first.second[0].lineno)
|
||||
<< "\n\n";
|
||||
}
|
||||
log << "=========================\n";
|
||||
log.end();
|
||||
}
|
||||
|
||||
void display_max_memory_info() {
|
||||
ASSERT(profile_memory_enable == 1);
|
||||
memory_profiler.display_max_memory_info();
|
||||
}
|
||||
|
||||
string MemoryProfiler::get_max_memory_info() {
|
||||
ASSERT(profile_memory_enable == 1);
|
||||
std::stringstream out;
|
||||
string div1 = "[!@#div1!@#]";
|
||||
string div2 = "[!@#div2!@#]";
|
||||
string div3 = "[!@#div3!@#]";
|
||||
|
||||
std::sort(max_live_vars.begin(), max_live_vars.end(), cmp);
|
||||
out << max_memory_size;
|
||||
for (int i = 0; i < max_live_vars.size(); ++i) {
|
||||
out << div1;
|
||||
out << max_live_vars[i].first.first << div2;
|
||||
out << max_live_vars[i].second << div2;
|
||||
for (int j = 0; j < max_live_vars[i].first.second.size(); ++j) {
|
||||
out << max_live_vars[i].first.second[j].file_path + ":" + std::to_string(max_live_vars[i].first.second[j].lineno) << div3
|
||||
<< max_live_vars[i].first.second[j].module_name << div3
|
||||
<< max_live_vars[i].first.second[j].module_type << div2;
|
||||
}
|
||||
}
|
||||
return out.str();
|
||||
}
|
||||
|
||||
string get_max_memory_info() {
|
||||
ASSERT(profile_memory_enable == 1);
|
||||
return memory_profiler.get_max_memory_info();
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,46 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
#include "mem/allocator.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "var.h"
|
||||
#include "pybind/py_var_tracer.h"
|
||||
namespace jittor {
|
||||
|
||||
// @pyjt(display_max_memory_info)
|
||||
void display_max_memory_info();
|
||||
// @pyjt(get_max_memory_info)
|
||||
string get_max_memory_info();
|
||||
|
||||
struct MemoryProfiler {
|
||||
std::map<void*, size_t> allocations;
|
||||
// Max Infos
|
||||
vector<std::pair<std::pair<string, vector<Stack>>, size_t>> max_live_vars;
|
||||
size_t max_used_memory_size;
|
||||
size_t max_memory_size;
|
||||
|
||||
|
||||
MemoryProfiler();
|
||||
static bool cmp(const std::pair<std::pair<string, vector<Stack>>, size_t>& a, const std::pair<std::pair<string, vector<Stack>>, size_t>& b);
|
||||
void clear();
|
||||
void check();
|
||||
std::pair<size_t, size_t> get_memory_info();
|
||||
void display_max_memory_info();
|
||||
string get_max_memory_info();
|
||||
};
|
||||
|
||||
extern MemoryProfiler memory_profiler;
|
||||
|
||||
DECLARE_FLAG(int, profile_memory_enable);
|
||||
|
||||
} // jittor
|
|
@ -76,9 +76,9 @@ void CandidateOp::jit_run() {
|
|||
// define ys
|
||||
auto* __restrict__ yp = y->ptr<Ty>();
|
||||
size_t n_allocation;
|
||||
int* np = (int*)exe.allocator->alloc(4, n_allocation);
|
||||
int* np = (int*)exe.temp_allocator->alloc(4, n_allocation);
|
||||
size_t mask_allocation;
|
||||
bool* maskp = (bool*)exe.allocator->alloc(xshape0, mask_allocation);
|
||||
bool* maskp = (bool*)exe.temp_allocator->alloc(xshape0, mask_allocation);
|
||||
checkCudaErrors(cudaMemsetAsync(maskp, 1, xshape0));
|
||||
|
||||
candidate_kernel<<<1, std::max(1, std::min(1024, xshape0)) >>>(
|
||||
|
@ -93,8 +93,8 @@ void CandidateOp::jit_run() {
|
|||
// checkCudaErrors(cudaDeviceSynchronize());
|
||||
checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault));
|
||||
y->set_shape({n});
|
||||
exe.allocator->free(np, 4, n_allocation);
|
||||
exe.allocator->free(maskp, xshape0, mask_allocation);
|
||||
exe.temp_allocator->free(np, 4, n_allocation);
|
||||
exe.temp_allocator->free(maskp, xshape0, mask_allocation);
|
||||
}
|
||||
#else
|
||||
void CandidateOp::jit_run() {
|
||||
|
|
|
@ -196,7 +196,7 @@ void WhereOp::jit_run() {
|
|||
@for(i, 0, NDIM, auto* __restrict__ outs@i@@p = outs[@i]->ptr<To>();)
|
||||
|
||||
size_t n_allocation;
|
||||
int* np = (int*)exe.allocator->alloc(4, n_allocation);
|
||||
int* np = (int*)exe.temp_allocator->alloc(4, n_allocation);
|
||||
|
||||
// one block kernel, result maybe unstable
|
||||
// int tnum = condshape@{NDIM-1};
|
||||
|
@ -232,7 +232,7 @@ void WhereOp::jit_run() {
|
|||
// checkCudaErrors(cudaDeviceSynchronize());
|
||||
checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault));
|
||||
@for(i, 0, NDIM, outs[@i]->set_shape({n});)
|
||||
exe.allocator->free(np, 4, n_allocation);
|
||||
exe.temp_allocator->free(np, 4, n_allocation);
|
||||
}
|
||||
#else
|
||||
|
||||
|
|
|
@ -48,7 +48,6 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
}
|
||||
auto output = op->outputs().front();
|
||||
output->share_with(input);
|
||||
// return;
|
||||
|
||||
auto data = op->input(1);
|
||||
// if setitem requires type conversion, don't inplace
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -97,7 +100,8 @@ static vector<Stack> get_stack_info() {
|
|||
PyObject* prev_obj = nullptr;
|
||||
if (trace_py_var >= 3) {
|
||||
// trace raw stack
|
||||
auto start = std::max(0, n-5);
|
||||
// auto start = std::max(0, n-5);
|
||||
auto start = 0;
|
||||
for (int i=start; i<n; i++) {
|
||||
auto f = frames[i];
|
||||
auto filename = to_string(f->f_code->co_filename);
|
||||
|
@ -363,4 +367,16 @@ void print_node_trace(const Node* node, std::ostream& os) {
|
|||
os << _get_stack_info((Node*)node);
|
||||
}
|
||||
|
||||
vector<Stack> get_node_trace(Node* node) {
|
||||
auto iter = trace_data.id_map.find(node);
|
||||
if (iter == trace_data.id_map.end())
|
||||
return vector<Stack>();
|
||||
auto node_id = iter->second;
|
||||
auto iter2 = trace_data.node_data.find(node_id);
|
||||
if (iter2 == trace_data.node_data.end())
|
||||
return vector<Stack>();
|
||||
return iter2->second.stacks;
|
||||
}
|
||||
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -67,5 +67,5 @@ struct TraceData {
|
|||
extern TraceData trace_data;
|
||||
|
||||
void print_node_trace(const Node* node, std::ostream& os);
|
||||
|
||||
vector<Stack> get_node_trace(Node* node);
|
||||
} // jittor
|
||||
|
|
Loading…
Reference in New Issue