diff --git a/extern/cuda/cub/ops/cub_arg_reduce_op.cc b/extern/cuda/cub/ops/cub_arg_reduce_op.cc index 8d5e2b6e..0216c525 100644 --- a/extern/cuda/cub/ops/cub_arg_reduce_op.cc +++ b/extern/cuda/cub/ops/cub_arg_reduce_op.cc @@ -87,7 +87,7 @@ void CubArgReduceOp::jit_run() { num_segments *= x->shape[i]; } size_t allocation_dout; - cub::KeyValuePair *d_out = (cub::KeyValuePair *)exe.allocator->alloc(sizeof(cub::KeyValuePair) * num_segments, allocation_dout); + cub::KeyValuePair *d_out = (cub::KeyValuePair *)exe.temp_allocator->alloc(sizeof(cub::KeyValuePair) * num_segments, allocation_dout); // Determine temporary device storage requirementse = NULL; void *d_temp_storage = NULL; @@ -96,7 +96,7 @@ void CubArgReduceOp::jit_run() { xp, d_out, num_segments, offsetsp, offsetsp + 1); // Allocate temporary storage size_t allocation; - d_temp_storage = exe.allocator->alloc(temp_storage_bytes, allocation); + d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, allocation); // Run sorting operation cub::DeviceSegmentedReduce::@FUNC@@(d_temp_storage, temp_storage_bytes, xp, d_out, num_segments, offsetsp, offsetsp + 1); @@ -105,8 +105,8 @@ void CubArgReduceOp::jit_run() { auto* __restrict__ y_keyp = y_key->ptr(); split<<>>(d_out, y_keyp, yp, num_segments); - exe.allocator->free(d_temp_storage, temp_storage_bytes, allocation); - exe.allocator->free(d_out, sizeof(cub::KeyValuePair) * num_segments, allocation_dout); + exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, allocation); + exe.temp_allocator->free(d_out, sizeof(cub::KeyValuePair) * num_segments, allocation_dout); } #endif // JIT_cuda #endif // JIT diff --git a/extern/cuda/cub/ops/cub_argsort_op.cc b/extern/cuda/cub/ops/cub_argsort_op.cc index 4e47d276..1ca57c15 100644 --- a/extern/cuda/cub/ops/cub_argsort_op.cc +++ b/extern/cuda/cub/ops/cub_argsort_op.cc @@ -85,12 +85,12 @@ void CubArgsortOp::jit_run() { num_items, num_segments, offsetsp, offsetsp + 1); // Allocate temporary storage size_t allocation; - d_temp_storage = exe.allocator->alloc(temp_storage_bytes, allocation); + d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, allocation); // Run sorting operation cub::DeviceSegmentedRadixSort::@FUNC@@(d_temp_storage, temp_storage_bytes, xp, y_keyp, indexesp, yp, num_items, num_segments, offsetsp, offsetsp + 1); - exe.allocator->free(d_temp_storage, temp_storage_bytes, allocation); + exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, allocation); } #endif // JIT_cuda #endif // JIT diff --git a/extern/cuda/cub/ops/cub_where_op.cc b/extern/cuda/cub/ops/cub_where_op.cc index bca1f8b3..65e459ac 100644 --- a/extern/cuda/cub/ops/cub_where_op.cc +++ b/extern/cuda/cub/ops/cub_where_op.cc @@ -82,7 +82,7 @@ void CubWhereOp::jit_run(){ int N = cond->num; size_t temp_storage_bytes=0; size_t num_nonzeros_allocation; - auto num_nonzeros = exe.allocator->alloc(sizeof(To), num_nonzeros_allocation); + auto num_nonzeros = exe.temp_allocator->alloc(sizeof(To), num_nonzeros_allocation); size_t temp_storage_allocation; void* temp_storage; @@ -93,9 +93,9 @@ void CubWhereOp::jit_run(){ cub::TransformInputIterator, Ti*> itr(cond->ptr(), NonZeroOp()); temp_storage_bytes = 0; checkCudaErrors(cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, out_temp, (To*)num_nonzeros, N)); - temp_storage = exe.allocator->alloc(temp_storage_bytes, temp_storage_allocation); + temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, temp_storage_allocation); checkCudaErrors(cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, counting_itr, itr,out_temp, (To*)num_nonzeros, N)); - exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation); + exe.temp_allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation); To num_nonzeros_h; cudaMemcpy(&num_nonzeros_h, num_nonzeros, sizeof(To), cudaMemcpyDeviceToHost); @@ -110,7 +110,7 @@ void CubWhereOp::jit_run(){ @for(i, 0, NDIM, 1, , cond->shape[@i], outs[@i]->ptr()) ); } - exe.allocator->free(num_nonzeros, sizeof(int), num_nonzeros_allocation); + exe.temp_allocator->free(num_nonzeros, sizeof(int), num_nonzeros_allocation); } #endif diff --git a/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc b/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc index ed1c6f8c..ce1c18a5 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc +++ b/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc @@ -203,7 +203,7 @@ void CudnnConvBackwardWOp::jit_run() { if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; } size_t allocation; - void* ws = exe.allocator->alloc(max_ws_size, allocation); + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); checkCudaErrors(cudnnFindConvolutionBackwardFilterAlgorithmEx( handle_, cudnnIdesc, x->ptr(), @@ -215,7 +215,7 @@ void CudnnConvBackwardWOp::jit_run() { perf_results, ws, max_ws_size)); - exe.allocator->free(ws, max_ws_size, allocation); + exe.temp_allocator->free(ws, max_ws_size, allocation); } else { checkCudaErrors(cudnnGetConvolutionBackwardFilterAlgorithm_v7( handle_, @@ -250,7 +250,7 @@ void CudnnConvBackwardWOp::jit_run() { cudnnFdesc, algo, &workSpaceSize)); size_t allocation; if (workSpaceSize > 0) { - workSpace = exe.allocator->alloc(workSpaceSize, allocation); + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); } float alpha=1, beta=0; checkCudaErrors(cudnnConvolutionBackwardFilter( @@ -265,7 +265,7 @@ void CudnnConvBackwardWOp::jit_run() { cudnnFdesc, w->ptr()) ); if (workSpace) - exe.allocator->free(workSpace, workSpaceSize, allocation); + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); diff --git a/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc b/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc index 5ecb503a..bbf72a1f 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc +++ b/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc @@ -204,7 +204,7 @@ void CudnnConvBackwardXOp::jit_run() { if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; } size_t allocation; - void* ws = exe.allocator->alloc(max_ws_size, allocation); + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); checkCudaErrors(cudnnFindConvolutionBackwardDataAlgorithmEx( handle_, cudnnFdesc, w->ptr(), @@ -216,7 +216,7 @@ void CudnnConvBackwardXOp::jit_run() { perf_results, ws, max_ws_size)); - exe.allocator->free(ws, max_ws_size, allocation); + exe.temp_allocator->free(ws, max_ws_size, allocation); } else { checkCudaErrors(cudnnGetConvolutionBackwardDataAlgorithm_v7( handle_, @@ -251,7 +251,7 @@ void CudnnConvBackwardXOp::jit_run() { cudnnIdesc, algo, &workSpaceSize)); size_t allocation; if (workSpaceSize > 0) { - workSpace = exe.allocator->alloc(workSpaceSize, allocation); + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); } float alpha=1, beta=0; checkCudaErrors(cudnnConvolutionBackwardData( @@ -266,7 +266,7 @@ void CudnnConvBackwardXOp::jit_run() { cudnnIdesc, x->ptr()) ); if (workSpace) - exe.allocator->free(workSpace, workSpaceSize, allocation); + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); diff --git a/extern/cuda/cudnn/ops/cudnn_conv_op.cc b/extern/cuda/cudnn/ops/cudnn_conv_op.cc index 5798e789..548b5ae3 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_op.cc +++ b/extern/cuda/cudnn/ops/cudnn_conv_op.cc @@ -208,7 +208,7 @@ void CudnnConvOp::jit_run() { if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; } size_t allocation; - void* ws = exe.allocator->alloc(max_ws_size, allocation); + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); checkCudaErrors(cudnnFindConvolutionForwardAlgorithmEx( handle_, cudnnIdesc, x->ptr(), @@ -220,7 +220,7 @@ void CudnnConvOp::jit_run() { perf_results, ws, max_ws_size)); - exe.allocator->free(ws, max_ws_size, allocation); + exe.temp_allocator->free(ws, max_ws_size, allocation); } else { checkCudaErrors(cudnnGetConvolutionForwardAlgorithm_v7( handle_, @@ -255,7 +255,7 @@ void CudnnConvOp::jit_run() { cudnnOdesc, algo, &workSpaceSize) ); size_t allocation; if (workSpaceSize > 0) { - workSpace = exe.allocator->alloc(workSpaceSize, allocation); + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); } float alpha=1, beta=0; checkCudaErrors(cudnnConvolutionForward( @@ -270,7 +270,7 @@ void CudnnConvOp::jit_run() { cudnnOdesc, y->ptr()) ); if (workSpace) - exe.allocator->free(workSpace, workSpaceSize, allocation); + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); diff --git a/src/executor.cc b/src/executor.cc index 4401c72d..72d4ce18 100644 --- a/src/executor.cc +++ b/src/executor.cc @@ -92,7 +92,9 @@ void load_fused_op(FusedOp& fused_op, vector& fuse_ops, vector& ops, i void Executor::run_sync(vector vars, bool device_sync) { auto allocator = get_allocator(); + auto temp_allocator = get_allocator(true); this->allocator = allocator; + this->temp_allocator = temp_allocator; // bfs find all ops need to run int op_num = 0; vector bfs_q; diff --git a/src/executor.h b/src/executor.h index 2126a880..9baef47c 100644 --- a/src/executor.h +++ b/src/executor.h @@ -16,6 +16,7 @@ namespace jittor { struct Executor { Allocator* allocator; + Allocator* temp_allocator; bool last_is_cuda = false; void run_sync(vector vars, bool device_sync); }; diff --git a/src/mem/allocator.cc b/src/mem/allocator.cc index d3ec59f2..7145d511 100644 --- a/src/mem/allocator.cc +++ b/src/mem/allocator.cc @@ -15,6 +15,7 @@ #include "mem/allocator/stat_allocator.h" #include "mem/allocator/sfrl_allocator.h" #include "mem/allocator/nfef_allocator.h" +#include "mem/allocator/temp_allocator.h" namespace jittor { @@ -46,7 +47,7 @@ Allocator* setup_allocator(Allocator* underlying) { Allocator* cpu_allocator = setup_allocator(&aligned_allocator); -Allocator* get_allocator() { +Allocator* get_allocator(bool temp_allocator) { Allocator* allocator = nullptr; #ifdef HAS_CUDA if (use_cuda && !allocator) { @@ -72,7 +73,10 @@ Allocator* get_allocator() { allocator = setup_allocator(allocator); return allocator; } - if (use_sfrl_allocator) { + if (temp_allocator && use_temp_allocator) { + LOGvv << "Using temp_allocator"; + allocator = setup_allocator(allocator); + } else if (use_sfrl_allocator) { LOGvv << "Using sfrl_allocator"; allocator = setup_allocator(allocator); } diff --git a/src/mem/allocator.h b/src/mem/allocator.h index 34553800..8f8c637e 100644 --- a/src/mem/allocator.h +++ b/src/mem/allocator.h @@ -49,7 +49,7 @@ struct Allocation { }; extern Allocator* cpu_allocator; -Allocator* get_allocator(); +Allocator* get_allocator(bool temp_allocator=false); // @pyjt(gc) void gc_all(); diff --git a/src/mem/allocator/temp_allocator.cc b/src/mem/allocator/temp_allocator.cc new file mode 100644 index 00000000..88c1398c --- /dev/null +++ b/src/mem/allocator/temp_allocator.cc @@ -0,0 +1,116 @@ +// *************************************************************** +// Copyright (c) 2020 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "mem/allocator/temp_allocator.h" + +namespace jittor { + +DEFINE_FLAG(int, use_temp_allocator, 1, "Enable temp allocator"); + +TempAllocator::~TempAllocator() { + while (!cached_blocks.empty()) { + auto it = cached_blocks.begin(); + TempCachingBlock* block = it->second; + cached_blocks.erase(it); + delete block; + } +} + +const char* TempAllocator::name() const {return "temp";} + +void TempAllocator::setup(Allocator* underlying) { + this->underlying = underlying; +} + +size_t TempAllocator::align_size(size_t size) { + return (size + ALIGN_SIZE - 1) / ALIGN_SIZE * ALIGN_SIZE; +} + +unsigned long long TempAllocator::get_key(TempCachingBlock* block) { + return ((unsigned long long)block->size) * ID_LIMIT + block->id; +} + +void* TempAllocator::alloc(size_t size, size_t& allocation) { + size = align_size(size); + + auto temp = TempCachingBlock(size); + auto it = cached_blocks.lower_bound(get_key(&temp)); + TempCachingBlock* block = nullptr; + if (it != cached_blocks.end()) { + block = it->second; + cached_blocks.erase(it); + unused_memory -= block->size; + } else { + void* ptr = underlying->alloc(size, allocation); + block = new TempCachingBlock(size, ptr); + size_t id; + if (!block_ids.empty()) { + id = block_ids.back(); + block_ids.pop_back(); + } else { + ASSERT(tot_block_id < ID_LIMIT - 1) << "block id limit extended."; + id = ++tot_block_id; + } + block->id = id; + } + + used_memory += block->size; + occupied_id_mapper[block->id] = block; + allocation = block->id; + return block->memory_ptr; +} + +void TempAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + size = align_size(size); + ASSERT(occupied_id_mapper[allocation] != nullptr) << "allocation not found"; + TempCachingBlock* block = occupied_id_mapper[allocation]; + occupied_id_mapper[allocation] = nullptr; + used_memory -= block->size; + unused_memory += block->size; + bool can_add = true; + if (cached_blocks.size() > cache_blocks_limit-1) { + ASSERT(cached_blocks.size() == cache_blocks_limit); + auto it = cached_blocks.lower_bound(get_key(block)); + if (it == cached_blocks.begin()) { + can_add = false; + } else { + --it; + TempCachingBlock* block = it->second; + underlying->free((void*)block->memory_ptr, block->size, 0); + unused_memory -= block->size; + block_ids.push_back(block->id); + cached_blocks.erase(it); + delete block; + } + } + if (can_add) { + cached_blocks[get_key(block)] = block; + } +} + +void TempAllocator::gc() { + while (!cached_blocks.empty()) { + auto it = cached_blocks.begin(); + TempCachingBlock* block = it->second; + underlying->free((void*)block->memory_ptr, block->size, 0); + unused_memory -= block->size; + block_ids.push_back(block->id); + cached_blocks.erase(it); + delete block; + } +} + +bool TempAllocator::share_with(size_t size, size_t allocation) { + ASSERT(false); + return true; +} + +} // jittor + diff --git a/src/mem/allocator/temp_allocator.h b/src/mem/allocator/temp_allocator.h new file mode 100644 index 00000000..0402e421 --- /dev/null +++ b/src/mem/allocator/temp_allocator.h @@ -0,0 +1,57 @@ +// *************************************************************** +// Copyright (c) 2020 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "mem/allocator.h" + +namespace jittor { + +struct TempCachingBlock { + size_t size; + size_t id; + void* memory_ptr; + + TempCachingBlock(size_t size):size(size),id(0) {} + TempCachingBlock(size_t size, void* memory_ptr):size(size),id(0), memory_ptr(memory_ptr) {} +}; + +struct TempAllocator : Allocator { + static const size_t ALIGN_SIZE = 512; + static const size_t ID_LIMIT = 1 << 18; + Allocator* underlying; + size_t cache_blocks_limit, used_memory, unused_memory; + std::map cached_blocks; + std::vector block_ids; + size_t tot_block_id; + std::unique_ptr occupied_id_mapper; + + + inline TempAllocator(size_t cache_blocks_limit=2) : cache_blocks_limit(cache_blocks_limit), used_memory(0), unused_memory(0), tot_block_id(0), occupied_id_mapper(new TempCachingBlock*[ID_LIMIT]) { + } + inline TempAllocator(Allocator* underlying, size_t cache_blocks_limit=2) : TempAllocator(cache_blocks_limit) { + setup(underlying); + } + ~TempAllocator(); + + size_t align_size(size_t size); + unsigned long long get_key(TempCachingBlock* block); + // free all unused memory of all sfrl allocators. + void setup(Allocator* underlying); + uint64 flags() const override { return underlying->flags(); } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; + void gc() override; + virtual bool share_with(size_t size, size_t allocation) override; +}; + +DECLARE_FLAG(int, use_temp_allocator); + +}//jittor + diff --git a/src/mem/mem_info.cc b/src/mem/mem_info.cc index 375f5f7d..f935685c 100644 --- a/src/mem/mem_info.cc +++ b/src/mem/mem_info.cc @@ -15,8 +15,10 @@ #include "misc/cuda_flags.h" #include "mem/allocator/sfrl_allocator.h" #include "mem/allocator/stat_allocator.h" +#include "mem/allocator/temp_allocator.h" #include "mem/mem_info.h" #include "update_queue.h" +#include "executor.h" namespace jittor { @@ -101,7 +103,13 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) { log << "cpu&gpu:" << FloatOutput{(double)all_total, " KMG", 1024, "B"} << "gpu:" << FloatOutput{(double)gpu_total, " KMG", 1024, "B"} << "cpu:" << FloatOutput{(double)cpu_total, " KMG", 1024, "B"} >> '\n'; - + if (use_temp_allocator) { + TempAllocator* temp_allocator = (TempAllocator*)exe.temp_allocator; + log << "\nname:" << temp_allocator->name() << "\n"; + log << "used_memory:" << FloatOutput{(double)temp_allocator->used_memory, " KMG", 1024, "B"} << "\n"; + log << "unused_memory:" << FloatOutput{(double)temp_allocator->unused_memory, " KMG", 1024, "B"} << "\n"; + + } if (dump_var) { vector queue; unordered_set visited; diff --git a/src/ops/candidate_op.cc b/src/ops/candidate_op.cc index c68fc76b..2cdda0c0 100644 --- a/src/ops/candidate_op.cc +++ b/src/ops/candidate_op.cc @@ -76,9 +76,9 @@ void CandidateOp::jit_run() { // define ys auto* __restrict__ yp = y->ptr(); size_t n_allocation; - int* np = (int*)exe.allocator->alloc(4, n_allocation); + int* np = (int*)exe.temp_allocator->alloc(4, n_allocation); size_t mask_allocation; - bool* maskp = (bool*)exe.allocator->alloc(xshape0, mask_allocation); + bool* maskp = (bool*)exe.temp_allocator->alloc(xshape0, mask_allocation); checkCudaErrors(cudaMemsetAsync(maskp, 1, xshape0)); candidate_kernel<<<1, std::max(1, std::min(1024, xshape0)) >>>( @@ -93,8 +93,8 @@ void CandidateOp::jit_run() { // checkCudaErrors(cudaDeviceSynchronize()); checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault)); y->set_shape({n}); - exe.allocator->free(np, 4, n_allocation); - exe.allocator->free(maskp, xshape0, mask_allocation); + exe.temp_allocator->free(np, 4, n_allocation); + exe.temp_allocator->free(maskp, xshape0, mask_allocation); } #else void CandidateOp::jit_run() { diff --git a/src/ops/where_op.cc b/src/ops/where_op.cc index 0cf31899..d3c5a3ca 100644 --- a/src/ops/where_op.cc +++ b/src/ops/where_op.cc @@ -196,7 +196,7 @@ void WhereOp::jit_run() { @for(i, 0, NDIM, auto* __restrict__ outs@i@@p = outs[@i]->ptr();) size_t n_allocation; - int* np = (int*)exe.allocator->alloc(4, n_allocation); + int* np = (int*)exe.temp_allocator->alloc(4, n_allocation); // one block kernel, result maybe unstable // int tnum = condshape@{NDIM-1}; @@ -232,7 +232,7 @@ void WhereOp::jit_run() { // checkCudaErrors(cudaDeviceSynchronize()); checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault)); @for(i, 0, NDIM, outs[@i]->set_shape({n});) - exe.allocator->free(np, 4, n_allocation); + exe.temp_allocator->free(np, 4, n_allocation); } #else diff --git a/src/opt/gopt/setitem_gopt.cc b/src/opt/gopt/setitem_gopt.cc index 44414457..0831609d 100644 --- a/src/opt/gopt/setitem_gopt.cc +++ b/src/opt/gopt/setitem_gopt.cc @@ -48,7 +48,7 @@ static void setitem_inplace(SetitemOp* op) { } auto output = op->outputs().front(); output->share_with(input); - // return; + return; // LOGir << "pass setitem optim one";