avoid import error when no torch installed

This commit is contained in:
lzhengning 2021-06-18 15:49:06 +08:00
parent df0ea12d7e
commit 556f0cfb1e
12 changed files with 65 additions and 28 deletions

View File

@ -10,11 +10,19 @@
# *************************************************************** # ***************************************************************
import unittest import unittest
import jittor as jt import jittor as jt
import torch
from torch.nn import functional as F
import numpy as np import numpy as np
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
from torch.nn import functional as F
except:
torch = None
skip_this_test = True
@unittest.skipIf(skip_this_test, "No Torch found")
class TestBicubicInterpolate(unittest.TestCase): class TestBicubicInterpolate(unittest.TestCase):
# this is for testing bicubic interpolate # this is for testing bicubic interpolate
def test_bicubic(self): def test_bicubic(self):

View File

@ -9,11 +9,18 @@
import unittest import unittest
import jittor as jt import jittor as jt
import numpy as np import numpy as np
import ctypes
import sys
import torch
from torch.autograd import Variable
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
from torch.autograd import Variable
except:
torch = None
skip_this_test = True
@unittest.skipIf(skip_this_test, "No Torch found")
class TestCumprod(unittest.TestCase): class TestCumprod(unittest.TestCase):
def test_cumprod_cpu(self): def test_cumprod_cpu(self):
for i in range(1,6): for i in range(1,6):

View File

@ -12,6 +12,14 @@ import jittor as jt
import numpy as np import numpy as np
import jittor.distributions as jd import jittor.distributions as jd
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
torch = None
skip_this_test = True
class TestOneHot(unittest.TestCase): class TestOneHot(unittest.TestCase):
def test_presum(self): def test_presum(self):
@ -19,6 +27,7 @@ class TestOneHot(unittest.TestCase):
b = jd.simple_presum(a) b = jd.simple_presum(a)
assert (b.data == [[0,1,3,6,10]]).all() assert (b.data == [[0,1,3,6,10]]).all()
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_one_hot(self): def test_one_hot(self):
a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25])) a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25]))
x = a.sample().numpy() x = a.sample().numpy()
@ -30,7 +39,7 @@ class TestOneHot(unittest.TestCase):
assert y.shape == [2,3,4] assert y.shape == [2,3,4]
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
import torch
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2)) jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2)) tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data,tc.entropy().numpy()) assert np.allclose(jc.entropy().data,tc.entropy().numpy())
@ -51,8 +60,8 @@ class TestOneHot(unittest.TestCase):
y.sync() y.sync()
assert y.shape == [2,3] assert y.shape == [2,3]
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_normal(self): def test_normal(self):
import torch
for _ in range(4): for _ in range(4):
mu = np.random.uniform(-1,1) mu = np.random.uniform(-1,1)
sigma = np.random.uniform(0,2) sigma = np.random.uniform(0,2)
@ -67,8 +76,8 @@ class TestOneHot(unittest.TestCase):
tn2 = torch.distributions.Normal(mu2,sigma2) tn2 = torch.distributions.Normal(mu2,sigma2)
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy()) assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_categorical1(self): def test_categorical1(self):
import torch
for _ in range(4): for _ in range(4):
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
@ -79,9 +88,9 @@ class TestOneHot(unittest.TestCase):
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5) np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_categorical2(self): def test_categorical2(self):
def check(prob_shape, sample_shape): def check(prob_shape, sample_shape):
import torch
for _ in range(4): for _ in range(4):
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape) probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
@ -98,9 +107,9 @@ class TestOneHot(unittest.TestCase):
check((2,3), (4,)) check((2,3), (4,))
check((3,4,5,6), (2,)) check((3,4,5,6), (2,))
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_one_hot_categorical2(self): def test_one_hot_categorical2(self):
def check(prob_shape, sample_shape): def check(prob_shape, sample_shape):
import torch
for _ in range(4): for _ in range(4):
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape) probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
@ -117,8 +126,8 @@ class TestOneHot(unittest.TestCase):
check((2,3), (4,)) check((2,3), (4,))
check((3,4,5,6), (2,)) check((3,4,5,6), (2,))
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_uniform(self): def test_uniform(self):
import torch
for _ in range(4): for _ in range(4):
low, low2 = np.random.randint(-1,2), np.random.randint(-1,2) low, low2 = np.random.randint(-1,2), np.random.randint(-1,2)
leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2) leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2)
@ -130,8 +139,8 @@ class TestOneHot(unittest.TestCase):
assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x))) assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x)))
assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2)) assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2))
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_geometric(self): def test_geometric(self):
import torch
for _ in range(4): for _ in range(4):
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1) prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2) jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)

View File

@ -10,11 +10,19 @@
# *************************************************************** # ***************************************************************
import unittest import unittest
import jittor as jt import jittor as jt
import torch
from torch.nn import functional as F
import numpy as np import numpy as np
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
from torch.nn import functional as F
except:
torch = None
skip_this_test = True
@unittest.skipIf(skip_this_test, "No Torch Found")
class TestFoldOp(unittest.TestCase): class TestFoldOp(unittest.TestCase):
def test_fold(self): def test_fold(self):
# test unfold first and the test fold. # test unfold first and the test fold.

View File

@ -8,13 +8,13 @@
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
import torch
from torch.autograd import Variable
import jittor as jt import jittor as jt
import numpy as np import numpy as np
import unittest import unittest
try: try:
import torch
from torch.autograd import Variable
import autograd.numpy as anp import autograd.numpy as anp
from autograd import jacobian from autograd import jacobian

View File

@ -18,12 +18,14 @@ import unittest
from .test_reorder_tuner import simple_parser from .test_reorder_tuner import simple_parser
from .test_log import find_log_with_re from .test_log import find_log_with_re
skip_this_test = False
try: try:
jt.dirty_fix_pytorch_runtime_error() jt.dirty_fix_pytorch_runtime_error()
import torch import torch
except: except:
skip_this_test = True skip_this_test = True
class TestRandomOp(unittest.TestCase): class TestRandomOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found") @unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1) @jt.flag_scope(use_cuda=1)
@ -51,6 +53,7 @@ class TestRandomOp(unittest.TestCase):
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)") logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)")
assert len(logs)==1 assert len(logs)==1
@unittest.skipIf(skip_this_test, "No Torch Found")
def test_normal(self): def test_normal(self):
from jittor import init from jittor import init
n = 10000 n = 10000

View File

@ -18,12 +18,9 @@ try:
jt.dirty_fix_pytorch_runtime_error() jt.dirty_fix_pytorch_runtime_error()
import torch import torch
import torch.nn as tnn import torch.nn as tnn
import torchvision
from torch.autograd import Variable
except: except:
torch = None torch = None
tnn = None tnn = None
torchvision = None
skip_this_test = True skip_this_test = True
# TODO: more test # TODO: more test

View File

@ -9,11 +9,16 @@
import unittest import unittest
import jittor as jt import jittor as jt
import numpy as np import numpy as np
import ctypes
import sys
import torch
from torch.autograd import Variable
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
skip_this_test = True
@unittest.skipIf
class TestSearchsorted(unittest.TestCase): class TestSearchsorted(unittest.TestCase):
def test_searchsorted_cpu(self): def test_searchsorted_cpu(self):
for i in range(1,3): for i in range(1,3):

View File

@ -32,7 +32,7 @@ cc_flags = f" -g -O0 -DTEST --std=c++14 -I{jittor_path}/test -I{jittor_path}/src
class TestUtils(unittest.TestCase): class TestUtils(unittest.TestCase):
def test_cache_compile(self): def test_cache_compile(self):
cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile" cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/str_utils.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile"
self.assertEqual(os.system(cmd), 0) self.assertEqual(os.system(cmd), 0)
def test_log(self): def test_log(self):

View File

@ -1 +1 @@
9edb1890b66321883c398591483d4d65377a848d 939b29514b2e5cc591053aab614efd569772585d

View File

@ -4,7 +4,7 @@
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
from multiprocessing import Pool, Value from multiprocessing import Pool
import multiprocessing as mp import multiprocessing as mp
import subprocess as sp import subprocess as sp
import os import os

View File

@ -31,7 +31,7 @@ if __name__ == "__main__":
libext = { libext = {
'Linux': 'so', 'Linux': 'so',
'Darwin': 'dylib', 'Darwin': 'dylib',
'Window': 'DLL', 'Windows': 'DLL',
}[platform.system()] }[platform.system()]
ldflags = jittor_utils.run_cmd(jittor_utils.py3_config_path + " --ldflags") ldflags = jittor_utils.run_cmd(jittor_utils.py3_config_path + " --ldflags")
libpaths = [l[2:] for l in ldflags.split(' ') if l.startswith("-L")] libpaths = [l[2:] for l in ldflags.split(' ') if l.startswith("-L")]