From 0f97f26b89641b84226972c106f8db270d489b8e Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Wed, 14 Oct 2020 12:40:21 +0800 Subject: [PATCH] improve import speed --- python/jittor/__init__.py | 2 +- python/jittor/compiler.py | 5 ++++- python/jittor_utils/__init__.py | 24 ++++++++++++++++-------- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index dd51e96f..9a4cd418 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -7,7 +7,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.0.1' +__version__ = '1.2.0.2' from . import lock with lock.lock_scope(): from . import compiler diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index 7170c468..e7c63d61 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -18,6 +18,7 @@ import jittor_utils as jit_utils from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path from . import pyjt_compiler from . import lock +from jittor import __version__ def find_jittor_path(): return os.path.dirname(__file__) @@ -615,7 +616,7 @@ def compile_custom_ops( if len(gen_name) > 100: gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) - includes = set(includes) + includes = sorted(list(set(includes))) includes = "".join(map(lambda x: f" -I'{x}' ", includes)) LOG.vvvv(f"Include flags:{includes}") @@ -828,6 +829,8 @@ jittor_path = find_jittor_path() check_debug_flags() sys.path.append(cache_path) +LOG.i(f"Jittor({__version__}) src: {jittor_path}") +LOG.i(f"cache_path: {cache_path}") with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() diff --git a/python/jittor_utils/__init__.py b/python/jittor_utils/__init__.py index 24d705c6..3ba43e34 100644 --- a/python/jittor_utils/__init__.py +++ b/python/jittor_utils/__init__.py @@ -149,22 +149,30 @@ def do_compile(args): pool_size = 0 +def pool_cleanup(): + global p + p.__exit__(None, None, None) + del p + def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"): - global pool_size + global pool_size, p + bk = mp.current_process()._config.get('daemon') + mp.current_process()._config['daemon'] = False if pool_size == 0: mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') mem_gib = mem_bytes/(1024.**3) pool_size = min(16,max(int(mem_gib // 3), 1)) LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.") + p = Pool(pool_size) + p.__enter__() + import atexit + atexit.register(pool_cleanup) cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ] - bk = mp.current_process()._config.get('daemon') - mp.current_process()._config['daemon'] = False try: - with Pool(pool_size) as p: - n = len(cmds) - dp = DelayProgress(msg, n) - for i,_ in enumerate(p.imap_unordered(do_compile, cmds)): - dp.update(i) + n = len(cmds) + dp = DelayProgress(msg, n) + for i,_ in enumerate(p.imap_unordered(do_compile, cmds)): + dp.update(i) finally: mp.current_process()._config['daemon'] = bk