mirror of https://github.com/Jittor/Jittor
support float64 fft
This commit is contained in:
parent
e087a56d86
commit
fa62b3a217
|
@ -36,8 +36,13 @@ VarPtr CufftFftOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CufftFftOp::jit_prepare(JK& jk) {
|
void CufftFftOp::jit_prepare(JK& jk) {
|
||||||
|
if ((y->dtype() != "float32") && (y->dtype() != "float64")){
|
||||||
|
printf("not supported fft dtype: %s\n", y->dtype().to_cstring());
|
||||||
|
ASSERT(false);
|
||||||
|
}
|
||||||
jk << _CS("[T:") << y->dtype();
|
jk << _CS("[T:") << y->dtype();
|
||||||
jk << _CS("][I:")<<inverse<<"]";
|
jk << _CS("][I:")<<inverse<<"]";
|
||||||
|
jk << _CS("[TS:\"")<<y->dtype()<<"\"]";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
@ -56,19 +61,26 @@ void CufftFftOp::jit_run() {
|
||||||
std::array<int, 2> fft = {n1, n2};
|
std::array<int, 2> fft = {n1, n2};
|
||||||
|
|
||||||
CUFFT_CALL(cufftCreate(&plan));
|
CUFFT_CALL(cufftCreate(&plan));
|
||||||
|
auto op_type = CUFFT_C2C;
|
||||||
|
if (TS == "float32") {
|
||||||
|
op_type = CUFFT_C2C;
|
||||||
|
} else if (TS == "float64") {
|
||||||
|
op_type = CUFFT_Z2Z;
|
||||||
|
}
|
||||||
CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(),
|
CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(),
|
||||||
nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist
|
nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist
|
||||||
nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist
|
nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist
|
||||||
CUFFT_C2C, batch_size));
|
op_type, batch_size));
|
||||||
CUFFT_CALL(cufftSetStream(plan, 0));
|
CUFFT_CALL(cufftSetStream(plan, 0));
|
||||||
/*
|
/*
|
||||||
* Note:
|
* Note:
|
||||||
* Identical pointers to data and output arrays implies in-place transformation
|
* Identical pointers to data and output arrays implies in-place transformation
|
||||||
*/
|
*/
|
||||||
CUDA_RT_CALL(cudaStreamSynchronize(0));
|
if (TS == "float32") {
|
||||||
CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
|
CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
|
||||||
// CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, CUFFT_INVERSE));
|
} else if (TS == "float64") {
|
||||||
CUDA_RT_CALL(cudaStreamSynchronize(0));
|
CUFFT_CALL(cufftExecZ2Z(plan, (cufftDoubleComplex *)xp, (cufftDoubleComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
|
||||||
|
}
|
||||||
|
|
||||||
CUFFT_CALL(cufftDestroy(plan));
|
CUFFT_CALL(cufftDestroy(plan));
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,8 +19,8 @@ class TestFFTOp(unittest.TestCase):
|
||||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
@jt.flag_scope(use_cuda=1)
|
@jt.flag_scope(use_cuda=1)
|
||||||
def test_fft_forward(self):
|
def test_fft_forward(self):
|
||||||
img = jt.rand(256, 300)
|
img = np.random.rand(256, 300)
|
||||||
img2 = jt.rand(256, 300)
|
img2 = np.random.rand(256, 300)
|
||||||
X = np.stack([img, img2], 0)
|
X = np.stack([img, img2], 0)
|
||||||
|
|
||||||
# torch
|
# torch
|
||||||
|
@ -41,8 +41,8 @@ class TestFFTOp(unittest.TestCase):
|
||||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
@jt.flag_scope(use_cuda=1)
|
@jt.flag_scope(use_cuda=1)
|
||||||
def test_ifft_forward(self):
|
def test_ifft_forward(self):
|
||||||
img = jt.rand(256, 300)
|
img = np.random.rand(256, 300)
|
||||||
img2 = jt.rand(256, 300)
|
img2 = np.random.rand(256, 300)
|
||||||
X = np.stack([img, img2], 0)
|
X = np.stack([img, img2], 0)
|
||||||
|
|
||||||
# torch
|
# torch
|
||||||
|
@ -70,8 +70,8 @@ class TestFFTOp(unittest.TestCase):
|
||||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
@jt.flag_scope(use_cuda=1)
|
@jt.flag_scope(use_cuda=1)
|
||||||
def test_fft_backward(self):
|
def test_fft_backward(self):
|
||||||
img = jt.rand(256, 300)
|
img = np.random.rand(256, 300)
|
||||||
img2 = jt.rand(256, 300)
|
img2 = np.random.rand(256, 300)
|
||||||
X = np.stack([img, img2], 0)
|
X = np.stack([img, img2], 0)
|
||||||
T1 = np.random.rand(1,256,300)
|
T1 = np.random.rand(1,256,300)
|
||||||
T2 = np.random.rand(1,256,300)
|
T2 = np.random.rand(1,256,300)
|
||||||
|
@ -105,8 +105,8 @@ class TestFFTOp(unittest.TestCase):
|
||||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
@jt.flag_scope(use_cuda=1)
|
@jt.flag_scope(use_cuda=1)
|
||||||
def test_ifft_backward(self):
|
def test_ifft_backward(self):
|
||||||
img = jt.rand(256, 300)
|
img = np.random.rand(256, 300)
|
||||||
img2 = jt.rand(256, 300)
|
img2 = np.random.rand(256, 300)
|
||||||
X = np.stack([img, img2], 0)
|
X = np.stack([img, img2], 0)
|
||||||
T1 = np.random.rand(1,256,300)
|
T1 = np.random.rand(1,256,300)
|
||||||
T2 = np.random.rand(1,256,300)
|
T2 = np.random.rand(1,256,300)
|
||||||
|
@ -137,5 +137,126 @@ class TestFFTOp(unittest.TestCase):
|
||||||
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
|
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
|
||||||
assert(np.allclose(grad_x_jt, grad_x_torch))
|
assert(np.allclose(grad_x_jt, grad_x_torch))
|
||||||
|
|
||||||
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
|
@jt.flag_scope(use_cuda=1)
|
||||||
|
def test_fft_float64_forward(self):
|
||||||
|
img = np.random.rand(256, 300)
|
||||||
|
img2 = np.random.rand(256, 300)
|
||||||
|
X = np.stack([img, img2], 0)
|
||||||
|
|
||||||
|
# torch
|
||||||
|
x = torch.DoubleTensor(X)
|
||||||
|
y = torch.fft.fft2(x)
|
||||||
|
y_torch_real = y.numpy().real
|
||||||
|
y_torch_imag = y.numpy().imag
|
||||||
|
|
||||||
|
#jittor
|
||||||
|
x = jt.array(X).float64()
|
||||||
|
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||||
|
y = nn.fft2(x)
|
||||||
|
y_jt_real = y[:, :, :, 0].data
|
||||||
|
y_jt_imag = y[:, :, :, 1].data
|
||||||
|
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
|
||||||
|
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
|
||||||
|
|
||||||
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
|
@jt.flag_scope(use_cuda=1)
|
||||||
|
def test_ifft_float64_forward(self):
|
||||||
|
img = np.random.rand(256, 300)
|
||||||
|
img2 = np.random.rand(256, 300)
|
||||||
|
X = np.stack([img, img2], 0)
|
||||||
|
|
||||||
|
# torch
|
||||||
|
x = torch.DoubleTensor(X)
|
||||||
|
y = torch.fft.fft2(x)
|
||||||
|
y_torch_real = y.numpy().real
|
||||||
|
y_torch_imag = y.numpy().imag
|
||||||
|
y_ori = torch.fft.ifft2(y)
|
||||||
|
y_ori_torch_real = y_ori.real.numpy()
|
||||||
|
assert(np.allclose(y_ori_torch_real, X, atol=1))
|
||||||
|
|
||||||
|
#jittor
|
||||||
|
x = jt.array(X).float64()
|
||||||
|
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||||
|
y = nn.fft2(x)
|
||||||
|
y_ori = nn.fft2(y, True)
|
||||||
|
y_jt_real = y[:, :, :, 0].data
|
||||||
|
y_jt_imag = y[:, :, :, 1].data
|
||||||
|
y_ori_jt_real = y_ori[:, :, :, 0].data
|
||||||
|
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
|
||||||
|
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
|
||||||
|
assert(np.allclose(y_ori_jt_real, X, atol=1))
|
||||||
|
assert(np.allclose(y_ori_jt_real, y_ori_torch_real, atol=1))
|
||||||
|
|
||||||
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
|
@jt.flag_scope(use_cuda=1)
|
||||||
|
def test_fft_float64_backward(self):
|
||||||
|
img = np.random.rand(256, 300)
|
||||||
|
img2 = np.random.rand(256, 300)
|
||||||
|
X = np.stack([img, img2], 0)
|
||||||
|
T1 = np.random.rand(1,256,300)
|
||||||
|
T2 = np.random.rand(1,256,300)
|
||||||
|
|
||||||
|
# torch
|
||||||
|
x = torch.DoubleTensor(X)
|
||||||
|
x.requires_grad = True
|
||||||
|
t1 = torch.DoubleTensor(T1)
|
||||||
|
t2 = torch.DoubleTensor(T2)
|
||||||
|
y_mid = torch.fft.fft2(x)
|
||||||
|
y = torch.fft.fft2(y_mid)
|
||||||
|
real = y.real
|
||||||
|
imag = y.imag
|
||||||
|
loss = (real * t1).sum() + (imag * t2).sum()
|
||||||
|
loss.backward()
|
||||||
|
grad_x_torch = x.grad.detach().numpy()
|
||||||
|
|
||||||
|
#jittor
|
||||||
|
x = jt.array(X).float64()
|
||||||
|
t1 = jt.array(T1).float64()
|
||||||
|
t2 = jt.array(T2).float64()
|
||||||
|
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||||
|
y_mid = nn.fft2(x)
|
||||||
|
y = nn.fft2(y_mid)
|
||||||
|
real = y[:, :, :, 0]
|
||||||
|
imag = y[:, :, :, 1]
|
||||||
|
loss = (real * t1).sum() + (imag * t2).sum()
|
||||||
|
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
|
||||||
|
assert(np.allclose(grad_x_jt, grad_x_torch, atol=1))
|
||||||
|
|
||||||
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
|
@jt.flag_scope(use_cuda=1)
|
||||||
|
def test_ifft_float64_backward(self):
|
||||||
|
img = np.random.rand(256, 300)
|
||||||
|
img2 = np.random.rand(256, 300)
|
||||||
|
X = np.stack([img, img2], 0)
|
||||||
|
T1 = np.random.rand(1,256,300)
|
||||||
|
T2 = np.random.rand(1,256,300)
|
||||||
|
|
||||||
|
# torch
|
||||||
|
x = torch.DoubleTensor(X)
|
||||||
|
x.requires_grad = True
|
||||||
|
t1 = torch.DoubleTensor(T1)
|
||||||
|
t2 = torch.DoubleTensor(T2)
|
||||||
|
y_mid = torch.fft.ifft2(x)
|
||||||
|
y = torch.fft.ifft2(y_mid)
|
||||||
|
real = y.real
|
||||||
|
imag = y.imag
|
||||||
|
loss = (real * t1).sum() + (imag * t2).sum()
|
||||||
|
loss.backward()
|
||||||
|
grad_x_torch = x.grad.detach().numpy()
|
||||||
|
|
||||||
|
#jittor
|
||||||
|
x = jt.array(X).float64()
|
||||||
|
t1 = jt.array(T1).float64()
|
||||||
|
t2 = jt.array(T2).float64()
|
||||||
|
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||||
|
y_mid = nn.fft2(x, True)
|
||||||
|
y = nn.fft2(y_mid, True)
|
||||||
|
real = y[:, :, :, 0]
|
||||||
|
imag = y[:, :, :, 1]
|
||||||
|
loss = (real * t1).sum() + (imag * t2).sum()
|
||||||
|
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
|
||||||
|
assert(np.allclose(grad_x_jt, grad_x_torch))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue