mirror of https://github.com/Jittor/Jittor
525 lines
17 KiB
Python
525 lines
17 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
from multiprocessing import Pool
|
|
import multiprocessing as mp
|
|
import subprocess as sp
|
|
import os
|
|
import re
|
|
import sys
|
|
import inspect
|
|
import datetime
|
|
import contextlib
|
|
import platform
|
|
import threading
|
|
import time
|
|
from ctypes import cdll
|
|
import shutil
|
|
import urllib.request
|
|
|
|
if platform.system() == 'Darwin':
|
|
mp.set_start_method('fork')
|
|
|
|
class Logwrapper:
|
|
def __init__(self):
|
|
self.log_silent = int(os.environ.get("log_silent", "0"))
|
|
self.log_v = int(os.environ.get("log_v", "0"))
|
|
|
|
def log_capture_start(self):
|
|
cc.log_capture_start()
|
|
|
|
def log_capture_stop(self):
|
|
cc.log_capture_stop()
|
|
|
|
def log_capture_read(self):
|
|
return cc.log_capture_read()
|
|
|
|
def _log(self, level, verbose, *msg):
|
|
if self.log_silent or verbose > self.log_v:
|
|
return
|
|
ss = ""
|
|
for m in msg:
|
|
if callable(m):
|
|
m = m()
|
|
ss += str(m)
|
|
msg = ss
|
|
f = inspect.currentframe()
|
|
fileline = inspect.getframeinfo(f.f_back.f_back)
|
|
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
|
|
if cc and hasattr(cc, "log"):
|
|
cc.log(fileline, level, verbose, msg)
|
|
else:
|
|
time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f")
|
|
tid = threading.get_ident()%100
|
|
v = f" v{verbose}" if verbose else ""
|
|
print(f"[{level} {time} {tid:02}{v} {fileline}] {msg}")
|
|
|
|
def V(self, verbose, *msg): self._log('i', verbose, *msg)
|
|
def v(self, *msg): self._log('i', 1, *msg)
|
|
def vv(self, *msg): self._log('i', 10, *msg)
|
|
def vvv(self, *msg): self._log('i', 100, *msg)
|
|
def vvvv(self, *msg): self._log('i', 1000, *msg)
|
|
def i(self, *msg): self._log('i', 0, *msg)
|
|
def w(self, *msg): self._log('w', 0, *msg)
|
|
def e(self, *msg): self._log('e', 0, *msg)
|
|
def f(self, *msg): self._log('f', 0, *msg)
|
|
|
|
class DelayProgress:
|
|
def __init__(self, msg, n):
|
|
self.msg = msg
|
|
self.n = n
|
|
self.time = time.time()
|
|
|
|
def update(self, i):
|
|
if LOG.log_silent:
|
|
return
|
|
used = time.time() - self.time
|
|
if used > 2:
|
|
eta = used / (i+1) * (self.n-i-1)
|
|
print(f"{self.msg}({i+1}/{self.n}) used: {used:.3f}s eta: {eta:.3f}s", end='\r')
|
|
if i==self.n-1: print()
|
|
|
|
# check is in jupyter notebook
|
|
def in_ipynb():
|
|
try:
|
|
cfg = get_ipython().config
|
|
if 'IPKernelApp' in cfg:
|
|
return True
|
|
else:
|
|
return False
|
|
except:
|
|
return False
|
|
|
|
@contextlib.contextmanager
|
|
def simple_timer(name):
|
|
print("Timer start", name)
|
|
now = time.time()
|
|
yield
|
|
print("Time stop", name, time.time()-now)
|
|
|
|
@contextlib.contextmanager
|
|
def import_scope(flags):
|
|
if os.name != 'nt':
|
|
prev = sys.getdlopenflags()
|
|
sys.setdlopenflags(flags)
|
|
yield
|
|
if os.name != 'nt':
|
|
sys.setdlopenflags(prev)
|
|
|
|
def try_import_jit_utils_core(silent=None):
|
|
global cc
|
|
if cc: return
|
|
if not (silent is None):
|
|
prev = os.environ.get("log_silent", "0")
|
|
os.environ["log_silent"] = str(int(silent))
|
|
try:
|
|
# if is in notebook, must log sync, and we redirect the log
|
|
if is_in_ipynb: os.environ["log_sync"] = "1"
|
|
import jit_utils_core as cc
|
|
if is_in_ipynb:
|
|
cc.ostream_redirect(True, True)
|
|
except Exception as _:
|
|
if int(os.environ.get("log_v", "0")) > 0:
|
|
print(_)
|
|
pass
|
|
if not (silent is None):
|
|
os.environ["log_silent"] = prev
|
|
|
|
def run_cmd(cmd, cwd=None, err_msg=None, print_error=True):
|
|
LOG.v(f"Run cmd: {cmd}")
|
|
if cwd:
|
|
r = sp.run(cmd, cwd=cwd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT)
|
|
else:
|
|
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT)
|
|
try:
|
|
s = r.stdout.decode('utf8')
|
|
except:
|
|
s = r.stdout.decode('gbk')
|
|
if r.returncode != 0:
|
|
if print_error:
|
|
sys.stderr.write(s)
|
|
if err_msg is None:
|
|
err_msg = f"Run cmd failed: {cmd}"
|
|
if not print_error:
|
|
err_msg += "\n"+s
|
|
raise Exception(err_msg)
|
|
if len(s) and s[-1] == '\n': s = s[:-1]
|
|
return s
|
|
|
|
|
|
def do_compile(args):
|
|
cmd, cache_path, jittor_path = args
|
|
try_import_jit_utils_core(True)
|
|
if cc:
|
|
return cc.cache_compile(cmd, cache_path, jittor_path)
|
|
else:
|
|
run_cmd(cmd)
|
|
return True
|
|
|
|
pool_size = 0
|
|
|
|
def pool_cleanup():
|
|
global p
|
|
p.__exit__(None, None, None)
|
|
del p
|
|
|
|
def pool_initializer():
|
|
if os.name == 'nt':
|
|
os.environ['log_silent'] = '1'
|
|
os.environ['gdb_path'] = ""
|
|
if cc is None:
|
|
try_import_jit_utils_core()
|
|
if cc:
|
|
cc.init_subprocess()
|
|
|
|
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
|
global pool_size, p
|
|
bk = mp.current_process()._config.get('daemon')
|
|
mp.current_process()._config['daemon'] = False
|
|
if pool_size == 0:
|
|
try:
|
|
mem_bytes = get_total_mem()
|
|
mem_gib = mem_bytes/(1024.**3)
|
|
pool_size = min(16,max(int(mem_gib // 3), 1))
|
|
LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.")
|
|
except ValueError:
|
|
# On macOS, python with version lower than 3.9 do not support SC_PHYS_PAGES.
|
|
# Use hard coded pool size instead.
|
|
pool_size = 4
|
|
LOG.i(f"using {pool_size} procs for compiling.")
|
|
if os.name == 'nt':
|
|
# a hack way to by pass windows
|
|
# multiprocess spawn init_main_from_path.
|
|
# check spawn.py:get_preparation_data
|
|
spec_bk = sys.modules['__main__'].__spec__
|
|
tmp = lambda x:x
|
|
tmp.name = '__main__'
|
|
sys.modules['__main__'].__spec__ = tmp
|
|
p = Pool(pool_size, initializer=pool_initializer)
|
|
p.__enter__()
|
|
if os.name == 'nt':
|
|
sys.modules['__main__'].__spec__ = spec_bk
|
|
import atexit
|
|
atexit.register(pool_cleanup)
|
|
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
|
|
try:
|
|
n = len(cmds)
|
|
dp = DelayProgress(msg, n)
|
|
for i,_ in enumerate(p.imap_unordered(do_compile, cmds)):
|
|
dp.update(i)
|
|
finally:
|
|
mp.current_process()._config['daemon'] = bk
|
|
|
|
if os.name=='nt' and getattr(mp.current_process(), '_inheriting', False):
|
|
# when windows spawn multiprocess, disable sub-subprocess
|
|
os.environ["DISABLE_MULTIPROCESSING"] = '1'
|
|
os.environ["log_silent"] = '1'
|
|
|
|
if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1':
|
|
os.environ["use_parallel_op_compiler"] = '0'
|
|
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
|
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
|
|
n = len(cmds)
|
|
dp = DelayProgress(msg, n)
|
|
for i,cmd in enumerate(cmds):
|
|
dp.update(i)
|
|
do_compile(cmd)
|
|
|
|
|
|
def download(url, filename):
|
|
if os.path.isfile(filename):
|
|
if os.path.getsize(filename) > 100:
|
|
return
|
|
LOG.v("Downloading", url)
|
|
urllib.request.urlretrieve(url, filename)
|
|
LOG.v("Download finished")
|
|
|
|
def get_jittor_version():
|
|
path = os.path.dirname(__file__)
|
|
with open(os.path.join(path, "../jittor/__init__.py"), "r", encoding='utf8') as fh:
|
|
for line in fh:
|
|
if line.startswith('__version__'):
|
|
version = line.split("'")[1]
|
|
break
|
|
else:
|
|
raise RuntimeError("Unable to find version string.")
|
|
return version
|
|
|
|
def get_str_hash(s):
|
|
import hashlib
|
|
md5 = hashlib.md5()
|
|
md5.update(s.encode())
|
|
return md5.hexdigest()
|
|
|
|
def get_cpu_version():
|
|
v = platform.processor()
|
|
try:
|
|
if os.name == 'nt':
|
|
import winreg
|
|
key_name = r"Hardware\Description\System\CentralProcessor\0"
|
|
field_name = "ProcessorNameString"
|
|
key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, key_name)
|
|
value = winreg.QueryValueEx(key, field_name)[0]
|
|
winreg.CloseKey(key)
|
|
v = value
|
|
elif platform.system() == "Darwin":
|
|
r, s = sp.getstatusoutput("sysctl -a sysctl machdep.cpu.brand_string")
|
|
if r==0:
|
|
v = s.split(":")[-1].strip()
|
|
else:
|
|
with open("/proc/cpuinfo", 'r') as f:
|
|
for l in f:
|
|
if l.startswith("model name"):
|
|
v = l.split(':')[-1].strip()
|
|
break
|
|
except:
|
|
pass
|
|
return v
|
|
|
|
def short(s):
|
|
ss = ""
|
|
for c in s:
|
|
if str.isidentifier(c) or str.isnumeric(c) \
|
|
or str.isalpha(c) or c in '.-+':
|
|
ss += c
|
|
if len(ss)>14:
|
|
return ss[:14]+'x'+get_str_hash(ss)[:2]
|
|
return ss
|
|
|
|
def find_cache_path():
|
|
from pathlib import Path
|
|
path = str(Path.home())
|
|
# jittor version key
|
|
jtv = "jt"+get_jittor_version().rsplit('.', 1)[0]
|
|
# cc version key
|
|
ccv = cc_type+get_version(cc_path)[1:-1] \
|
|
if cc_type != "cl" else cc_type
|
|
# os version key
|
|
osv = platform.platform() + platform.node()
|
|
if len(osv)>14:
|
|
osv = osv[:14] + 'x'+get_str_hash(osv)[:2]
|
|
# py version
|
|
pyv = "py"+platform.python_version()
|
|
# cpu version
|
|
cpuv = get_cpu_version()
|
|
dirs = [".cache", "jittor", jtv, ccv, pyv, osv, cpuv]
|
|
dirs = list(map(short, dirs))
|
|
cache_name = "default"
|
|
try:
|
|
if "cache_name" in os.environ:
|
|
cache_name = os.environ["cache_name"]
|
|
else:
|
|
# try to get branch name from git
|
|
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
|
|
stderr=sp.PIPE)
|
|
assert r.returncode == 0
|
|
bs = r.stdout.decode().splitlines()
|
|
for b in bs:
|
|
if b.startswith("* "): break
|
|
|
|
cache_name = b[2:]
|
|
for c in " (){}": cache_name = cache_name.replace(c, "_")
|
|
except:
|
|
pass
|
|
if os.environ.get("debug")=="1":
|
|
dirs[-1] += "_debug"
|
|
for name in os.path.normpath(cache_name).split(os.path.sep):
|
|
dirs.append(name)
|
|
os.environ["cache_name"] = cache_name
|
|
LOG.v("cache_name: ", cache_name)
|
|
path = os.path.join(path, *dirs)
|
|
os.makedirs(path, exist_ok=True)
|
|
if path not in sys.path:
|
|
sys.path.append(path)
|
|
return path
|
|
|
|
def get_version(output):
|
|
if output.endswith("mpicc"):
|
|
version = run_cmd(f"\"{output}\" --showme:version")
|
|
elif os.name == 'nt' and (
|
|
output.endswith("cl") or output.endswith("cl.exe")):
|
|
version = run_cmd(output)
|
|
else:
|
|
version = run_cmd(f"\"{output}\" --version")
|
|
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
|
|
if len(v) == 0:
|
|
v = re.findall("[0-9]+\\.[0-9]+", version)
|
|
assert len(v) != 0, f"Can not find version number from: {version}"
|
|
if 'clang' in version and platform.system() == 'Darwin':
|
|
version = "("+v[-3]+")"
|
|
else:
|
|
version = "("+v[-1]+")"
|
|
return version
|
|
|
|
def get_int_version(output):
|
|
ver = get_version(output)
|
|
ver = ver[1:-1].split('.')
|
|
ver = tuple(( int(v) for v in ver ))
|
|
return ver
|
|
|
|
def find_exe(name, check_version=True, silent=False):
|
|
output = shutil.which(name)
|
|
if not output:
|
|
raise RuntimeError(f"{name} not found")
|
|
if check_version:
|
|
version = get_version(name)
|
|
else:
|
|
version = ""
|
|
if not silent:
|
|
LOG.i(f"Found {name}{version} at {output}.")
|
|
return output
|
|
|
|
def env_or_find(name, bname, silent=False):
|
|
if name in os.environ:
|
|
path = os.environ[name]
|
|
if path != "":
|
|
version = get_version(path)
|
|
if not silent:
|
|
LOG.i(f"Found {bname}{version} at {path}")
|
|
return path
|
|
return find_exe(bname, silent=silent)
|
|
|
|
def get_cc_type(cc_path):
|
|
bname = os.path.basename(cc_path)
|
|
if "clang" in bname: return "clang"
|
|
if "icc" in bname or "icpc" in bname: return "icc"
|
|
if "g++" in bname: return "g++"
|
|
if "cl" in bname: return "cl"
|
|
LOG.f(f"Unknown cc type: {bname}")
|
|
|
|
def get_py3_config_path():
|
|
global _py3_config_path
|
|
if _py3_config_path:
|
|
return _py3_config_path
|
|
|
|
if os.name == 'nt':
|
|
return None
|
|
else:
|
|
# Search python3.x-config
|
|
# Note:
|
|
# This may be called via c++ console. In that case, sys.executable will
|
|
# be a path to the executable file, rather than python. So, we cannot infer
|
|
# python-config path only from sys.executable.
|
|
# To address this issue, we add predefined paths to search,
|
|
# - Linux: /usr/bin/python3.x-config
|
|
# - macOS (installed via homebrew): /usr/local/bin/python3.x-config
|
|
# There may be issues under other cases, e.g., installed via conda.
|
|
py3_config_paths = [
|
|
os.path.dirname(sys.executable) + f"/python3.{sys.version_info.minor}-config",
|
|
sys.executable + "-config",
|
|
f"/usr/bin/python3.{sys.version_info.minor}-config",
|
|
f"/usr/local/bin/python3.{sys.version_info.minor}-config",
|
|
f'/opt/homebrew/bin/python3.{sys.version_info.minor}-config',
|
|
os.path.dirname(sys.executable) + "/python3-config",
|
|
]
|
|
if "python_config_path" in os.environ:
|
|
py3_config_paths.insert(0, os.environ["python_config_path"])
|
|
|
|
for py3_config_path in py3_config_paths:
|
|
if os.path.isfile(py3_config_path):
|
|
break
|
|
else:
|
|
raise RuntimeError(f"python3.{sys.version_info.minor}-config "
|
|
f"not found in {py3_config_paths}, please specify "
|
|
f"enviroment variable 'python_config_path',"
|
|
f" or install python3.{sys.version_info.minor}-dev")
|
|
_py3_config_path = py3_config_path
|
|
return py3_config_path
|
|
|
|
def get_py3_include_path():
|
|
global _py3_include_path
|
|
if _py3_include_path:
|
|
return _py3_include_path
|
|
|
|
if os.name == 'nt':
|
|
# Windows
|
|
sys.executable = sys.executable.lower()
|
|
_py3_include_path = '-I"' + os.path.join(
|
|
os.path.dirname(sys.executable),
|
|
"include"
|
|
) + '"'
|
|
else:
|
|
_py3_include_path = run_cmd(get_py3_config_path()+" --includes")
|
|
return _py3_include_path
|
|
|
|
|
|
def get_py3_extension_suffix():
|
|
global _py3_extension_suffix
|
|
if _py3_extension_suffix:
|
|
return _py3_extension_suffix
|
|
|
|
if os.name == 'nt':
|
|
# Windows
|
|
_py3_extension_suffix = f".cp3{sys.version_info.minor}-win_amd64.pyd"
|
|
else:
|
|
_py3_extension_suffix = run_cmd(get_py3_config_path()+" --extension-suffix")
|
|
return _py3_extension_suffix
|
|
|
|
def get_total_mem():
|
|
if os.name == 'nt':
|
|
from ctypes import Structure, c_int32, c_uint64, sizeof, byref, windll
|
|
class MemoryStatusEx(Structure):
|
|
_fields_ = [
|
|
('length', c_int32),
|
|
('memoryLoad', c_int32),
|
|
('totalPhys', c_uint64),
|
|
('availPhys', c_uint64),
|
|
('totalPageFile', c_uint64),
|
|
('availPageFile', c_uint64),
|
|
('totalVirtual', c_uint64),
|
|
('availVirtual', c_uint64),
|
|
('availExtendedVirtual', c_uint64)]
|
|
def __init__(self):
|
|
self.length = sizeof(self)
|
|
m = MemoryStatusEx()
|
|
assert windll.kernel32.GlobalMemoryStatusEx(byref(m))
|
|
return m.totalPhys
|
|
else:
|
|
return os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
|
|
|
|
is_in_ipynb = in_ipynb()
|
|
cc = None
|
|
LOG = Logwrapper()
|
|
|
|
check_msvc_install = False
|
|
msvc_path = ""
|
|
if os.name == 'nt' and os.environ.get("cc_path", "")=="":
|
|
from pathlib import Path
|
|
msvc_path = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
|
|
cc_path = os.path.join(msvc_path, "VC", r"_\_\_\_\_\bin", "cl.exe")
|
|
check_msvc_install = True
|
|
else:
|
|
cc_path = env_or_find('cc_path', 'g++', silent=True)
|
|
os.environ["cc_path"] = cc_path
|
|
cc_type = get_cc_type(cc_path)
|
|
cache_path = find_cache_path()
|
|
|
|
_py3_config_path = None
|
|
_py3_include_path = None
|
|
_py3_extension_suffix = None
|
|
|
|
if os.name == 'nt':
|
|
from pathlib import Path
|
|
try:
|
|
import ssl
|
|
ssl._create_default_https_context = ssl._create_unverified_context
|
|
except:
|
|
pass
|
|
if check_msvc_install:
|
|
if not os.path.isfile(cc_path):
|
|
from jittor_utils import install_msvc
|
|
install_msvc.install(msvc_path)
|
|
mpath = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
|
|
if cc_path.startswith(mpath):
|
|
msvc_path = mpath
|
|
os.RTLD_NOW = os.RTLD_GLOBAL = os.RTLD_DEEPBIND = 0
|
|
path = os.path.dirname(cc_path).replace('/', '\\')
|
|
if path:
|
|
sys.path.insert(0, path)
|
|
os.environ["PATH"] = path+';'+os.environ["PATH"]
|
|
if hasattr(os, "add_dll_directory"):
|
|
os.add_dll_directory(path)
|