mirror of https://github.com/Jittor/Jittor
fft
This commit is contained in:
parent
5d191e6247
commit
e9f681de53
|
@ -26,7 +26,7 @@ with lock.lock_scope():
|
|||
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
|
||||
if core.get_device_count() == 0:
|
||||
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
|
||||
from .compile_extern import cudnn, curand, cublas
|
||||
from .compile_extern import cudnn, curand, cublas, cufft
|
||||
from .init_cupy import numpy2cupy
|
||||
|
||||
import contextlib
|
||||
|
|
|
@ -2,7 +2,7 @@ from jittor_core import *
|
|||
from jittor_core.ops import *
|
||||
from .misc import *
|
||||
from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse
|
||||
from .compile_extern import cublas as cublas, cudnn as cudnn, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
|
||||
from .compile_extern import cublas as cublas, cudnn as cudnn, curand as curand, cufft as cufft, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
|
||||
from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops
|
||||
from .contrib import concat as concat
|
||||
from .nn import matmul as matmul
|
||||
|
|
|
@ -209,7 +209,7 @@ def setup_cuda_extern():
|
|||
line = traceback.format_exc()
|
||||
LOG.w(f"CUDA found but cub is not loaded:\n{line}")
|
||||
|
||||
libs = ["cublas", "cudnn", "curand"]
|
||||
libs = ["cublas", "cudnn", "curand", "cufft"]
|
||||
for lib_name in libs:
|
||||
try:
|
||||
setup_cuda_lib(lib_name, extra_flags=link_cuda_extern)
|
||||
|
@ -547,7 +547,7 @@ if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
|||
except:
|
||||
pass
|
||||
|
||||
cudnn = cublas = curand = None
|
||||
cudnn = cublas = curand = cufft = None
|
||||
setup_mpi()
|
||||
in_mpi = inside_mpi()
|
||||
rank = mpi.world_rank() if in_mpi else 0
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
/*
|
||||
* Copyright 2020 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NOTICE TO LICENSEE:
|
||||
*
|
||||
* This source code and/or documentation ("Licensed Deliverables") are
|
||||
* subject to NVIDIA intellectual property rights under U.S. and
|
||||
* international Copyright laws.
|
||||
*
|
||||
* These Licensed Deliverables contained herein is PROPRIETARY and
|
||||
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
||||
* conditions of a form of NVIDIA software license agreement by and
|
||||
* between NVIDIA and Licensee ("License Agreement") or electronically
|
||||
* accepted by Licensee. Notwithstanding any terms or conditions to
|
||||
* the contrary in the License Agreement, reproduction or disclosure
|
||||
* of the Licensed Deliverables to any third party without the express
|
||||
* written consent of NVIDIA is prohibited.
|
||||
*
|
||||
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
||||
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
||||
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
||||
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
||||
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
||||
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
||||
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
||||
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
||||
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
||||
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
||||
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
||||
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
||||
* OF THESE LICENSED DELIVERABLES.
|
||||
*
|
||||
* U.S. Government End Users. These Licensed Deliverables are a
|
||||
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
||||
* 1995), consisting of "commercial computer software" and "commercial
|
||||
* computer software documentation" as such terms are used in 48
|
||||
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
||||
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
||||
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
||||
* U.S. Government End Users acquire the Licensed Deliverables with
|
||||
* only those rights set forth herein.
|
||||
*
|
||||
* Any use of the Licensed Deliverables in individual and commercial
|
||||
* software must include, in the user documentation and internal
|
||||
* comments to the code, the above Disclaimer and U.S. Government End
|
||||
* Users Notice.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// CUDA API error checking
|
||||
#ifndef CUDA_RT_CALL
|
||||
#define CUDA_RT_CALL( call ) \
|
||||
{ \
|
||||
auto status = static_cast<cudaError_t>( call ); \
|
||||
if ( status != cudaSuccess ) \
|
||||
fprintf( stderr, \
|
||||
"ERROR: CUDA RT call \"%s\" in line %d of file %s failed " \
|
||||
"with " \
|
||||
"%s (%d).\n", \
|
||||
#call, \
|
||||
__LINE__, \
|
||||
__FILE__, \
|
||||
cudaGetErrorString( status ), \
|
||||
status ); \
|
||||
}
|
||||
#endif // CUDA_RT_CALL
|
||||
|
||||
// cufft API error chekcing
|
||||
#ifndef CUFFT_CALL
|
||||
#define CUFFT_CALL( call ) \
|
||||
{ \
|
||||
auto status = static_cast<cufftResult>( call ); \
|
||||
if ( status != CUFFT_SUCCESS ) \
|
||||
fprintf( stderr, \
|
||||
"ERROR: CUFFT call \"%s\" in line %d of file %s failed " \
|
||||
"with " \
|
||||
"code (%d).\n", \
|
||||
#call, \
|
||||
__LINE__, \
|
||||
__FILE__, \
|
||||
status ); \
|
||||
}
|
||||
#endif // CUFFT_CALL
|
||||
|
||||
// template <> struct traits<CUFFT_C2C> {
|
||||
// // scalar type
|
||||
// typedef float T;
|
||||
|
||||
// using input_host_type = std::complex<T>;
|
||||
// using input_device_type = cufftComplex;
|
||||
|
||||
// using output_host_type = std::complex<T>;
|
||||
// using output_device_type = cufftComplex;
|
||||
|
||||
// static constexpr cufftType_t transformType = CUDA_R_64F;
|
||||
|
||||
// template <typename RNG> inline static T rand(RNG &gen) {
|
||||
// return make_cuFloatComplex((S)gen(), (S)gen());
|
||||
// }
|
||||
// };
|
|
@ -0,0 +1,78 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
||||
#include "var.h"
|
||||
#include "init.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <cufft.h>
|
||||
#include "helper_cuda.h"
|
||||
#include "cufft_fft_op.h"
|
||||
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#include <cufftXt.h>
|
||||
#include "cufft_utils.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
static auto make_cufft_fft = get_op_info("cufft_fft")
|
||||
.get_constructor<VarPtr, Var*, bool>();
|
||||
CufftFftOp::CufftFftOp(Var* x, bool inverse) : x(x), inverse(inverse) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(x->shape, x->dtype());
|
||||
}
|
||||
|
||||
VarPtr CufftFftOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return make_cufft_fft(dout, !inverse);
|
||||
}
|
||||
|
||||
void CufftFftOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << y->dtype();
|
||||
jk << _CS("][I:")<<inverse<<"]";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void CufftFftOp::jit_run() {
|
||||
}
|
||||
#else // JIT_cuda
|
||||
void CufftFftOp::jit_run() {
|
||||
auto* __restrict__ xp = x->mem_ptr;
|
||||
auto* __restrict__ yp = y->mem_ptr;
|
||||
|
||||
cufftHandle plan;
|
||||
int batch_size = x->shape[0];
|
||||
int n1 = x->shape[1], n2 = x->shape[2];
|
||||
int fft_size = batch_size * n1 * n2;
|
||||
std::array<int, 2> fft = {n1, n2};
|
||||
|
||||
CUFFT_CALL(cufftCreate(&plan));
|
||||
CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(),
|
||||
nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist
|
||||
nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist
|
||||
CUFFT_C2C, batch_size));
|
||||
CUFFT_CALL(cufftSetStream(plan, 0));
|
||||
/*
|
||||
* Note:
|
||||
* Identical pointers to data and output arrays implies in-place transformation
|
||||
*/
|
||||
CUDA_RT_CALL(cudaStreamSynchronize(0));
|
||||
CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
|
||||
// CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, CUFFT_INVERSE));
|
||||
CUDA_RT_CALL(cudaStreamSynchronize(0));
|
||||
|
||||
CUFFT_CALL(cufftDestroy(plan));
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,27 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
//TODO: support FFT2D only now.
|
||||
struct CufftFftOp : Op {
|
||||
bool inverse;
|
||||
Var* x, * y;
|
||||
NanoString type;
|
||||
CufftFftOp(Var* x, bool inverse=false);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
|
||||
const char* name() const override { return "cufft_fft"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -2770,3 +2770,12 @@ class Bilinear(Module):
|
|||
|
||||
def execute(self, in1, in2):
|
||||
return bilinear(in1, in2, self.weight, self.bias)
|
||||
|
||||
#TODO: support FFT2D only now.
|
||||
def fft2(x, inverse=False):
|
||||
assert(len(x.shape) == 4)
|
||||
assert(x.shape[3] == 2)
|
||||
y = jt.compile_extern.cufft_ops.cufft_fft(x, inverse)
|
||||
if inverse:
|
||||
y /= x.shape[1] * x.shape[2]
|
||||
return y
|
|
@ -263,7 +263,7 @@ def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs
|
|||
help_name = ""+target_scope_name+'.'+name
|
||||
else:
|
||||
help_name = name
|
||||
if lib_name in ["mpi", "nccl", "cudnn", "curand", "cublas", "mkl"]:
|
||||
if lib_name in ["mpi", "nccl", "cudnn", "curand" "cufft", "cublas", "mkl"]:
|
||||
help_name = lib_name+'.'+help_name
|
||||
help_cmd = f"help(jt.{help_name})"
|
||||
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import unittest
|
||||
from .test_log import find_log_with_re
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
from jittor import nn
|
||||
|
||||
#requires torch>=1.10.1
|
||||
class TestFFTOp(unittest.TestCase):
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_fft_forward(self):
|
||||
img = cv2.imread("test.jpg")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.imread("test2.jpg")
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
X = np.stack([img, img2], 0)
|
||||
|
||||
# torch
|
||||
x = torch.Tensor(X)
|
||||
y = torch.fft.fft2(x)
|
||||
y_torch_real = y.numpy().real
|
||||
y_torch_imag = y.numpy().imag
|
||||
|
||||
#jittor
|
||||
x = jt.array(X,dtype=jt.float32)
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y = nn.fft2(x)
|
||||
y_jt_real = y[:, :, :, 0].data
|
||||
y_jt_imag = y[:, :, :, 1].data
|
||||
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
|
||||
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_ifft_forward(self):
|
||||
img = cv2.imread("test.jpg")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.imread("test2.jpg")
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
X = np.stack([img, img2], 0)
|
||||
|
||||
# torch
|
||||
x = torch.Tensor(X)
|
||||
y = torch.fft.fft2(x)
|
||||
y_torch_real = y.numpy().real
|
||||
y_torch_imag = y.numpy().imag
|
||||
y_ori = torch.fft.ifft2(y)
|
||||
y_ori_torch_real = y_ori.real.numpy()
|
||||
assert(np.allclose(y_ori_torch_real, X, atol=1))
|
||||
|
||||
#jittor
|
||||
x = jt.array(X,dtype=jt.float32)
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y = nn.fft2(x)
|
||||
y_ori = nn.fft2(y, True)
|
||||
y_jt_real = y[:, :, :, 0].data
|
||||
y_jt_imag = y[:, :, :, 1].data
|
||||
y_ori_jt_real = y_ori[:, :, :, 0].data
|
||||
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
|
||||
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
|
||||
assert(np.allclose(y_ori_jt_real, X, atol=1))
|
||||
assert(np.allclose(y_ori_jt_real, y_ori_torch_real, atol=1))
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_fft_backward(self):
|
||||
img = cv2.imread("test.jpg")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.imread("test2.jpg")
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
X = np.stack([img, img2], 0)
|
||||
T1 = np.random.rand(1,512,512)
|
||||
T2 = np.random.rand(1,512,512)
|
||||
|
||||
# torch
|
||||
x = torch.Tensor(X)
|
||||
x.requires_grad = True
|
||||
t1 = torch.Tensor(T1)
|
||||
t2 = torch.Tensor(T2)
|
||||
y_mid = torch.fft.fft2(x)
|
||||
y = torch.fft.fft2(y_mid)
|
||||
real = y.real
|
||||
imag = y.imag
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
loss.backward()
|
||||
grad_x_torch = x.grad.detach().numpy()
|
||||
|
||||
#jittor
|
||||
x = jt.array(X,dtype=jt.float32)
|
||||
t1 = jt.array(T1,dtype=jt.float32)
|
||||
t2 = jt.array(T2,dtype=jt.float32)
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y_mid = nn.fft2(x)
|
||||
y = nn.fft2(y_mid)
|
||||
real = y[:, :, :, 0]
|
||||
imag = y[:, :, :, 1]
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
|
||||
assert(np.allclose(grad_x_jt, grad_x_torch, atol=1))
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_ifft_backward(self):
|
||||
img = cv2.imread("test.jpg")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.imread("test2.jpg")
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
X = np.stack([img, img2], 0)
|
||||
T1 = np.random.rand(1,512,512)
|
||||
T2 = np.random.rand(1,512,512)
|
||||
|
||||
# torch
|
||||
x = torch.Tensor(X)
|
||||
x.requires_grad = True
|
||||
t1 = torch.Tensor(T1)
|
||||
t2 = torch.Tensor(T2)
|
||||
y_mid = torch.fft.ifft2(x)
|
||||
y = torch.fft.ifft2(y_mid)
|
||||
real = y.real
|
||||
imag = y.imag
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
loss.backward()
|
||||
grad_x_torch = x.grad.detach().numpy()
|
||||
|
||||
#jittor
|
||||
x = jt.array(X,dtype=jt.float32)
|
||||
t1 = jt.array(T1,dtype=jt.float32)
|
||||
t2 = jt.array(T2,dtype=jt.float32)
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y_mid = nn.fft2(x, True)
|
||||
y = nn.fft2(y_mid, True)
|
||||
real = y[:, :, :, 0]
|
||||
imag = y[:, :, :, 1]
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
|
||||
assert(np.allclose(grad_x_jt, grad_x_torch))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue