add cufft plan map

This commit is contained in:
cxjyxx_me 2022-03-28 09:36:10 -04:00
parent 254b8609fe
commit 0b5e367cb1
3 changed files with 20 additions and 10 deletions

View File

@ -19,6 +19,6 @@
namespace jittor {
EXTERN_LIB cufftHandle cufft_handle;
EXTERN_LIB unordered_map<string, cufftHandle> cufft_handle_cache;
} // jittor

View File

@ -58,7 +58,6 @@ void CufftFftOp::jit_run() {
auto* __restrict__ xp = x->mem_ptr;
auto* __restrict__ yp = y->mem_ptr;
cufftHandle& plan = cufft_handle;
int batch_size = x->shape[0];
int n1 = x->shape[1], n2 = x->shape[2];
int fft_size = batch_size * n1 * n2;
@ -70,11 +69,21 @@ void CufftFftOp::jit_run() {
} else if (TS == "float64") {
op_type = CUFFT_Z2Z;
}
CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(),
nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist
nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist
op_type, batch_size));
CUFFT_CALL(cufftSetStream(plan, 0));
JK& jk = get_jk();
jk.clear();
jk << fft[0] << "," << fft[1] << "," << TS << "," << batch_size;
auto iter = cufft_handle_cache.find(jk.to_string());
cufftHandle plan;
if (iter!=cufft_handle_cache.end()) plan = iter->second;
else {
CUFFT_CALL(cufftCreate(&plan));
CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(),
nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist
nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist
op_type, batch_size));
CUFFT_CALL(cufftSetStream(plan, 0));
cufft_handle_cache[jk.to_string()] = plan;
}
/*
* Note:
* Identical pointers to data and output arrays implies in-place transformation

View File

@ -12,19 +12,20 @@
namespace jittor {
cufftHandle cufft_handle;
unordered_map<string, cufftHandle> cufft_handle_cache;
struct cufft_initer {
inline cufft_initer() {
if (!get_device_count()) return;
CUFFT_CALL(cufftCreate(&cufft_handle));
LOGv << "cufftCreate finished";
}
inline ~cufft_initer() {
if (!get_device_count()) return;
CUFFT_CALL(cufftDestroy(cufft_handle));
for (auto it = cufft_handle_cache.begin(); it != cufft_handle_cache.end(); it++) {
CUFFT_CALL(cufftDestroy(it->second));
}
LOGv << "cufftDestroy finished";
}