update fft op test

This commit is contained in:
cxjyxx_me 2022-03-26 22:27:01 -04:00
parent e9f681de53
commit e087a56d86
1 changed files with 13 additions and 22 deletions

View File

@ -10,8 +10,7 @@
import jittor as jt
import unittest
from .test_log import find_log_with_re
import torch
import cv2
import torch # torch >= 1.9.0 needed
import numpy as np
from jittor import nn
@ -20,10 +19,8 @@ class TestFFTOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_fft_forward(self):
img = cv2.imread("test.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img2 = cv2.imread("test2.jpg")
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
img = jt.rand(256, 300)
img2 = jt.rand(256, 300)
X = np.stack([img, img2], 0)
# torch
@ -44,10 +41,8 @@ class TestFFTOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_ifft_forward(self):
img = cv2.imread("test.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img2 = cv2.imread("test2.jpg")
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
img = jt.rand(256, 300)
img2 = jt.rand(256, 300)
X = np.stack([img, img2], 0)
# torch
@ -75,13 +70,11 @@ class TestFFTOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_fft_backward(self):
img = cv2.imread("test.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img2 = cv2.imread("test2.jpg")
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
img = jt.rand(256, 300)
img2 = jt.rand(256, 300)
X = np.stack([img, img2], 0)
T1 = np.random.rand(1,512,512)
T2 = np.random.rand(1,512,512)
T1 = np.random.rand(1,256,300)
T2 = np.random.rand(1,256,300)
# torch
x = torch.Tensor(X)
@ -112,13 +105,11 @@ class TestFFTOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_ifft_backward(self):
img = cv2.imread("test.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img2 = cv2.imread("test2.jpg")
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
img = jt.rand(256, 300)
img2 = jt.rand(256, 300)
X = np.stack([img, img2], 0)
T1 = np.random.rand(1,512,512)
T2 = np.random.rand(1,512,512)
T1 = np.random.rand(1,256,300)
T2 = np.random.rand(1,256,300)
# torch
x = torch.Tensor(X)