mirror of https://github.com/Jittor/Jittor
avoid import error when no torch installed
This commit is contained in:
parent
df0ea12d7e
commit
556f0cfb1e
|
@ -10,11 +10,19 @@
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
import unittest
|
import unittest
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
skip_this_test = False
|
||||||
|
try:
|
||||||
|
jt.dirty_fix_pytorch_runtime_error()
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
except:
|
||||||
|
torch = None
|
||||||
|
skip_this_test = True
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||||
class TestBicubicInterpolate(unittest.TestCase):
|
class TestBicubicInterpolate(unittest.TestCase):
|
||||||
# this is for testing bicubic interpolate
|
# this is for testing bicubic interpolate
|
||||||
def test_bicubic(self):
|
def test_bicubic(self):
|
||||||
|
|
|
@ -9,11 +9,18 @@
|
||||||
import unittest
|
import unittest
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ctypes
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
from torch.autograd import Variable
|
|
||||||
|
|
||||||
|
skip_this_test = False
|
||||||
|
try:
|
||||||
|
jt.dirty_fix_pytorch_runtime_error()
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
|
except:
|
||||||
|
torch = None
|
||||||
|
skip_this_test = True
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||||
class TestCumprod(unittest.TestCase):
|
class TestCumprod(unittest.TestCase):
|
||||||
def test_cumprod_cpu(self):
|
def test_cumprod_cpu(self):
|
||||||
for i in range(1,6):
|
for i in range(1,6):
|
||||||
|
|
|
@ -12,6 +12,14 @@ import jittor as jt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import jittor.distributions as jd
|
import jittor.distributions as jd
|
||||||
|
|
||||||
|
skip_this_test = False
|
||||||
|
try:
|
||||||
|
jt.dirty_fix_pytorch_runtime_error()
|
||||||
|
import torch
|
||||||
|
except:
|
||||||
|
torch = None
|
||||||
|
skip_this_test = True
|
||||||
|
|
||||||
|
|
||||||
class TestOneHot(unittest.TestCase):
|
class TestOneHot(unittest.TestCase):
|
||||||
def test_presum(self):
|
def test_presum(self):
|
||||||
|
@ -19,6 +27,7 @@ class TestOneHot(unittest.TestCase):
|
||||||
b = jd.simple_presum(a)
|
b = jd.simple_presum(a)
|
||||||
assert (b.data == [[0,1,3,6,10]]).all()
|
assert (b.data == [[0,1,3,6,10]]).all()
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||||
def test_one_hot(self):
|
def test_one_hot(self):
|
||||||
a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25]))
|
a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25]))
|
||||||
x = a.sample().numpy()
|
x = a.sample().numpy()
|
||||||
|
@ -30,7 +39,7 @@ class TestOneHot(unittest.TestCase):
|
||||||
assert y.shape == [2,3,4]
|
assert y.shape == [2,3,4]
|
||||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||||
import torch
|
|
||||||
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
|
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
|
||||||
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
||||||
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
|
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
|
||||||
|
@ -51,8 +60,8 @@ class TestOneHot(unittest.TestCase):
|
||||||
y.sync()
|
y.sync()
|
||||||
assert y.shape == [2,3]
|
assert y.shape == [2,3]
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||||
def test_normal(self):
|
def test_normal(self):
|
||||||
import torch
|
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
mu = np.random.uniform(-1,1)
|
mu = np.random.uniform(-1,1)
|
||||||
sigma = np.random.uniform(0,2)
|
sigma = np.random.uniform(0,2)
|
||||||
|
@ -67,8 +76,8 @@ class TestOneHot(unittest.TestCase):
|
||||||
tn2 = torch.distributions.Normal(mu2,sigma2)
|
tn2 = torch.distributions.Normal(mu2,sigma2)
|
||||||
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
|
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||||
def test_categorical1(self):
|
def test_categorical1(self):
|
||||||
import torch
|
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||||
|
@ -79,9 +88,9 @@ class TestOneHot(unittest.TestCase):
|
||||||
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||||
def test_categorical2(self):
|
def test_categorical2(self):
|
||||||
def check(prob_shape, sample_shape):
|
def check(prob_shape, sample_shape):
|
||||||
import torch
|
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
||||||
|
|
||||||
|
@ -98,9 +107,9 @@ class TestOneHot(unittest.TestCase):
|
||||||
check((2,3), (4,))
|
check((2,3), (4,))
|
||||||
check((3,4,5,6), (2,))
|
check((3,4,5,6), (2,))
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||||
def test_one_hot_categorical2(self):
|
def test_one_hot_categorical2(self):
|
||||||
def check(prob_shape, sample_shape):
|
def check(prob_shape, sample_shape):
|
||||||
import torch
|
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
||||||
|
|
||||||
|
@ -117,8 +126,8 @@ class TestOneHot(unittest.TestCase):
|
||||||
check((2,3), (4,))
|
check((2,3), (4,))
|
||||||
check((3,4,5,6), (2,))
|
check((3,4,5,6), (2,))
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||||
def test_uniform(self):
|
def test_uniform(self):
|
||||||
import torch
|
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
low, low2 = np.random.randint(-1,2), np.random.randint(-1,2)
|
low, low2 = np.random.randint(-1,2), np.random.randint(-1,2)
|
||||||
leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2)
|
leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2)
|
||||||
|
@ -130,8 +139,8 @@ class TestOneHot(unittest.TestCase):
|
||||||
assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x)))
|
assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x)))
|
||||||
assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2))
|
assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2))
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||||
def test_geometric(self):
|
def test_geometric(self):
|
||||||
import torch
|
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
|
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
|
||||||
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
|
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
|
||||||
|
|
|
@ -10,11 +10,19 @@
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
import unittest
|
import unittest
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
skip_this_test = False
|
||||||
|
try:
|
||||||
|
jt.dirty_fix_pytorch_runtime_error()
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
except:
|
||||||
|
torch = None
|
||||||
|
skip_this_test = True
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||||
class TestFoldOp(unittest.TestCase):
|
class TestFoldOp(unittest.TestCase):
|
||||||
def test_fold(self):
|
def test_fold(self):
|
||||||
# test unfold first and the test fold.
|
# test unfold first and the test fold.
|
||||||
|
|
|
@ -8,13 +8,13 @@
|
||||||
# This file is subject to the terms and conditions defined in
|
# This file is subject to the terms and conditions defined in
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
import torch
|
|
||||||
from torch.autograd import Variable
|
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
import autograd.numpy as anp
|
import autograd.numpy as anp
|
||||||
from autograd import jacobian
|
from autograd import jacobian
|
||||||
|
|
||||||
|
|
|
@ -18,12 +18,14 @@ import unittest
|
||||||
from .test_reorder_tuner import simple_parser
|
from .test_reorder_tuner import simple_parser
|
||||||
from .test_log import find_log_with_re
|
from .test_log import find_log_with_re
|
||||||
|
|
||||||
|
skip_this_test = False
|
||||||
try:
|
try:
|
||||||
jt.dirty_fix_pytorch_runtime_error()
|
jt.dirty_fix_pytorch_runtime_error()
|
||||||
import torch
|
import torch
|
||||||
except:
|
except:
|
||||||
skip_this_test = True
|
skip_this_test = True
|
||||||
|
|
||||||
|
|
||||||
class TestRandomOp(unittest.TestCase):
|
class TestRandomOp(unittest.TestCase):
|
||||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
@jt.flag_scope(use_cuda=1)
|
@jt.flag_scope(use_cuda=1)
|
||||||
|
@ -51,6 +53,7 @@ class TestRandomOp(unittest.TestCase):
|
||||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)")
|
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)")
|
||||||
assert len(logs)==1
|
assert len(logs)==1
|
||||||
|
|
||||||
|
@unittest.skipIf(skip_this_test, "No Torch Found")
|
||||||
def test_normal(self):
|
def test_normal(self):
|
||||||
from jittor import init
|
from jittor import init
|
||||||
n = 10000
|
n = 10000
|
||||||
|
|
|
@ -18,12 +18,9 @@ try:
|
||||||
jt.dirty_fix_pytorch_runtime_error()
|
jt.dirty_fix_pytorch_runtime_error()
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as tnn
|
import torch.nn as tnn
|
||||||
import torchvision
|
|
||||||
from torch.autograd import Variable
|
|
||||||
except:
|
except:
|
||||||
torch = None
|
torch = None
|
||||||
tnn = None
|
tnn = None
|
||||||
torchvision = None
|
|
||||||
skip_this_test = True
|
skip_this_test = True
|
||||||
|
|
||||||
# TODO: more test
|
# TODO: more test
|
||||||
|
|
|
@ -9,11 +9,16 @@
|
||||||
import unittest
|
import unittest
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ctypes
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
from torch.autograd import Variable
|
|
||||||
|
|
||||||
|
skip_this_test = False
|
||||||
|
try:
|
||||||
|
jt.dirty_fix_pytorch_runtime_error()
|
||||||
|
import torch
|
||||||
|
except:
|
||||||
|
skip_this_test = True
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf
|
||||||
class TestSearchsorted(unittest.TestCase):
|
class TestSearchsorted(unittest.TestCase):
|
||||||
def test_searchsorted_cpu(self):
|
def test_searchsorted_cpu(self):
|
||||||
for i in range(1,3):
|
for i in range(1,3):
|
||||||
|
|
|
@ -32,7 +32,7 @@ cc_flags = f" -g -O0 -DTEST --std=c++14 -I{jittor_path}/test -I{jittor_path}/src
|
||||||
|
|
||||||
class TestUtils(unittest.TestCase):
|
class TestUtils(unittest.TestCase):
|
||||||
def test_cache_compile(self):
|
def test_cache_compile(self):
|
||||||
cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile"
|
cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/str_utils.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile"
|
||||||
self.assertEqual(os.system(cmd), 0)
|
self.assertEqual(os.system(cmd), 0)
|
||||||
|
|
||||||
def test_log(self):
|
def test_log(self):
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
9edb1890b66321883c398591483d4d65377a848d
|
939b29514b2e5cc591053aab614efd569772585d
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This file is subject to the terms and conditions defined in
|
# This file is subject to the terms and conditions defined in
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
from multiprocessing import Pool, Value
|
from multiprocessing import Pool
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import subprocess as sp
|
import subprocess as sp
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
||||||
libext = {
|
libext = {
|
||||||
'Linux': 'so',
|
'Linux': 'so',
|
||||||
'Darwin': 'dylib',
|
'Darwin': 'dylib',
|
||||||
'Window': 'DLL',
|
'Windows': 'DLL',
|
||||||
}[platform.system()]
|
}[platform.system()]
|
||||||
ldflags = jittor_utils.run_cmd(jittor_utils.py3_config_path + " --ldflags")
|
ldflags = jittor_utils.run_cmd(jittor_utils.py3_config_path + " --ldflags")
|
||||||
libpaths = [l[2:] for l in ldflags.split(' ') if l.startswith("-L")]
|
libpaths = [l[2:] for l in ldflags.split(' ') if l.startswith("-L")]
|
||||||
|
|
Loading…
Reference in New Issue