diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index b63a3afa..da59bb86 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -7,7 +7,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.1.7.11' +__version__ = '1.1.7.12' from . import lock with lock.lock_scope(): from . import compiler diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index 95fa28a9..9ef86c0b 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -784,11 +784,16 @@ def try_find_exe(*args): def check_pybt(gdb_path, python_path): if gdb_path=='' or python_path=='': return False - ret = sp.getoutput(f"{gdb_path} --batch {python_path} -ex 'help py-bt'") - if 'python frame' in ret: - LOG.v("py-bt found in gdb.") - return True - return False + return True + # TODO: prev we use below code to check has py-bt or nor + # but it is too slow, so we comment it, + # find a better way to check py-bt exist + + # ret = sp.getoutput(f"{gdb_path} --batch {python_path} -ex 'help py-bt'") + # if 'python frame' in ret: + # LOG.v("py-bt found in gdb.") + # return True + # return False def check_debug_flags(): global is_debug diff --git a/python/jittor/pyjt_compiler.py b/python/jittor/pyjt_compiler.py index 7d2f3aac..542c9c6f 100644 --- a/python/jittor/pyjt_compiler.py +++ b/python/jittor/pyjt_compiler.py @@ -449,7 +449,7 @@ def compile_src(src, h, basename): continue else: defs.append(def_info) - LOG.vvv(json.dumps(def_info, indent=4)) + LOG.vvv(lambda: json.dumps(def_info, indent=4)) # deal with defs if len(defs) == 0: return # include_name = h[4:] # remove "src/" prefix diff --git a/python/jittor/test/test_init.py b/python/jittor/test/test_init.py new file mode 100644 index 00000000..58b9099a --- /dev/null +++ b/python/jittor/test/test_init.py @@ -0,0 +1,59 @@ +# *************************************************************** +# Copyright (c) Jittor 2020, Author: +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import unittest +import numpy as np +from jittor import models + +pass_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torchvision +except Exception as e: + pass_this_test = True + +def get_error(a, b): + return np.abs(a-b) / max(np.abs(a), np.abs(b), 1e-5) , np.abs(a-b) + +def check(jt_mod, torch_mod, rtol=1e-2, atol=1e-5, mean_atol=1e-5): + pa = [ p for p in jt_mod.parameters() if not p.is_stop_grad() ] + pb = list(torch_mod.parameters()) + assert len(pa) == len(pb) + error_count = 0 + for a,b in zip(pa, pb): + assert a.shape == list(b.shape), (a.shape, b.shape, a.name()) + stda, meana = np.std(a.numpy()), np.mean(a.numpy()) + stdb, meanb = np.std(b.detach().numpy()), np.mean(b.detach().numpy()) + + r_err, a_err = get_error(stda, stdb) + if r_err > rtol and a_err > atol: + error_count += 1 + print("compare std error", stda, stdb, r_err, a_err, a.name(), a.shape) + + r_err, a_err = get_error(meana, meanb) + if r_err > rtol and a_err > mean_atol: + error_count += 1 + print("compare mean error", meana, meanb, r_err, a_err, a.name(), a.shape) + assert error_count == 0 + +@unittest.skipIf(pass_this_test, f"pass init check, no torch found") +class TestInit(unittest.TestCase): + @classmethod + def setUpClass(self): + jt.seed(0) + np.random.seed(0) + torch.manual_seed(0) + + def test_conv(self): + check(jt.nn.Conv(64, 256, 3), torch.nn.Conv2d(64, 256, 3), rtol=1e-1, mean_atol=1e-3) + + def test_resnet(self): + check(models.resnet152(), torchvision.models.resnet152(), rtol=2e-2, mean_atol=1e-2) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor_utils/__init__.py b/python/jittor_utils/__init__.py index 994db7b9..a3d33543 100644 --- a/python/jittor_utils/__init__.py +++ b/python/jittor_utils/__init__.py @@ -31,18 +31,20 @@ class LogWarper: return cc.log_capture_read() def _log(self, level, verbose, *msg): - if len(msg): - msg = " ".join([ str(m) for m in msg ]) - else: - msg = str(msg) + if self.log_silent or verbose > self.log_v: + return + ss = "" + for m in msg: + if callable(m): + m = m() + ss += str(m) + msg = ss f = inspect.currentframe() fileline = inspect.getframeinfo(f.f_back.f_back) fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}" if cc and hasattr(cc, "log"): cc.log(fileline, level, verbose, msg) else: - if self.log_silent or verbose > self.log_v: - return time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f") tid = threading.get_ident()%100 v = f" v{verbose}" if verbose else ""