JittorMirror/python/jittor_utils/__init__.py

525 lines
17 KiB
Python

# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
from multiprocessing import Pool
import multiprocessing as mp
import subprocess as sp
import os
import re
import sys
import inspect
import datetime
import contextlib
import platform
import threading
import time
from ctypes import cdll
import shutil
import urllib.request
if platform.system() == 'Darwin':
mp.set_start_method('fork')
class Logwrapper:
def __init__(self):
self.log_silent = int(os.environ.get("log_silent", "0"))
self.log_v = int(os.environ.get("log_v", "0"))
def log_capture_start(self):
cc.log_capture_start()
def log_capture_stop(self):
cc.log_capture_stop()
def log_capture_read(self):
return cc.log_capture_read()
def _log(self, level, verbose, *msg):
if self.log_silent or verbose > self.log_v:
return
ss = ""
for m in msg:
if callable(m):
m = m()
ss += str(m)
msg = ss
f = inspect.currentframe()
fileline = inspect.getframeinfo(f.f_back.f_back)
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
if cc and hasattr(cc, "log"):
cc.log(fileline, level, verbose, msg)
else:
time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f")
tid = threading.get_ident()%100
v = f" v{verbose}" if verbose else ""
print(f"[{level} {time} {tid:02}{v} {fileline}] {msg}")
def V(self, verbose, *msg): self._log('i', verbose, *msg)
def v(self, *msg): self._log('i', 1, *msg)
def vv(self, *msg): self._log('i', 10, *msg)
def vvv(self, *msg): self._log('i', 100, *msg)
def vvvv(self, *msg): self._log('i', 1000, *msg)
def i(self, *msg): self._log('i', 0, *msg)
def w(self, *msg): self._log('w', 0, *msg)
def e(self, *msg): self._log('e', 0, *msg)
def f(self, *msg): self._log('f', 0, *msg)
class DelayProgress:
def __init__(self, msg, n):
self.msg = msg
self.n = n
self.time = time.time()
def update(self, i):
if LOG.log_silent:
return
used = time.time() - self.time
if used > 2:
eta = used / (i+1) * (self.n-i-1)
print(f"{self.msg}({i+1}/{self.n}) used: {used:.3f}s eta: {eta:.3f}s", end='\r')
if i==self.n-1: print()
# check is in jupyter notebook
def in_ipynb():
try:
cfg = get_ipython().config
if 'IPKernelApp' in cfg:
return True
else:
return False
except:
return False
@contextlib.contextmanager
def simple_timer(name):
print("Timer start", name)
now = time.time()
yield
print("Time stop", name, time.time()-now)
@contextlib.contextmanager
def import_scope(flags):
if os.name != 'nt':
prev = sys.getdlopenflags()
sys.setdlopenflags(flags)
yield
if os.name != 'nt':
sys.setdlopenflags(prev)
def try_import_jit_utils_core(silent=None):
global cc
if cc: return
if not (silent is None):
prev = os.environ.get("log_silent", "0")
os.environ["log_silent"] = str(int(silent))
try:
# if is in notebook, must log sync, and we redirect the log
if is_in_ipynb: os.environ["log_sync"] = "1"
import jit_utils_core as cc
if is_in_ipynb:
cc.ostream_redirect(True, True)
except Exception as _:
if int(os.environ.get("log_v", "0")) > 0:
print(_)
pass
if not (silent is None):
os.environ["log_silent"] = prev
def run_cmd(cmd, cwd=None, err_msg=None, print_error=True):
LOG.v(f"Run cmd: {cmd}")
if cwd:
r = sp.run(cmd, cwd=cwd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT)
else:
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT)
try:
s = r.stdout.decode('utf8')
except:
s = r.stdout.decode('gbk')
if r.returncode != 0:
if print_error:
sys.stderr.write(s)
if err_msg is None:
err_msg = f"Run cmd failed: {cmd}"
if not print_error:
err_msg += "\n"+s
raise Exception(err_msg)
if len(s) and s[-1] == '\n': s = s[:-1]
return s
def do_compile(args):
cmd, cache_path, jittor_path = args
try_import_jit_utils_core(True)
if cc:
return cc.cache_compile(cmd, cache_path, jittor_path)
else:
run_cmd(cmd)
return True
pool_size = 0
def pool_cleanup():
global p
p.__exit__(None, None, None)
del p
def pool_initializer():
if os.name == 'nt':
os.environ['log_silent'] = '1'
os.environ['gdb_path'] = ""
if cc is None:
try_import_jit_utils_core()
if cc:
cc.init_subprocess()
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
global pool_size, p
bk = mp.current_process()._config.get('daemon')
mp.current_process()._config['daemon'] = False
if pool_size == 0:
try:
mem_bytes = get_total_mem()
mem_gib = mem_bytes/(1024.**3)
pool_size = min(16,max(int(mem_gib // 3), 1))
LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.")
except ValueError:
# On macOS, python with version lower than 3.9 do not support SC_PHYS_PAGES.
# Use hard coded pool size instead.
pool_size = 4
LOG.i(f"using {pool_size} procs for compiling.")
if os.name == 'nt':
# a hack way to by pass windows
# multiprocess spawn init_main_from_path.
# check spawn.py:get_preparation_data
spec_bk = sys.modules['__main__'].__spec__
tmp = lambda x:x
tmp.name = '__main__'
sys.modules['__main__'].__spec__ = tmp
p = Pool(pool_size, initializer=pool_initializer)
p.__enter__()
if os.name == 'nt':
sys.modules['__main__'].__spec__ = spec_bk
import atexit
atexit.register(pool_cleanup)
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
try:
n = len(cmds)
dp = DelayProgress(msg, n)
for i,_ in enumerate(p.imap_unordered(do_compile, cmds)):
dp.update(i)
finally:
mp.current_process()._config['daemon'] = bk
if os.name=='nt' and getattr(mp.current_process(), '_inheriting', False):
# when windows spawn multiprocess, disable sub-subprocess
os.environ["DISABLE_MULTIPROCESSING"] = '1'
os.environ["log_silent"] = '1'
if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1':
os.environ["use_parallel_op_compiler"] = '0'
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
n = len(cmds)
dp = DelayProgress(msg, n)
for i,cmd in enumerate(cmds):
dp.update(i)
do_compile(cmd)
def download(url, filename):
if os.path.isfile(filename):
if os.path.getsize(filename) > 100:
return
LOG.v("Downloading", url)
urllib.request.urlretrieve(url, filename)
LOG.v("Download finished")
def get_jittor_version():
path = os.path.dirname(__file__)
with open(os.path.join(path, "../jittor/__init__.py"), "r", encoding='utf8') as fh:
for line in fh:
if line.startswith('__version__'):
version = line.split("'")[1]
break
else:
raise RuntimeError("Unable to find version string.")
return version
def get_str_hash(s):
import hashlib
md5 = hashlib.md5()
md5.update(s.encode())
return md5.hexdigest()
def get_cpu_version():
v = platform.processor()
try:
if os.name == 'nt':
import winreg
key_name = r"Hardware\Description\System\CentralProcessor\0"
field_name = "ProcessorNameString"
key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, key_name)
value = winreg.QueryValueEx(key, field_name)[0]
winreg.CloseKey(key)
v = value
elif platform.system() == "Darwin":
r, s = sp.getstatusoutput("sysctl -a sysctl machdep.cpu.brand_string")
if r==0:
v = s.split(":")[-1].strip()
else:
with open("/proc/cpuinfo", 'r') as f:
for l in f:
if l.startswith("model name"):
v = l.split(':')[-1].strip()
break
except:
pass
return v
def short(s):
ss = ""
for c in s:
if str.isidentifier(c) or str.isnumeric(c) \
or str.isalpha(c) or c in '.-+':
ss += c
if len(ss)>14:
return ss[:14]+'x'+get_str_hash(ss)[:2]
return ss
def find_cache_path():
from pathlib import Path
path = str(Path.home())
# jittor version key
jtv = "jt"+get_jittor_version().rsplit('.', 1)[0]
# cc version key
ccv = cc_type+get_version(cc_path)[1:-1] \
if cc_type != "cl" else cc_type
# os version key
osv = platform.platform() + platform.node()
if len(osv)>14:
osv = osv[:14] + 'x'+get_str_hash(osv)[:2]
# py version
pyv = "py"+platform.python_version()
# cpu version
cpuv = get_cpu_version()
dirs = [".cache", "jittor", jtv, ccv, pyv, osv, cpuv]
dirs = list(map(short, dirs))
cache_name = "default"
try:
if "cache_name" in os.environ:
cache_name = os.environ["cache_name"]
else:
# try to get branch name from git
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
stderr=sp.PIPE)
assert r.returncode == 0
bs = r.stdout.decode().splitlines()
for b in bs:
if b.startswith("* "): break
cache_name = b[2:]
for c in " (){}": cache_name = cache_name.replace(c, "_")
except:
pass
if os.environ.get("debug")=="1":
dirs[-1] += "_debug"
for name in os.path.normpath(cache_name).split(os.path.sep):
dirs.append(name)
os.environ["cache_name"] = cache_name
LOG.v("cache_name: ", cache_name)
path = os.path.join(path, *dirs)
os.makedirs(path, exist_ok=True)
if path not in sys.path:
sys.path.append(path)
return path
def get_version(output):
if output.endswith("mpicc"):
version = run_cmd(f"\"{output}\" --showme:version")
elif os.name == 'nt' and (
output.endswith("cl") or output.endswith("cl.exe")):
version = run_cmd(output)
else:
version = run_cmd(f"\"{output}\" --version")
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
if len(v) == 0:
v = re.findall("[0-9]+\\.[0-9]+", version)
assert len(v) != 0, f"Can not find version number from: {version}"
if 'clang' in version and platform.system() == 'Darwin':
version = "("+v[-3]+")"
else:
version = "("+v[-1]+")"
return version
def get_int_version(output):
ver = get_version(output)
ver = ver[1:-1].split('.')
ver = tuple(( int(v) for v in ver ))
return ver
def find_exe(name, check_version=True, silent=False):
output = shutil.which(name)
if not output:
raise RuntimeError(f"{name} not found")
if check_version:
version = get_version(name)
else:
version = ""
if not silent:
LOG.i(f"Found {name}{version} at {output}.")
return output
def env_or_find(name, bname, silent=False):
if name in os.environ:
path = os.environ[name]
if path != "":
version = get_version(path)
if not silent:
LOG.i(f"Found {bname}{version} at {path}")
return path
return find_exe(bname, silent=silent)
def get_cc_type(cc_path):
bname = os.path.basename(cc_path)
if "clang" in bname: return "clang"
if "icc" in bname or "icpc" in bname: return "icc"
if "g++" in bname: return "g++"
if "cl" in bname: return "cl"
LOG.f(f"Unknown cc type: {bname}")
def get_py3_config_path():
global _py3_config_path
if _py3_config_path:
return _py3_config_path
if os.name == 'nt':
return None
else:
# Search python3.x-config
# Note:
# This may be called via c++ console. In that case, sys.executable will
# be a path to the executable file, rather than python. So, we cannot infer
# python-config path only from sys.executable.
# To address this issue, we add predefined paths to search,
# - Linux: /usr/bin/python3.x-config
# - macOS (installed via homebrew): /usr/local/bin/python3.x-config
# There may be issues under other cases, e.g., installed via conda.
py3_config_paths = [
os.path.dirname(sys.executable) + f"/python3.{sys.version_info.minor}-config",
sys.executable + "-config",
f"/usr/bin/python3.{sys.version_info.minor}-config",
f"/usr/local/bin/python3.{sys.version_info.minor}-config",
f'/opt/homebrew/bin/python3.{sys.version_info.minor}-config',
os.path.dirname(sys.executable) + "/python3-config",
]
if "python_config_path" in os.environ:
py3_config_paths.insert(0, os.environ["python_config_path"])
for py3_config_path in py3_config_paths:
if os.path.isfile(py3_config_path):
break
else:
raise RuntimeError(f"python3.{sys.version_info.minor}-config "
f"not found in {py3_config_paths}, please specify "
f"enviroment variable 'python_config_path',"
f" or install python3.{sys.version_info.minor}-dev")
_py3_config_path = py3_config_path
return py3_config_path
def get_py3_include_path():
global _py3_include_path
if _py3_include_path:
return _py3_include_path
if os.name == 'nt':
# Windows
sys.executable = sys.executable.lower()
_py3_include_path = '-I"' + os.path.join(
os.path.dirname(sys.executable),
"include"
) + '"'
else:
_py3_include_path = run_cmd(get_py3_config_path()+" --includes")
return _py3_include_path
def get_py3_extension_suffix():
global _py3_extension_suffix
if _py3_extension_suffix:
return _py3_extension_suffix
if os.name == 'nt':
# Windows
_py3_extension_suffix = f".cp3{sys.version_info.minor}-win_amd64.pyd"
else:
_py3_extension_suffix = run_cmd(get_py3_config_path()+" --extension-suffix")
return _py3_extension_suffix
def get_total_mem():
if os.name == 'nt':
from ctypes import Structure, c_int32, c_uint64, sizeof, byref, windll
class MemoryStatusEx(Structure):
_fields_ = [
('length', c_int32),
('memoryLoad', c_int32),
('totalPhys', c_uint64),
('availPhys', c_uint64),
('totalPageFile', c_uint64),
('availPageFile', c_uint64),
('totalVirtual', c_uint64),
('availVirtual', c_uint64),
('availExtendedVirtual', c_uint64)]
def __init__(self):
self.length = sizeof(self)
m = MemoryStatusEx()
assert windll.kernel32.GlobalMemoryStatusEx(byref(m))
return m.totalPhys
else:
return os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
is_in_ipynb = in_ipynb()
cc = None
LOG = Logwrapper()
check_msvc_install = False
msvc_path = ""
if os.name == 'nt' and os.environ.get("cc_path", "")=="":
from pathlib import Path
msvc_path = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
cc_path = os.path.join(msvc_path, "VC", r"_\_\_\_\_\bin", "cl.exe")
check_msvc_install = True
else:
cc_path = env_or_find('cc_path', 'g++', silent=True)
os.environ["cc_path"] = cc_path
cc_type = get_cc_type(cc_path)
cache_path = find_cache_path()
_py3_config_path = None
_py3_include_path = None
_py3_extension_suffix = None
if os.name == 'nt':
from pathlib import Path
try:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
except:
pass
if check_msvc_install:
if not os.path.isfile(cc_path):
from jittor_utils import install_msvc
install_msvc.install(msvc_path)
mpath = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
if cc_path.startswith(mpath):
msvc_path = mpath
os.RTLD_NOW = os.RTLD_GLOBAL = os.RTLD_DEEPBIND = 0
path = os.path.dirname(cc_path).replace('/', '\\')
if path:
sys.path.insert(0, path)
os.environ["PATH"] = path+';'+os.environ["PATH"]
if hasattr(os, "add_dll_directory"):
os.add_dll_directory(path)