mirror of https://github.com/Jittor/Jittor
add cufft plan map
This commit is contained in:
parent
254b8609fe
commit
0b5e367cb1
|
@ -19,6 +19,6 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
EXTERN_LIB cufftHandle cufft_handle;
|
||||
EXTERN_LIB unordered_map<string, cufftHandle> cufft_handle_cache;
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -58,7 +58,6 @@ void CufftFftOp::jit_run() {
|
|||
auto* __restrict__ xp = x->mem_ptr;
|
||||
auto* __restrict__ yp = y->mem_ptr;
|
||||
|
||||
cufftHandle& plan = cufft_handle;
|
||||
int batch_size = x->shape[0];
|
||||
int n1 = x->shape[1], n2 = x->shape[2];
|
||||
int fft_size = batch_size * n1 * n2;
|
||||
|
@ -70,11 +69,21 @@ void CufftFftOp::jit_run() {
|
|||
} else if (TS == "float64") {
|
||||
op_type = CUFFT_Z2Z;
|
||||
}
|
||||
CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(),
|
||||
nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist
|
||||
nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist
|
||||
op_type, batch_size));
|
||||
CUFFT_CALL(cufftSetStream(plan, 0));
|
||||
JK& jk = get_jk();
|
||||
jk.clear();
|
||||
jk << fft[0] << "," << fft[1] << "," << TS << "," << batch_size;
|
||||
auto iter = cufft_handle_cache.find(jk.to_string());
|
||||
cufftHandle plan;
|
||||
if (iter!=cufft_handle_cache.end()) plan = iter->second;
|
||||
else {
|
||||
CUFFT_CALL(cufftCreate(&plan));
|
||||
CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(),
|
||||
nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist
|
||||
nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist
|
||||
op_type, batch_size));
|
||||
CUFFT_CALL(cufftSetStream(plan, 0));
|
||||
cufft_handle_cache[jk.to_string()] = plan;
|
||||
}
|
||||
/*
|
||||
* Note:
|
||||
* Identical pointers to data and output arrays implies in-place transformation
|
||||
|
|
|
@ -12,19 +12,20 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
cufftHandle cufft_handle;
|
||||
unordered_map<string, cufftHandle> cufft_handle_cache;
|
||||
|
||||
struct cufft_initer {
|
||||
|
||||
inline cufft_initer() {
|
||||
if (!get_device_count()) return;
|
||||
CUFFT_CALL(cufftCreate(&cufft_handle));
|
||||
LOGv << "cufftCreate finished";
|
||||
}
|
||||
|
||||
inline ~cufft_initer() {
|
||||
if (!get_device_count()) return;
|
||||
CUFFT_CALL(cufftDestroy(cufft_handle));
|
||||
for (auto it = cufft_handle_cache.begin(); it != cufft_handle_cache.end(); it++) {
|
||||
CUFFT_CALL(cufftDestroy(it->second));
|
||||
}
|
||||
LOGv << "cufftDestroy finished";
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue