add cufft wrapper

This commit is contained in:
cxjyxx_me 2022-03-28 07:45:33 -04:00
parent b642b8f1d1
commit 254b8609fe
3 changed files with 63 additions and 4 deletions

View File

@ -0,0 +1,24 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include <cuda_runtime.h>
#include <cufftXt.h>
#include "cufft_utils.h"
#include "utils/log.h"
#include "helper_cuda.h"
#include "fp16_emu.h"
#include "common.h"
namespace jittor {
EXTERN_LIB cufftHandle cufft_handle;
} // jittor

View File

@ -1,6 +1,9 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// Maintainers:
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
@ -11,6 +14,7 @@
#include <cufft.h>
#include "helper_cuda.h"
#include "cufft_fft_op.h"
#include "cufft_wrapper.h"
#include <complex>
#include <iostream>
@ -54,13 +58,12 @@ void CufftFftOp::jit_run() {
auto* __restrict__ xp = x->mem_ptr;
auto* __restrict__ yp = y->mem_ptr;
cufftHandle plan;
cufftHandle& plan = cufft_handle;
int batch_size = x->shape[0];
int n1 = x->shape[1], n2 = x->shape[2];
int fft_size = batch_size * n1 * n2;
std::array<int, 2> fft = {n1, n2};
CUFFT_CALL(cufftCreate(&plan));
auto op_type = CUFFT_C2C;
if (TS == "float32") {
op_type = CUFFT_C2C;
@ -82,7 +85,6 @@ void CufftFftOp::jit_run() {
CUFFT_CALL(cufftExecZ2Z(plan, (cufftDoubleComplex *)xp, (cufftDoubleComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
}
CUFFT_CALL(cufftDestroy(plan));
}
#endif // JIT_cpu
#endif // JIT

View File

@ -0,0 +1,33 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "cufft_wrapper.h"
#include "misc/cuda_flags.h"
namespace jittor {
cufftHandle cufft_handle;
struct cufft_initer {
inline cufft_initer() {
if (!get_device_count()) return;
CUFFT_CALL(cufftCreate(&cufft_handle));
LOGv << "cufftCreate finished";
}
inline ~cufft_initer() {
if (!get_device_count()) return;
CUFFT_CALL(cufftDestroy(cufft_handle));
LOGv << "cufftDestroy finished";
}
} init;
} // jittor