mirror of https://github.com/Jittor/Jittor
update fft op test
This commit is contained in:
parent
e9f681de53
commit
e087a56d86
|
@ -10,8 +10,7 @@
|
|||
import jittor as jt
|
||||
import unittest
|
||||
from .test_log import find_log_with_re
|
||||
import torch
|
||||
import cv2
|
||||
import torch # torch >= 1.9.0 needed
|
||||
import numpy as np
|
||||
from jittor import nn
|
||||
|
||||
|
@ -20,10 +19,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_fft_forward(self):
|
||||
img = cv2.imread("test.jpg")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.imread("test2.jpg")
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
img = jt.rand(256, 300)
|
||||
img2 = jt.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
|
||||
# torch
|
||||
|
@ -44,10 +41,8 @@ class TestFFTOp(unittest.TestCase):
|
|||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_ifft_forward(self):
|
||||
img = cv2.imread("test.jpg")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.imread("test2.jpg")
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
img = jt.rand(256, 300)
|
||||
img2 = jt.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
|
||||
# torch
|
||||
|
@ -75,13 +70,11 @@ class TestFFTOp(unittest.TestCase):
|
|||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_fft_backward(self):
|
||||
img = cv2.imread("test.jpg")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.imread("test2.jpg")
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
img = jt.rand(256, 300)
|
||||
img2 = jt.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
T1 = np.random.rand(1,512,512)
|
||||
T2 = np.random.rand(1,512,512)
|
||||
T1 = np.random.rand(1,256,300)
|
||||
T2 = np.random.rand(1,256,300)
|
||||
|
||||
# torch
|
||||
x = torch.Tensor(X)
|
||||
|
@ -112,13 +105,11 @@ class TestFFTOp(unittest.TestCase):
|
|||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_ifft_backward(self):
|
||||
img = cv2.imread("test.jpg")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.imread("test2.jpg")
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
img = jt.rand(256, 300)
|
||||
img2 = jt.rand(256, 300)
|
||||
X = np.stack([img, img2], 0)
|
||||
T1 = np.random.rand(1,512,512)
|
||||
T2 = np.random.rand(1,512,512)
|
||||
T1 = np.random.rand(1,256,300)
|
||||
T2 = np.random.rand(1,256,300)
|
||||
|
||||
# torch
|
||||
x = torch.Tensor(X)
|
||||
|
|
Loading…
Reference in New Issue