improve import speed

This commit is contained in:
Dun Liang 2020-10-14 12:40:21 +08:00
parent 8d64d98a35
commit 0f97f26b89
3 changed files with 21 additions and 10 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.2.0.1' __version__ = '1.2.0.2'
from . import lock from . import lock
with lock.lock_scope(): with lock.lock_scope():
from . import compiler from . import compiler

View File

@ -18,6 +18,7 @@ import jittor_utils as jit_utils
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
from . import pyjt_compiler from . import pyjt_compiler
from . import lock from . import lock
from jittor import __version__
def find_jittor_path(): def find_jittor_path():
return os.path.dirname(__file__) return os.path.dirname(__file__)
@ -615,7 +616,7 @@ def compile_custom_ops(
if len(gen_name) > 100: if len(gen_name) > 100:
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))
includes = set(includes) includes = sorted(list(set(includes)))
includes = "".join(map(lambda x: f" -I'{x}' ", includes)) includes = "".join(map(lambda x: f" -I'{x}' ", includes))
LOG.vvvv(f"Include flags:{includes}") LOG.vvvv(f"Include flags:{includes}")
@ -828,6 +829,8 @@ jittor_path = find_jittor_path()
check_debug_flags() check_debug_flags()
sys.path.append(cache_path) sys.path.append(cache_path)
LOG.i(f"Jittor({__version__}) src: {jittor_path}")
LOG.i(f"cache_path: {cache_path}")
with jit_utils.import_scope(import_flags): with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core() jit_utils.try_import_jit_utils_core()

View File

@ -149,22 +149,30 @@ def do_compile(args):
pool_size = 0 pool_size = 0
def pool_cleanup():
global p
p.__exit__(None, None, None)
del p
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"): def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
global pool_size global pool_size, p
bk = mp.current_process()._config.get('daemon')
mp.current_process()._config['daemon'] = False
if pool_size == 0: if pool_size == 0:
mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
mem_gib = mem_bytes/(1024.**3) mem_gib = mem_bytes/(1024.**3)
pool_size = min(16,max(int(mem_gib // 3), 1)) pool_size = min(16,max(int(mem_gib // 3), 1))
LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.") LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.")
p = Pool(pool_size)
p.__enter__()
import atexit
atexit.register(pool_cleanup)
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ] cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
bk = mp.current_process()._config.get('daemon')
mp.current_process()._config['daemon'] = False
try: try:
with Pool(pool_size) as p: n = len(cmds)
n = len(cmds) dp = DelayProgress(msg, n)
dp = DelayProgress(msg, n) for i,_ in enumerate(p.imap_unordered(do_compile, cmds)):
for i,_ in enumerate(p.imap_unordered(do_compile, cmds)): dp.update(i)
dp.update(i)
finally: finally:
mp.current_process()._config['daemon'] = bk mp.current_process()._config['daemon'] = bk