From 556f0cfb1eed9a817388facbb8b723a1f05a1ff4 Mon Sep 17 00:00:00 2001 From: lzhengning Date: Fri, 18 Jun 2021 15:49:06 +0800 Subject: [PATCH] avoid import error when no torch installed --- python/jittor/test/test_bicubic.py | 12 +++++++++-- python/jittor/test/test_cumprod_op.py | 15 ++++++++++---- python/jittor/test/test_distributions.py | 23 +++++++++++++++------- python/jittor/test/test_fold.py | 12 +++++++++-- python/jittor/test/test_linalg.py | 4 ++-- python/jittor/test/test_random_op.py | 3 +++ python/jittor/test/test_search_sorted.py | 3 --- python/jittor/test/test_searchsorted_op.py | 13 ++++++++---- python/jittor/test/test_utils.py | 2 +- python/jittor/version | 2 +- python/jittor_utils/__init__.py | 2 +- python/jittor_utils/config.py | 2 +- 12 files changed, 65 insertions(+), 28 deletions(-) diff --git a/python/jittor/test/test_bicubic.py b/python/jittor/test/test_bicubic.py index 69441274..00a3f346 100644 --- a/python/jittor/test/test_bicubic.py +++ b/python/jittor/test/test_bicubic.py @@ -10,11 +10,19 @@ # *************************************************************** import unittest import jittor as jt -import torch -from torch.nn import functional as F import numpy as np +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + from torch.nn import functional as F +except: + torch = None + skip_this_test = True + +@unittest.skipIf(skip_this_test, "No Torch found") class TestBicubicInterpolate(unittest.TestCase): # this is for testing bicubic interpolate def test_bicubic(self): diff --git a/python/jittor/test/test_cumprod_op.py b/python/jittor/test/test_cumprod_op.py index 1f82ce64..7d356ff6 100644 --- a/python/jittor/test/test_cumprod_op.py +++ b/python/jittor/test/test_cumprod_op.py @@ -9,11 +9,18 @@ import unittest import jittor as jt import numpy as np -import ctypes -import sys -import torch -from torch.autograd import Variable +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + from torch.autograd import Variable +except: + torch = None + skip_this_test = True + + +@unittest.skipIf(skip_this_test, "No Torch found") class TestCumprod(unittest.TestCase): def test_cumprod_cpu(self): for i in range(1,6): diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index 1836ab65..b22d9108 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -12,6 +12,14 @@ import jittor as jt import numpy as np import jittor.distributions as jd +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch +except: + torch = None + skip_this_test = True + class TestOneHot(unittest.TestCase): def test_presum(self): @@ -19,6 +27,7 @@ class TestOneHot(unittest.TestCase): b = jd.simple_presum(a) assert (b.data == [[0,1,3,6,10]]).all() + @unittest.skipIf(skip_this_test, "No Torch Found") def test_one_hot(self): a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25])) x = a.sample().numpy() @@ -30,7 +39,7 @@ class TestOneHot(unittest.TestCase): assert y.shape == [2,3,4] probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() - import torch + jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2)) tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2)) assert np.allclose(jc.entropy().data,tc.entropy().numpy()) @@ -51,8 +60,8 @@ class TestOneHot(unittest.TestCase): y.sync() assert y.shape == [2,3] + @unittest.skipIf(skip_this_test, "No Torch Found") def test_normal(self): - import torch for _ in range(4): mu = np.random.uniform(-1,1) sigma = np.random.uniform(0,2) @@ -67,8 +76,8 @@ class TestOneHot(unittest.TestCase): tn2 = torch.distributions.Normal(mu2,sigma2) assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy()) + @unittest.skipIf(skip_this_test, "No Torch Found") def test_categorical1(self): - import torch for _ in range(4): probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() @@ -79,9 +88,9 @@ class TestOneHot(unittest.TestCase): np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5) assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) + @unittest.skipIf(skip_this_test, "No Torch Found") def test_categorical2(self): def check(prob_shape, sample_shape): - import torch for _ in range(4): probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape) @@ -98,9 +107,9 @@ class TestOneHot(unittest.TestCase): check((2,3), (4,)) check((3,4,5,6), (2,)) + @unittest.skipIf(skip_this_test, "No Torch Found") def test_one_hot_categorical2(self): def check(prob_shape, sample_shape): - import torch for _ in range(4): probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape) @@ -117,8 +126,8 @@ class TestOneHot(unittest.TestCase): check((2,3), (4,)) check((3,4,5,6), (2,)) + @unittest.skipIf(skip_this_test, "No Torch Found") def test_uniform(self): - import torch for _ in range(4): low, low2 = np.random.randint(-1,2), np.random.randint(-1,2) leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2) @@ -130,8 +139,8 @@ class TestOneHot(unittest.TestCase): assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x))) assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2)) + @unittest.skipIf(skip_this_test, "No Torch Found") def test_geometric(self): - import torch for _ in range(4): prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1) jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2) diff --git a/python/jittor/test/test_fold.py b/python/jittor/test/test_fold.py index bc394e47..7f8c2c86 100644 --- a/python/jittor/test/test_fold.py +++ b/python/jittor/test/test_fold.py @@ -10,11 +10,19 @@ # *************************************************************** import unittest import jittor as jt -import torch -from torch.nn import functional as F import numpy as np +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + from torch.nn import functional as F +except: + torch = None + skip_this_test = True + +@unittest.skipIf(skip_this_test, "No Torch Found") class TestFoldOp(unittest.TestCase): def test_fold(self): # test unfold first and the test fold. diff --git a/python/jittor/test/test_linalg.py b/python/jittor/test/test_linalg.py index bdbb7a54..f3d1c442 100644 --- a/python/jittor/test/test_linalg.py +++ b/python/jittor/test/test_linalg.py @@ -8,13 +8,13 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -import torch -from torch.autograd import Variable import jittor as jt import numpy as np import unittest try: + import torch + from torch.autograd import Variable import autograd.numpy as anp from autograd import jacobian diff --git a/python/jittor/test/test_random_op.py b/python/jittor/test/test_random_op.py index 597944a7..f6ca7cc7 100644 --- a/python/jittor/test/test_random_op.py +++ b/python/jittor/test/test_random_op.py @@ -18,12 +18,14 @@ import unittest from .test_reorder_tuner import simple_parser from .test_log import find_log_with_re +skip_this_test = False try: jt.dirty_fix_pytorch_runtime_error() import torch except: skip_this_test = True + class TestRandomOp(unittest.TestCase): @unittest.skipIf(not jt.has_cuda, "Cuda not found") @jt.flag_scope(use_cuda=1) @@ -51,6 +53,7 @@ class TestRandomOp(unittest.TestCase): logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)") assert len(logs)==1 + @unittest.skipIf(skip_this_test, "No Torch Found") def test_normal(self): from jittor import init n = 10000 diff --git a/python/jittor/test/test_search_sorted.py b/python/jittor/test/test_search_sorted.py index f10e56cc..d71d4f6a 100644 --- a/python/jittor/test/test_search_sorted.py +++ b/python/jittor/test/test_search_sorted.py @@ -18,12 +18,9 @@ try: jt.dirty_fix_pytorch_runtime_error() import torch import torch.nn as tnn - import torchvision - from torch.autograd import Variable except: torch = None tnn = None - torchvision = None skip_this_test = True # TODO: more test diff --git a/python/jittor/test/test_searchsorted_op.py b/python/jittor/test/test_searchsorted_op.py index de833195..65585b4f 100644 --- a/python/jittor/test/test_searchsorted_op.py +++ b/python/jittor/test/test_searchsorted_op.py @@ -9,11 +9,16 @@ import unittest import jittor as jt import numpy as np -import ctypes -import sys -import torch -from torch.autograd import Variable +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch +except: + skip_this_test = True + + +@unittest.skipIf class TestSearchsorted(unittest.TestCase): def test_searchsorted_cpu(self): for i in range(1,3): diff --git a/python/jittor/test/test_utils.py b/python/jittor/test/test_utils.py index cc67ed76..8b7fe194 100644 --- a/python/jittor/test/test_utils.py +++ b/python/jittor/test/test_utils.py @@ -32,7 +32,7 @@ cc_flags = f" -g -O0 -DTEST --std=c++14 -I{jittor_path}/test -I{jittor_path}/src class TestUtils(unittest.TestCase): def test_cache_compile(self): - cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile" + cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/str_utils.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile" self.assertEqual(os.system(cmd), 0) def test_log(self): diff --git a/python/jittor/version b/python/jittor/version index be2d89bc..98d3c70f 100644 --- a/python/jittor/version +++ b/python/jittor/version @@ -1 +1 @@ -9edb1890b66321883c398591483d4d65377a848d +939b29514b2e5cc591053aab614efd569772585d diff --git a/python/jittor_utils/__init__.py b/python/jittor_utils/__init__.py index 7a875c7b..2942f0c1 100644 --- a/python/jittor_utils/__init__.py +++ b/python/jittor_utils/__init__.py @@ -4,7 +4,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -from multiprocessing import Pool, Value +from multiprocessing import Pool import multiprocessing as mp import subprocess as sp import os diff --git a/python/jittor_utils/config.py b/python/jittor_utils/config.py index 4331afa7..fb7215b6 100644 --- a/python/jittor_utils/config.py +++ b/python/jittor_utils/config.py @@ -31,7 +31,7 @@ if __name__ == "__main__": libext = { 'Linux': 'so', 'Darwin': 'dylib', - 'Window': 'DLL', + 'Windows': 'DLL', }[platform.system()] ldflags = jittor_utils.run_cmd(jittor_utils.py3_config_path + " --ldflags") libpaths = [l[2:] for l in ldflags.split(' ') if l.startswith("-L")]