mirror of https://github.com/Jittor/Jittor
improve import speed
This commit is contained in:
parent
8d64d98a35
commit
0f97f26b89
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.0.1'
|
||||
__version__ = '1.2.0.2'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
|
|
@ -18,6 +18,7 @@ import jittor_utils as jit_utils
|
|||
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
|
||||
from . import pyjt_compiler
|
||||
from . import lock
|
||||
from jittor import __version__
|
||||
|
||||
def find_jittor_path():
|
||||
return os.path.dirname(__file__)
|
||||
|
@ -615,7 +616,7 @@ def compile_custom_ops(
|
|||
if len(gen_name) > 100:
|
||||
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))
|
||||
|
||||
includes = set(includes)
|
||||
includes = sorted(list(set(includes)))
|
||||
includes = "".join(map(lambda x: f" -I'{x}' ", includes))
|
||||
LOG.vvvv(f"Include flags:{includes}")
|
||||
|
||||
|
@ -828,6 +829,8 @@ jittor_path = find_jittor_path()
|
|||
check_debug_flags()
|
||||
|
||||
sys.path.append(cache_path)
|
||||
LOG.i(f"Jittor({__version__}) src: {jittor_path}")
|
||||
LOG.i(f"cache_path: {cache_path}")
|
||||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
|
|
|
@ -149,18 +149,26 @@ def do_compile(args):
|
|||
|
||||
pool_size = 0
|
||||
|
||||
def pool_cleanup():
|
||||
global p
|
||||
p.__exit__(None, None, None)
|
||||
del p
|
||||
|
||||
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
||||
global pool_size
|
||||
global pool_size, p
|
||||
bk = mp.current_process()._config.get('daemon')
|
||||
mp.current_process()._config['daemon'] = False
|
||||
if pool_size == 0:
|
||||
mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
|
||||
mem_gib = mem_bytes/(1024.**3)
|
||||
pool_size = min(16,max(int(mem_gib // 3), 1))
|
||||
LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.")
|
||||
p = Pool(pool_size)
|
||||
p.__enter__()
|
||||
import atexit
|
||||
atexit.register(pool_cleanup)
|
||||
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
|
||||
bk = mp.current_process()._config.get('daemon')
|
||||
mp.current_process()._config['daemon'] = False
|
||||
try:
|
||||
with Pool(pool_size) as p:
|
||||
n = len(cmds)
|
||||
dp = DelayProgress(msg, n)
|
||||
for i,_ in enumerate(p.imap_unordered(do_compile, cmds)):
|
||||
|
|
Loading…
Reference in New Issue