diff --git a/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc b/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc index 9078ae68..e7432cfb 100644 --- a/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc +++ b/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc @@ -36,8 +36,13 @@ VarPtr CufftFftOp::grad(Var* out, Var* dout, Var* v, int v_index) { } void CufftFftOp::jit_prepare(JK& jk) { + if ((y->dtype() != "float32") && (y->dtype() != "float64")){ + printf("not supported fft dtype: %s\n", y->dtype().to_cstring()); + ASSERT(false); + } jk << _CS("[T:") << y->dtype(); jk << _CS("][I:")<dtype()<<"\"]"; } #else // JIT @@ -56,19 +61,26 @@ void CufftFftOp::jit_run() { std::array fft = {n1, n2}; CUFFT_CALL(cufftCreate(&plan)); + auto op_type = CUFFT_C2C; + if (TS == "float32") { + op_type = CUFFT_C2C; + } else if (TS == "float64") { + op_type = CUFFT_Z2Z; + } CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(), nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist - CUFFT_C2C, batch_size)); + op_type, batch_size)); CUFFT_CALL(cufftSetStream(plan, 0)); /* * Note: * Identical pointers to data and output arrays implies in-place transformation */ - CUDA_RT_CALL(cudaStreamSynchronize(0)); - CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD)); - // CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, CUFFT_INVERSE)); - CUDA_RT_CALL(cudaStreamSynchronize(0)); + if (TS == "float32") { + CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD)); + } else if (TS == "float64") { + CUFFT_CALL(cufftExecZ2Z(plan, (cufftDoubleComplex *)xp, (cufftDoubleComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD)); + } CUFFT_CALL(cufftDestroy(plan)); } diff --git a/python/jittor/test/test_fft_op.py b/python/jittor/test/test_fft_op.py index 74f1f5a4..ca6685cc 100644 --- a/python/jittor/test/test_fft_op.py +++ b/python/jittor/test/test_fft_op.py @@ -19,8 +19,8 @@ class TestFFTOp(unittest.TestCase): @unittest.skipIf(not jt.has_cuda, "Cuda not found") @jt.flag_scope(use_cuda=1) def test_fft_forward(self): - img = jt.rand(256, 300) - img2 = jt.rand(256, 300) + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) X = np.stack([img, img2], 0) # torch @@ -41,8 +41,8 @@ class TestFFTOp(unittest.TestCase): @unittest.skipIf(not jt.has_cuda, "Cuda not found") @jt.flag_scope(use_cuda=1) def test_ifft_forward(self): - img = jt.rand(256, 300) - img2 = jt.rand(256, 300) + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) X = np.stack([img, img2], 0) # torch @@ -70,8 +70,8 @@ class TestFFTOp(unittest.TestCase): @unittest.skipIf(not jt.has_cuda, "Cuda not found") @jt.flag_scope(use_cuda=1) def test_fft_backward(self): - img = jt.rand(256, 300) - img2 = jt.rand(256, 300) + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) X = np.stack([img, img2], 0) T1 = np.random.rand(1,256,300) T2 = np.random.rand(1,256,300) @@ -105,8 +105,8 @@ class TestFFTOp(unittest.TestCase): @unittest.skipIf(not jt.has_cuda, "Cuda not found") @jt.flag_scope(use_cuda=1) def test_ifft_backward(self): - img = jt.rand(256, 300) - img2 = jt.rand(256, 300) + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) X = np.stack([img, img2], 0) T1 = np.random.rand(1,256,300) T2 = np.random.rand(1,256,300) @@ -137,5 +137,126 @@ class TestFFTOp(unittest.TestCase): grad_x_jt = jt.grad(loss, x).data[:, :, :, 0] assert(np.allclose(grad_x_jt, grad_x_torch)) + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_fft_float64_forward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + + # torch + x = torch.DoubleTensor(X) + y = torch.fft.fft2(x) + y_torch_real = y.numpy().real + y_torch_imag = y.numpy().imag + + #jittor + x = jt.array(X).float64() + x = jt.stack([x, jt.zeros_like(x)], 3) + y = nn.fft2(x) + y_jt_real = y[:, :, :, 0].data + y_jt_imag = y[:, :, :, 1].data + assert(np.allclose(y_torch_real, y_jt_real, atol=1)) + assert(np.allclose(y_torch_imag, y_jt_imag, atol=1)) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_ifft_float64_forward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + + # torch + x = torch.DoubleTensor(X) + y = torch.fft.fft2(x) + y_torch_real = y.numpy().real + y_torch_imag = y.numpy().imag + y_ori = torch.fft.ifft2(y) + y_ori_torch_real = y_ori.real.numpy() + assert(np.allclose(y_ori_torch_real, X, atol=1)) + + #jittor + x = jt.array(X).float64() + x = jt.stack([x, jt.zeros_like(x)], 3) + y = nn.fft2(x) + y_ori = nn.fft2(y, True) + y_jt_real = y[:, :, :, 0].data + y_jt_imag = y[:, :, :, 1].data + y_ori_jt_real = y_ori[:, :, :, 0].data + assert(np.allclose(y_torch_real, y_jt_real, atol=1)) + assert(np.allclose(y_torch_imag, y_jt_imag, atol=1)) + assert(np.allclose(y_ori_jt_real, X, atol=1)) + assert(np.allclose(y_ori_jt_real, y_ori_torch_real, atol=1)) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_fft_float64_backward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + T1 = np.random.rand(1,256,300) + T2 = np.random.rand(1,256,300) + + # torch + x = torch.DoubleTensor(X) + x.requires_grad = True + t1 = torch.DoubleTensor(T1) + t2 = torch.DoubleTensor(T2) + y_mid = torch.fft.fft2(x) + y = torch.fft.fft2(y_mid) + real = y.real + imag = y.imag + loss = (real * t1).sum() + (imag * t2).sum() + loss.backward() + grad_x_torch = x.grad.detach().numpy() + + #jittor + x = jt.array(X).float64() + t1 = jt.array(T1).float64() + t2 = jt.array(T2).float64() + x = jt.stack([x, jt.zeros_like(x)], 3) + y_mid = nn.fft2(x) + y = nn.fft2(y_mid) + real = y[:, :, :, 0] + imag = y[:, :, :, 1] + loss = (real * t1).sum() + (imag * t2).sum() + grad_x_jt = jt.grad(loss, x).data[:, :, :, 0] + assert(np.allclose(grad_x_jt, grad_x_torch, atol=1)) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_ifft_float64_backward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + T1 = np.random.rand(1,256,300) + T2 = np.random.rand(1,256,300) + + # torch + x = torch.DoubleTensor(X) + x.requires_grad = True + t1 = torch.DoubleTensor(T1) + t2 = torch.DoubleTensor(T2) + y_mid = torch.fft.ifft2(x) + y = torch.fft.ifft2(y_mid) + real = y.real + imag = y.imag + loss = (real * t1).sum() + (imag * t2).sum() + loss.backward() + grad_x_torch = x.grad.detach().numpy() + + #jittor + x = jt.array(X).float64() + t1 = jt.array(T1).float64() + t2 = jt.array(T2).float64() + x = jt.stack([x, jt.zeros_like(x)], 3) + y_mid = nn.fft2(x, True) + y = nn.fft2(y_mid, True) + real = y[:, :, :, 0] + imag = y[:, :, :, 1] + loss = (real * t1).sum() + (imag * t2).sum() + grad_x_jt = jt.grad(loss, x).data[:, :, :, 0] + assert(np.allclose(grad_x_jt, grad_x_torch)) + if __name__ == "__main__": unittest.main()