mirror of https://github.com/Jittor/Jittor
avoid import error when no torch installed
This commit is contained in:
parent
df0ea12d7e
commit
556f0cfb1e
|
@ -10,11 +10,19 @@
|
|||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
skip_this_test = False
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
except:
|
||||
torch = None
|
||||
skip_this_test = True
|
||||
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestBicubicInterpolate(unittest.TestCase):
|
||||
# this is for testing bicubic interpolate
|
||||
def test_bicubic(self):
|
||||
|
|
|
@ -9,11 +9,18 @@
|
|||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
skip_this_test = False
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
except:
|
||||
torch = None
|
||||
skip_this_test = True
|
||||
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestCumprod(unittest.TestCase):
|
||||
def test_cumprod_cpu(self):
|
||||
for i in range(1,6):
|
||||
|
|
|
@ -12,6 +12,14 @@ import jittor as jt
|
|||
import numpy as np
|
||||
import jittor.distributions as jd
|
||||
|
||||
skip_this_test = False
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
torch = None
|
||||
skip_this_test = True
|
||||
|
||||
|
||||
class TestOneHot(unittest.TestCase):
|
||||
def test_presum(self):
|
||||
|
@ -19,6 +27,7 @@ class TestOneHot(unittest.TestCase):
|
|||
b = jd.simple_presum(a)
|
||||
assert (b.data == [[0,1,3,6,10]]).all()
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||
def test_one_hot(self):
|
||||
a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25]))
|
||||
x = a.sample().numpy()
|
||||
|
@ -30,7 +39,7 @@ class TestOneHot(unittest.TestCase):
|
|||
assert y.shape == [2,3,4]
|
||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||
import torch
|
||||
|
||||
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
|
||||
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
|
||||
|
@ -51,8 +60,8 @@ class TestOneHot(unittest.TestCase):
|
|||
y.sync()
|
||||
assert y.shape == [2,3]
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||
def test_normal(self):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
mu = np.random.uniform(-1,1)
|
||||
sigma = np.random.uniform(0,2)
|
||||
|
@ -67,8 +76,8 @@ class TestOneHot(unittest.TestCase):
|
|||
tn2 = torch.distributions.Normal(mu2,sigma2)
|
||||
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||
def test_categorical1(self):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||
|
@ -79,9 +88,9 @@ class TestOneHot(unittest.TestCase):
|
|||
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||
def test_categorical2(self):
|
||||
def check(prob_shape, sample_shape):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
||||
|
||||
|
@ -98,9 +107,9 @@ class TestOneHot(unittest.TestCase):
|
|||
check((2,3), (4,))
|
||||
check((3,4,5,6), (2,))
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||
def test_one_hot_categorical2(self):
|
||||
def check(prob_shape, sample_shape):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
||||
|
||||
|
@ -117,8 +126,8 @@ class TestOneHot(unittest.TestCase):
|
|||
check((2,3), (4,))
|
||||
check((3,4,5,6), (2,))
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||
def test_uniform(self):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
low, low2 = np.random.randint(-1,2), np.random.randint(-1,2)
|
||||
leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2)
|
||||
|
@ -130,8 +139,8 @@ class TestOneHot(unittest.TestCase):
|
|||
assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x)))
|
||||
assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2))
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||
def test_geometric(self):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
|
||||
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
|
||||
|
|
|
@ -10,11 +10,19 @@
|
|||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
skip_this_test = False
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
except:
|
||||
torch = None
|
||||
skip_this_test = True
|
||||
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||
class TestFoldOp(unittest.TestCase):
|
||||
def test_fold(self):
|
||||
# test unfold first and the test fold.
|
||||
|
|
|
@ -8,13 +8,13 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
try:
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import autograd.numpy as anp
|
||||
from autograd import jacobian
|
||||
|
||||
|
|
|
@ -18,12 +18,14 @@ import unittest
|
|||
from .test_reorder_tuner import simple_parser
|
||||
from .test_log import find_log_with_re
|
||||
|
||||
skip_this_test = False
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
skip_this_test = True
|
||||
|
||||
|
||||
class TestRandomOp(unittest.TestCase):
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
|
@ -51,6 +53,7 @@ class TestRandomOp(unittest.TestCase):
|
|||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)")
|
||||
assert len(logs)==1
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||
def test_normal(self):
|
||||
from jittor import init
|
||||
n = 10000
|
||||
|
|
|
@ -18,12 +18,9 @@ try:
|
|||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
import torchvision
|
||||
from torch.autograd import Variable
|
||||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
torchvision = None
|
||||
skip_this_test = True
|
||||
|
||||
# TODO: more test
|
||||
|
|
|
@ -9,11 +9,16 @@
|
|||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import ctypes
|
||||
import sys
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
skip_this_test = False
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
skip_this_test = True
|
||||
|
||||
|
||||
@unittest.skipIf
|
||||
class TestSearchsorted(unittest.TestCase):
|
||||
def test_searchsorted_cpu(self):
|
||||
for i in range(1,3):
|
||||
|
|
|
@ -32,7 +32,7 @@ cc_flags = f" -g -O0 -DTEST --std=c++14 -I{jittor_path}/test -I{jittor_path}/src
|
|||
|
||||
class TestUtils(unittest.TestCase):
|
||||
def test_cache_compile(self):
|
||||
cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile"
|
||||
cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/str_utils.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile"
|
||||
self.assertEqual(os.system(cmd), 0)
|
||||
|
||||
def test_log(self):
|
||||
|
|
|
@ -1 +1 @@
|
|||
9edb1890b66321883c398591483d4d65377a848d
|
||||
939b29514b2e5cc591053aab614efd569772585d
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
from multiprocessing import Pool, Value
|
||||
from multiprocessing import Pool
|
||||
import multiprocessing as mp
|
||||
import subprocess as sp
|
||||
import os
|
||||
|
|
|
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
|||
libext = {
|
||||
'Linux': 'so',
|
||||
'Darwin': 'dylib',
|
||||
'Window': 'DLL',
|
||||
'Windows': 'DLL',
|
||||
}[platform.system()]
|
||||
ldflags = jittor_utils.run_cmd(jittor_utils.py3_config_path + " --ldflags")
|
||||
libpaths = [l[2:] for l in ldflags.split(' ') if l.startswith("-L")]
|
||||
|
|
Loading…
Reference in New Issue