avoid import error when no torch installed

This commit is contained in:
lzhengning 2021-06-18 15:49:06 +08:00
parent df0ea12d7e
commit 556f0cfb1e
12 changed files with 65 additions and 28 deletions

View File

@ -10,11 +10,19 @@
# ***************************************************************
import unittest
import jittor as jt
import torch
from torch.nn import functional as F
import numpy as np
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
from torch.nn import functional as F
except:
torch = None
skip_this_test = True
@unittest.skipIf(skip_this_test, "No Torch found")
class TestBicubicInterpolate(unittest.TestCase):
# this is for testing bicubic interpolate
def test_bicubic(self):

View File

@ -9,11 +9,18 @@
import unittest
import jittor as jt
import numpy as np
import ctypes
import sys
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
from torch.autograd import Variable
except:
torch = None
skip_this_test = True
@unittest.skipIf(skip_this_test, "No Torch found")
class TestCumprod(unittest.TestCase):
def test_cumprod_cpu(self):
for i in range(1,6):

View File

@ -12,6 +12,14 @@ import jittor as jt
import numpy as np
import jittor.distributions as jd
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
torch = None
skip_this_test = True
class TestOneHot(unittest.TestCase):
def test_presum(self):
@ -19,6 +27,7 @@ class TestOneHot(unittest.TestCase):
b = jd.simple_presum(a)
assert (b.data == [[0,1,3,6,10]]).all()
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_one_hot(self):
a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25]))
x = a.sample().numpy()
@ -30,7 +39,7 @@ class TestOneHot(unittest.TestCase):
assert y.shape == [2,3,4]
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
import torch
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
@ -51,8 +60,8 @@ class TestOneHot(unittest.TestCase):
y.sync()
assert y.shape == [2,3]
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_normal(self):
import torch
for _ in range(4):
mu = np.random.uniform(-1,1)
sigma = np.random.uniform(0,2)
@ -67,8 +76,8 @@ class TestOneHot(unittest.TestCase):
tn2 = torch.distributions.Normal(mu2,sigma2)
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_categorical1(self):
import torch
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
@ -79,9 +88,9 @@ class TestOneHot(unittest.TestCase):
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_categorical2(self):
def check(prob_shape, sample_shape):
import torch
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
@ -98,9 +107,9 @@ class TestOneHot(unittest.TestCase):
check((2,3), (4,))
check((3,4,5,6), (2,))
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_one_hot_categorical2(self):
def check(prob_shape, sample_shape):
import torch
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
@ -117,8 +126,8 @@ class TestOneHot(unittest.TestCase):
check((2,3), (4,))
check((3,4,5,6), (2,))
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_uniform(self):
import torch
for _ in range(4):
low, low2 = np.random.randint(-1,2), np.random.randint(-1,2)
leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2)
@ -130,8 +139,8 @@ class TestOneHot(unittest.TestCase):
assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x)))
assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2))
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_geometric(self):
import torch
for _ in range(4):
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)

View File

@ -10,11 +10,19 @@
# ***************************************************************
import unittest
import jittor as jt
import torch
from torch.nn import functional as F
import numpy as np
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
from torch.nn import functional as F
except:
torch = None
skip_this_test = True
@unittest.skipIf(skip_this_test, "No Torch Found")
class TestFoldOp(unittest.TestCase):
def test_fold(self):
# test unfold first and the test fold.

View File

@ -8,13 +8,13 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import torch
from torch.autograd import Variable
import jittor as jt
import numpy as np
import unittest
try:
import torch
from torch.autograd import Variable
import autograd.numpy as anp
from autograd import jacobian

View File

@ -18,12 +18,14 @@ import unittest
from .test_reorder_tuner import simple_parser
from .test_log import find_log_with_re
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
skip_this_test = True
class TestRandomOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
@ -51,6 +53,7 @@ class TestRandomOp(unittest.TestCase):
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)")
assert len(logs)==1
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_normal(self):
from jittor import init
n = 10000

View File

@ -18,12 +18,9 @@ try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
import torchvision
from torch.autograd import Variable
except:
torch = None
tnn = None
torchvision = None
skip_this_test = True
# TODO: more test

View File

@ -9,11 +9,16 @@
import unittest
import jittor as jt
import numpy as np
import ctypes
import sys
import torch
from torch.autograd import Variable
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
skip_this_test = True
@unittest.skipIf
class TestSearchsorted(unittest.TestCase):
def test_searchsorted_cpu(self):
for i in range(1,3):

View File

@ -32,7 +32,7 @@ cc_flags = f" -g -O0 -DTEST --std=c++14 -I{jittor_path}/test -I{jittor_path}/src
class TestUtils(unittest.TestCase):
def test_cache_compile(self):
cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile"
cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/str_utils.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile"
self.assertEqual(os.system(cmd), 0)
def test_log(self):

View File

@ -1 +1 @@
9edb1890b66321883c398591483d4d65377a848d
939b29514b2e5cc591053aab614efd569772585d

View File

@ -4,7 +4,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
from multiprocessing import Pool, Value
from multiprocessing import Pool
import multiprocessing as mp
import subprocess as sp
import os

View File

@ -31,7 +31,7 @@ if __name__ == "__main__":
libext = {
'Linux': 'so',
'Darwin': 'dylib',
'Window': 'DLL',
'Windows': 'DLL',
}[platform.system()]
ldflags = jittor_utils.run_cmd(jittor_utils.py3_config_path + " --ldflags")
libpaths = [l[2:] for l in ldflags.split(' ') if l.startswith("-L")]