fft polish

This commit is contained in:
cxjyxx_me 2022-03-30 01:29:25 -04:00
parent 65832ac10f
commit 45df1a4eee
2 changed files with 16 additions and 15 deletions

View File

@ -2791,7 +2791,8 @@ class Bilinear(Module):
return bilinear(in1, in2, self.weight, self.bias)
#TODO: support FFT2D only now.
def fft2(x, inverse=False):
def _fft2(x, inverse=False):
assert(jt.flags.use_cuda==1)
assert(len(x.shape) == 4)
assert(x.shape[3] == 2)
y = jt.compile_extern.cufft_ops.cufft_fft(x, inverse)

View File

@ -32,7 +32,7 @@ class TestFFTOp(unittest.TestCase):
#jittor
x = jt.array(X,dtype=jt.float32)
x = jt.stack([x, jt.zeros_like(x)], 3)
y = nn.fft2(x)
y = nn._fft2(x)
y_jt_real = y[:, :, :, 0].data
y_jt_imag = y[:, :, :, 1].data
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
@ -57,8 +57,8 @@ class TestFFTOp(unittest.TestCase):
#jittor
x = jt.array(X,dtype=jt.float32)
x = jt.stack([x, jt.zeros_like(x)], 3)
y = nn.fft2(x)
y_ori = nn.fft2(y, True)
y = nn._fft2(x)
y_ori = nn._fft2(y, True)
y_jt_real = y[:, :, :, 0].data
y_jt_imag = y[:, :, :, 1].data
y_ori_jt_real = y_ori[:, :, :, 0].data
@ -94,8 +94,8 @@ class TestFFTOp(unittest.TestCase):
t1 = jt.array(T1,dtype=jt.float32)
t2 = jt.array(T2,dtype=jt.float32)
x = jt.stack([x, jt.zeros_like(x)], 3)
y_mid = nn.fft2(x)
y = nn.fft2(y_mid)
y_mid = nn._fft2(x)
y = nn._fft2(y_mid)
real = y[:, :, :, 0]
imag = y[:, :, :, 1]
loss = (real * t1).sum() + (imag * t2).sum()
@ -129,8 +129,8 @@ class TestFFTOp(unittest.TestCase):
t1 = jt.array(T1,dtype=jt.float32)
t2 = jt.array(T2,dtype=jt.float32)
x = jt.stack([x, jt.zeros_like(x)], 3)
y_mid = nn.fft2(x, True)
y = nn.fft2(y_mid, True)
y_mid = nn._fft2(x, True)
y = nn._fft2(y_mid, True)
real = y[:, :, :, 0]
imag = y[:, :, :, 1]
loss = (real * t1).sum() + (imag * t2).sum()
@ -153,7 +153,7 @@ class TestFFTOp(unittest.TestCase):
#jittor
x = jt.array(X).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y = nn.fft2(x)
y = nn._fft2(x)
y_jt_real = y[:, :, :, 0].data
y_jt_imag = y[:, :, :, 1].data
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
@ -178,8 +178,8 @@ class TestFFTOp(unittest.TestCase):
#jittor
x = jt.array(X).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y = nn.fft2(x)
y_ori = nn.fft2(y, True)
y = nn._fft2(x)
y_ori = nn._fft2(y, True)
y_jt_real = y[:, :, :, 0].data
y_jt_imag = y[:, :, :, 1].data
y_ori_jt_real = y_ori[:, :, :, 0].data
@ -215,8 +215,8 @@ class TestFFTOp(unittest.TestCase):
t1 = jt.array(T1).float64()
t2 = jt.array(T2).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y_mid = nn.fft2(x)
y = nn.fft2(y_mid)
y_mid = nn._fft2(x)
y = nn._fft2(y_mid)
real = y[:, :, :, 0]
imag = y[:, :, :, 1]
loss = (real * t1).sum() + (imag * t2).sum()
@ -250,8 +250,8 @@ class TestFFTOp(unittest.TestCase):
t1 = jt.array(T1).float64()
t2 = jt.array(T2).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y_mid = nn.fft2(x, True)
y = nn.fft2(y_mid, True)
y_mid = nn._fft2(x, True)
y = nn._fft2(y_mid, True)
real = y[:, :, :, 0]
imag = y[:, :, :, 1]
loss = (real * t1).sum() + (imag * t2).sum()