mirror of https://github.com/Jittor/Jittor
add cufft wrapper
This commit is contained in:
parent
b642b8f1d1
commit
254b8609fe
|
@ -0,0 +1,24 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <cuda_runtime.h>
|
||||
#include <cufftXt.h>
|
||||
#include "cufft_utils.h"
|
||||
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_emu.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
EXTERN_LIB cufftHandle cufft_handle;
|
||||
|
||||
} // jittor
|
|
@ -1,6 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -11,6 +14,7 @@
|
|||
#include <cufft.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "cufft_fft_op.h"
|
||||
#include "cufft_wrapper.h"
|
||||
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
|
@ -54,13 +58,12 @@ void CufftFftOp::jit_run() {
|
|||
auto* __restrict__ xp = x->mem_ptr;
|
||||
auto* __restrict__ yp = y->mem_ptr;
|
||||
|
||||
cufftHandle plan;
|
||||
cufftHandle& plan = cufft_handle;
|
||||
int batch_size = x->shape[0];
|
||||
int n1 = x->shape[1], n2 = x->shape[2];
|
||||
int fft_size = batch_size * n1 * n2;
|
||||
std::array<int, 2> fft = {n1, n2};
|
||||
|
||||
CUFFT_CALL(cufftCreate(&plan));
|
||||
auto op_type = CUFFT_C2C;
|
||||
if (TS == "float32") {
|
||||
op_type = CUFFT_C2C;
|
||||
|
@ -82,7 +85,6 @@ void CufftFftOp::jit_run() {
|
|||
CUFFT_CALL(cufftExecZ2Z(plan, (cufftDoubleComplex *)xp, (cufftDoubleComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
|
||||
}
|
||||
|
||||
CUFFT_CALL(cufftDestroy(plan));
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "cufft_wrapper.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
cufftHandle cufft_handle;
|
||||
|
||||
struct cufft_initer {
|
||||
|
||||
inline cufft_initer() {
|
||||
if (!get_device_count()) return;
|
||||
CUFFT_CALL(cufftCreate(&cufft_handle));
|
||||
LOGv << "cufftCreate finished";
|
||||
}
|
||||
|
||||
inline ~cufft_initer() {
|
||||
if (!get_device_count()) return;
|
||||
CUFFT_CALL(cufftDestroy(cufft_handle));
|
||||
LOGv << "cufftDestroy finished";
|
||||
}
|
||||
|
||||
} init;
|
||||
|
||||
} // jittor
|
Loading…
Reference in New Issue