mirror of https://github.com/Jittor/Jittor
support float64 fft
This commit is contained in:
parent
e087a56d86
commit
fa62b3a217
|
@ -36,8 +36,13 @@ VarPtr CufftFftOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void CufftFftOp::jit_prepare(JK& jk) {
|
||||
if ((y->dtype() != "float32") && (y->dtype() != "float64")){
|
||||
printf("not supported fft dtype: %s\n", y->dtype().to_cstring());
|
||||
ASSERT(false);
|
||||
}
|
||||
jk << _CS("[T:") << y->dtype();
|
||||
jk << _CS("][I:")<<inverse<<"]";
|
||||
jk << _CS("[TS:\"")<<y->dtype()<<"\"]";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
@ -56,19 +61,26 @@ void CufftFftOp::jit_run() {
|
|||
std::array<int, 2> fft = {n1, n2};
|
||||
|
||||
CUFFT_CALL(cufftCreate(&plan));
|
||||
auto op_type = CUFFT_C2C;
|
||||
if (TS == "float32") {
|
||||
op_type = CUFFT_C2C;
|
||||
} else if (TS == "float64") {
|
||||
op_type = CUFFT_Z2Z;
|
||||
}
|
||||
CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(),
|
||||
nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist
|
||||
nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist
|
||||
CUFFT_C2C, batch_size));
|
||||
op_type, batch_size));
|
||||
CUFFT_CALL(cufftSetStream(plan, 0));
|
||||
/*
|
||||
* Note:
|
||||
* Identical pointers to data and output arrays implies in-place transformation
|
||||
*/
|
||||
CUDA_RT_CALL(cudaStreamSynchronize(0));
|
||||
CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
|
||||
// CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, CUFFT_INVERSE));
|
||||
CUDA_RT_CALL(cudaStreamSynchronize(0));
|
||||
if (TS == "float32") {
|
||||
CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
|
||||
} else if (TS == "float64") {
|
||||
CUFFT_CALL(cufftExecZ2Z(plan, (cufftDoubleComplex *)xp, (cufftDoubleComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
|
||||
}
|
||||
|
||||
CUFFT_CALL(cufftDestroy(plan));
|
||||
}
|
||||
|
|
|
@ -19,8 +19,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_fft_forward(self):
|
||||
img = jt.rand(256, 300)
|
||||
img2 = jt.rand(256, 300)
|
||||
img = np.random.rand(256, 300)
|
||||
img2 = np.random.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
|
||||
# torch
|
||||
|
@ -41,8 +41,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_ifft_forward(self):
|
||||
img = jt.rand(256, 300)
|
||||
img2 = jt.rand(256, 300)
|
||||
img = np.random.rand(256, 300)
|
||||
img2 = np.random.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
|
||||
# torch
|
||||
|
@ -70,8 +70,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_fft_backward(self):
|
||||
img = jt.rand(256, 300)
|
||||
img2 = jt.rand(256, 300)
|
||||
img = np.random.rand(256, 300)
|
||||
img2 = np.random.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
T1 = np.random.rand(1,256,300)
|
||||
T2 = np.random.rand(1,256,300)
|
||||
|
@ -105,8 +105,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_ifft_backward(self):
|
||||
img = jt.rand(256, 300)
|
||||
img2 = jt.rand(256, 300)
|
||||
img = np.random.rand(256, 300)
|
||||
img2 = np.random.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
T1 = np.random.rand(1,256,300)
|
||||
T2 = np.random.rand(1,256,300)
|
||||
|
@ -137,5 +137,126 @@ class TestFFTOp(unittest.TestCase):
|
|||
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
|
||||
assert(np.allclose(grad_x_jt, grad_x_torch))
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_fft_float64_forward(self):
|
||||
img = np.random.rand(256, 300)
|
||||
img2 = np.random.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
|
||||
# torch
|
||||
x = torch.DoubleTensor(X)
|
||||
y = torch.fft.fft2(x)
|
||||
y_torch_real = y.numpy().real
|
||||
y_torch_imag = y.numpy().imag
|
||||
|
||||
#jittor
|
||||
x = jt.array(X).float64()
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y = nn.fft2(x)
|
||||
y_jt_real = y[:, :, :, 0].data
|
||||
y_jt_imag = y[:, :, :, 1].data
|
||||
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
|
||||
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_ifft_float64_forward(self):
|
||||
img = np.random.rand(256, 300)
|
||||
img2 = np.random.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
|
||||
# torch
|
||||
x = torch.DoubleTensor(X)
|
||||
y = torch.fft.fft2(x)
|
||||
y_torch_real = y.numpy().real
|
||||
y_torch_imag = y.numpy().imag
|
||||
y_ori = torch.fft.ifft2(y)
|
||||
y_ori_torch_real = y_ori.real.numpy()
|
||||
assert(np.allclose(y_ori_torch_real, X, atol=1))
|
||||
|
||||
#jittor
|
||||
x = jt.array(X).float64()
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y = nn.fft2(x)
|
||||
y_ori = nn.fft2(y, True)
|
||||
y_jt_real = y[:, :, :, 0].data
|
||||
y_jt_imag = y[:, :, :, 1].data
|
||||
y_ori_jt_real = y_ori[:, :, :, 0].data
|
||||
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
|
||||
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
|
||||
assert(np.allclose(y_ori_jt_real, X, atol=1))
|
||||
assert(np.allclose(y_ori_jt_real, y_ori_torch_real, atol=1))
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_fft_float64_backward(self):
|
||||
img = np.random.rand(256, 300)
|
||||
img2 = np.random.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
T1 = np.random.rand(1,256,300)
|
||||
T2 = np.random.rand(1,256,300)
|
||||
|
||||
# torch
|
||||
x = torch.DoubleTensor(X)
|
||||
x.requires_grad = True
|
||||
t1 = torch.DoubleTensor(T1)
|
||||
t2 = torch.DoubleTensor(T2)
|
||||
y_mid = torch.fft.fft2(x)
|
||||
y = torch.fft.fft2(y_mid)
|
||||
real = y.real
|
||||
imag = y.imag
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
loss.backward()
|
||||
grad_x_torch = x.grad.detach().numpy()
|
||||
|
||||
#jittor
|
||||
x = jt.array(X).float64()
|
||||
t1 = jt.array(T1).float64()
|
||||
t2 = jt.array(T2).float64()
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y_mid = nn.fft2(x)
|
||||
y = nn.fft2(y_mid)
|
||||
real = y[:, :, :, 0]
|
||||
imag = y[:, :, :, 1]
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
|
||||
assert(np.allclose(grad_x_jt, grad_x_torch, atol=1))
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_ifft_float64_backward(self):
|
||||
img = np.random.rand(256, 300)
|
||||
img2 = np.random.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
T1 = np.random.rand(1,256,300)
|
||||
T2 = np.random.rand(1,256,300)
|
||||
|
||||
# torch
|
||||
x = torch.DoubleTensor(X)
|
||||
x.requires_grad = True
|
||||
t1 = torch.DoubleTensor(T1)
|
||||
t2 = torch.DoubleTensor(T2)
|
||||
y_mid = torch.fft.ifft2(x)
|
||||
y = torch.fft.ifft2(y_mid)
|
||||
real = y.real
|
||||
imag = y.imag
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
loss.backward()
|
||||
grad_x_torch = x.grad.detach().numpy()
|
||||
|
||||
#jittor
|
||||
x = jt.array(X).float64()
|
||||
t1 = jt.array(T1).float64()
|
||||
t2 = jt.array(T2).float64()
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y_mid = nn.fft2(x, True)
|
||||
y = nn.fft2(y_mid, True)
|
||||
real = y[:, :, :, 0]
|
||||
imag = y[:, :, :, 1]
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
|
||||
assert(np.allclose(grad_x_jt, grad_x_torch))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue