support float64 fft

This commit is contained in:
cxjyxx_me 2022-03-26 23:20:16 -04:00
parent e087a56d86
commit fa62b3a217
2 changed files with 146 additions and 13 deletions

View File

@ -36,8 +36,13 @@ VarPtr CufftFftOp::grad(Var* out, Var* dout, Var* v, int v_index) {
}
void CufftFftOp::jit_prepare(JK& jk) {
if ((y->dtype() != "float32") && (y->dtype() != "float64")){
printf("not supported fft dtype: %s\n", y->dtype().to_cstring());
ASSERT(false);
}
jk << _CS("[T:") << y->dtype();
jk << _CS("][I:")<<inverse<<"]";
jk << _CS("[TS:\"")<<y->dtype()<<"\"]";
}
#else // JIT
@ -56,19 +61,26 @@ void CufftFftOp::jit_run() {
std::array<int, 2> fft = {n1, n2};
CUFFT_CALL(cufftCreate(&plan));
auto op_type = CUFFT_C2C;
if (TS == "float32") {
op_type = CUFFT_C2C;
} else if (TS == "float64") {
op_type = CUFFT_Z2Z;
}
CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(),
nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist
nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist
CUFFT_C2C, batch_size));
op_type, batch_size));
CUFFT_CALL(cufftSetStream(plan, 0));
/*
* Note:
* Identical pointers to data and output arrays implies in-place transformation
*/
CUDA_RT_CALL(cudaStreamSynchronize(0));
CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
// CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, CUFFT_INVERSE));
CUDA_RT_CALL(cudaStreamSynchronize(0));
if (TS == "float32") {
CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
} else if (TS == "float64") {
CUFFT_CALL(cufftExecZ2Z(plan, (cufftDoubleComplex *)xp, (cufftDoubleComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
}
CUFFT_CALL(cufftDestroy(plan));
}

View File

@ -19,8 +19,8 @@ class TestFFTOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_fft_forward(self):
img = jt.rand(256, 300)
img2 = jt.rand(256, 300)
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
# torch
@ -41,8 +41,8 @@ class TestFFTOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_ifft_forward(self):
img = jt.rand(256, 300)
img2 = jt.rand(256, 300)
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
# torch
@ -70,8 +70,8 @@ class TestFFTOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_fft_backward(self):
img = jt.rand(256, 300)
img2 = jt.rand(256, 300)
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
T1 = np.random.rand(1,256,300)
T2 = np.random.rand(1,256,300)
@ -105,8 +105,8 @@ class TestFFTOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_ifft_backward(self):
img = jt.rand(256, 300)
img2 = jt.rand(256, 300)
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
T1 = np.random.rand(1,256,300)
T2 = np.random.rand(1,256,300)
@ -137,5 +137,126 @@ class TestFFTOp(unittest.TestCase):
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
assert(np.allclose(grad_x_jt, grad_x_torch))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_fft_float64_forward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
# torch
x = torch.DoubleTensor(X)
y = torch.fft.fft2(x)
y_torch_real = y.numpy().real
y_torch_imag = y.numpy().imag
#jittor
x = jt.array(X).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y = nn.fft2(x)
y_jt_real = y[:, :, :, 0].data
y_jt_imag = y[:, :, :, 1].data
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_ifft_float64_forward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
# torch
x = torch.DoubleTensor(X)
y = torch.fft.fft2(x)
y_torch_real = y.numpy().real
y_torch_imag = y.numpy().imag
y_ori = torch.fft.ifft2(y)
y_ori_torch_real = y_ori.real.numpy()
assert(np.allclose(y_ori_torch_real, X, atol=1))
#jittor
x = jt.array(X).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y = nn.fft2(x)
y_ori = nn.fft2(y, True)
y_jt_real = y[:, :, :, 0].data
y_jt_imag = y[:, :, :, 1].data
y_ori_jt_real = y_ori[:, :, :, 0].data
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
assert(np.allclose(y_ori_jt_real, X, atol=1))
assert(np.allclose(y_ori_jt_real, y_ori_torch_real, atol=1))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_fft_float64_backward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
T1 = np.random.rand(1,256,300)
T2 = np.random.rand(1,256,300)
# torch
x = torch.DoubleTensor(X)
x.requires_grad = True
t1 = torch.DoubleTensor(T1)
t2 = torch.DoubleTensor(T2)
y_mid = torch.fft.fft2(x)
y = torch.fft.fft2(y_mid)
real = y.real
imag = y.imag
loss = (real * t1).sum() + (imag * t2).sum()
loss.backward()
grad_x_torch = x.grad.detach().numpy()
#jittor
x = jt.array(X).float64()
t1 = jt.array(T1).float64()
t2 = jt.array(T2).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y_mid = nn.fft2(x)
y = nn.fft2(y_mid)
real = y[:, :, :, 0]
imag = y[:, :, :, 1]
loss = (real * t1).sum() + (imag * t2).sum()
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
assert(np.allclose(grad_x_jt, grad_x_torch, atol=1))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_ifft_float64_backward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
T1 = np.random.rand(1,256,300)
T2 = np.random.rand(1,256,300)
# torch
x = torch.DoubleTensor(X)
x.requires_grad = True
t1 = torch.DoubleTensor(T1)
t2 = torch.DoubleTensor(T2)
y_mid = torch.fft.ifft2(x)
y = torch.fft.ifft2(y_mid)
real = y.real
imag = y.imag
loss = (real * t1).sum() + (imag * t2).sum()
loss.backward()
grad_x_torch = x.grad.detach().numpy()
#jittor
x = jt.array(X).float64()
t1 = jt.array(T1).float64()
t2 = jt.array(T2).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y_mid = nn.fft2(x, True)
y = nn.fft2(y_mid, True)
real = y[:, :, :, 0]
imag = y[:, :, :, 1]
loss = (real * t1).sum() + (imag * t2).sum()
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
assert(np.allclose(grad_x_jt, grad_x_torch))
if __name__ == "__main__":
unittest.main()