mirror of https://github.com/Jittor/Jittor
fix cudnn conv workspace too big
This commit is contained in:
parent
a9bb4567dd
commit
c657491a51
|
@ -6,6 +6,7 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mem/allocator.h"
|
||||
#include "var.h"
|
||||
#include "cudnn_conv_backward_w_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
|
@ -195,6 +196,8 @@ void CudnnConvBackwardWOp::jit_run() {
|
|||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz*4 > mem_info.total_cuda_ram) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mem/allocator.h"
|
||||
#include "var.h"
|
||||
#include "cudnn_conv_backward_x_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
|
@ -196,6 +197,8 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz*4 > mem_info.total_cuda_ram) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
|
|
|
@ -199,9 +199,11 @@ void CudnnConvOp::jit_run() {
|
|||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionForwardWorkspaceSize(
|
||||
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||
cudnnOdesc, algos[i], &sz);
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size && sz<512*1024*1024) max_ws_size = sz;
|
||||
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||
cudnnOdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz*4 > mem_info.total_cuda_ram) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.allocator->alloc(max_ws_size, allocation);
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
class TestMem(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
jt.clean()
|
||||
jt.gc()
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "no cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_oom(self):
|
||||
backups = []
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
one_g = np.ones((1024*1024*1024//4,), "float32")
|
||||
|
||||
meminfo = jt.get_mem_info()
|
||||
n = int(meminfo.total_cuda_ram // (1024**3) * 1.5)
|
||||
|
||||
for i in range(n):
|
||||
a = jt.array(one_g)
|
||||
b = a + 1
|
||||
b.sync()
|
||||
backups.append((a,b))
|
||||
backups = []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -5,6 +5,7 @@
|
|||
// ***************************************************************
|
||||
#include <typeinfo>
|
||||
#include <iomanip>
|
||||
#include <sys/sysinfo.h>
|
||||
|
||||
#include "var.h"
|
||||
#include "op.h"
|
||||
|
@ -140,5 +141,18 @@ void display_memory_info(const char* fileline) {
|
|||
log.end();
|
||||
}
|
||||
|
||||
MemInfo::MemInfo() {
|
||||
struct sysinfo info = {0};
|
||||
sysinfo(&info);
|
||||
total_cpu_ram = info.totalram;
|
||||
total_cuda_ram = 0;
|
||||
#ifdef HAS_CUDA
|
||||
cudaDeviceProp prop = {0};
|
||||
cudaGetDeviceProperties(&prop, 0);
|
||||
total_cuda_ram = prop.totalGlobalMem;
|
||||
#endif
|
||||
}
|
||||
|
||||
MemInfo mem_info;
|
||||
|
||||
} // jittor
|
|
@ -54,4 +54,21 @@ void gc_all();
|
|||
// @pyjt(display_memory_info)
|
||||
void display_memory_info(const char* fileline="");
|
||||
|
||||
// @pyjt(MemInfo)
|
||||
struct MemInfo {
|
||||
// @pyjt(total_cpu_ram)
|
||||
int64 total_cpu_ram;
|
||||
// @pyjt(total_cuda_ram)
|
||||
int64 total_cuda_ram;
|
||||
|
||||
inline MemInfo(const MemInfo&) = default;
|
||||
|
||||
MemInfo();
|
||||
};
|
||||
|
||||
extern MemInfo mem_info;
|
||||
|
||||
// @pyjt(get_mem_info)
|
||||
inline MemInfo get_mem_info() { return mem_info; }
|
||||
|
||||
} // jittor
|
|
@ -16,7 +16,14 @@ const char* CudaDeviceAllocator::name() const {return "cuda_device";}
|
|||
|
||||
void* CudaDeviceAllocator::alloc(size_t size, size_t& allocation) {
|
||||
void* ptr;
|
||||
checkCudaErrors(cudaMalloc(&ptr, size));
|
||||
try {
|
||||
checkCudaErrors(cudaMalloc(&ptr, size));
|
||||
return ptr;
|
||||
} catch (...) {}
|
||||
LOGw << "Unable to alloc cuda device memory, use unify memory instead. "
|
||||
"This may cause low performance.";
|
||||
display_memory_info(__FILELINE__);
|
||||
checkCudaErrors(cudaMallocManaged(&ptr, size));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -158,6 +158,25 @@ DEF_IS(DumpGraphs, const T&) from_py_object(PyObject* obj) {
|
|||
return GET_RAW_PTR(T, obj);
|
||||
}
|
||||
|
||||
// MemInfo
|
||||
struct MemInfo;
|
||||
extern PyTypeObject PyjtMemInfo;
|
||||
DEF_IS(MemInfo, bool) is_type(PyObject* obj) {
|
||||
return Py_TYPE(obj) == &PyjtMemInfo;
|
||||
}
|
||||
|
||||
|
||||
DEF_IS(MemInfo, PyObject*) to_py_object(const T& a) {
|
||||
PyObjHolder obj(_PyObject_New(&PyjtMemInfo));
|
||||
auto ptr = GET_RAW_PTR(T, obj.obj);
|
||||
new (ptr) T(a);
|
||||
return obj.release();
|
||||
}
|
||||
|
||||
DEF_IS(MemInfo, const T&) from_py_object(PyObject* obj) {
|
||||
return GET_RAW_PTR(T, obj);
|
||||
}
|
||||
|
||||
|
||||
// NanoString
|
||||
struct NanoString;
|
||||
|
|
Loading…
Reference in New Issue