mirror of https://github.com/Jittor/Jittor
fft polish
This commit is contained in:
parent
65832ac10f
commit
45df1a4eee
|
@ -2791,7 +2791,8 @@ class Bilinear(Module):
|
|||
return bilinear(in1, in2, self.weight, self.bias)
|
||||
|
||||
#TODO: support FFT2D only now.
|
||||
def fft2(x, inverse=False):
|
||||
def _fft2(x, inverse=False):
|
||||
assert(jt.flags.use_cuda==1)
|
||||
assert(len(x.shape) == 4)
|
||||
assert(x.shape[3] == 2)
|
||||
y = jt.compile_extern.cufft_ops.cufft_fft(x, inverse)
|
||||
|
|
|
@ -32,7 +32,7 @@ class TestFFTOp(unittest.TestCase):
|
|||
#jittor
|
||||
x = jt.array(X,dtype=jt.float32)
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y = nn.fft2(x)
|
||||
y = nn._fft2(x)
|
||||
y_jt_real = y[:, :, :, 0].data
|
||||
y_jt_imag = y[:, :, :, 1].data
|
||||
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
|
||||
|
@ -57,8 +57,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
#jittor
|
||||
x = jt.array(X,dtype=jt.float32)
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y = nn.fft2(x)
|
||||
y_ori = nn.fft2(y, True)
|
||||
y = nn._fft2(x)
|
||||
y_ori = nn._fft2(y, True)
|
||||
y_jt_real = y[:, :, :, 0].data
|
||||
y_jt_imag = y[:, :, :, 1].data
|
||||
y_ori_jt_real = y_ori[:, :, :, 0].data
|
||||
|
@ -94,8 +94,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
t1 = jt.array(T1,dtype=jt.float32)
|
||||
t2 = jt.array(T2,dtype=jt.float32)
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y_mid = nn.fft2(x)
|
||||
y = nn.fft2(y_mid)
|
||||
y_mid = nn._fft2(x)
|
||||
y = nn._fft2(y_mid)
|
||||
real = y[:, :, :, 0]
|
||||
imag = y[:, :, :, 1]
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
|
@ -129,8 +129,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
t1 = jt.array(T1,dtype=jt.float32)
|
||||
t2 = jt.array(T2,dtype=jt.float32)
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y_mid = nn.fft2(x, True)
|
||||
y = nn.fft2(y_mid, True)
|
||||
y_mid = nn._fft2(x, True)
|
||||
y = nn._fft2(y_mid, True)
|
||||
real = y[:, :, :, 0]
|
||||
imag = y[:, :, :, 1]
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
|
@ -153,7 +153,7 @@ class TestFFTOp(unittest.TestCase):
|
|||
#jittor
|
||||
x = jt.array(X).float64()
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y = nn.fft2(x)
|
||||
y = nn._fft2(x)
|
||||
y_jt_real = y[:, :, :, 0].data
|
||||
y_jt_imag = y[:, :, :, 1].data
|
||||
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
|
||||
|
@ -178,8 +178,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
#jittor
|
||||
x = jt.array(X).float64()
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y = nn.fft2(x)
|
||||
y_ori = nn.fft2(y, True)
|
||||
y = nn._fft2(x)
|
||||
y_ori = nn._fft2(y, True)
|
||||
y_jt_real = y[:, :, :, 0].data
|
||||
y_jt_imag = y[:, :, :, 1].data
|
||||
y_ori_jt_real = y_ori[:, :, :, 0].data
|
||||
|
@ -215,8 +215,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
t1 = jt.array(T1).float64()
|
||||
t2 = jt.array(T2).float64()
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y_mid = nn.fft2(x)
|
||||
y = nn.fft2(y_mid)
|
||||
y_mid = nn._fft2(x)
|
||||
y = nn._fft2(y_mid)
|
||||
real = y[:, :, :, 0]
|
||||
imag = y[:, :, :, 1]
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
|
@ -250,8 +250,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
t1 = jt.array(T1).float64()
|
||||
t2 = jt.array(T2).float64()
|
||||
x = jt.stack([x, jt.zeros_like(x)], 3)
|
||||
y_mid = nn.fft2(x, True)
|
||||
y = nn.fft2(y_mid, True)
|
||||
y_mid = nn._fft2(x, True)
|
||||
y = nn._fft2(y_mid, True)
|
||||
real = y[:, :, :, 0]
|
||||
imag = y[:, :, :, 1]
|
||||
loss = (real * t1).sum() + (imag * t2).sum()
|
||||
|
|
Loading…
Reference in New Issue