JittorMirror/python/jittor_utils/__init__.py

239 lines
7.3 KiB
Python

# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
from multiprocessing import Pool
import subprocess as sp
import os
import re
import sys
import inspect
import datetime
import contextlib
import threading
import time
from ctypes import cdll
class LogWarper:
def __init__(self):
self.log_silent = int(os.environ.get("log_silent", "0"))
self.log_v = int(os.environ.get("log_v", "0"))
def log_capture_start(self):
cc.log_capture_start()
def log_capture_stop(self):
cc.log_capture_stop()
def log_capture_read(self):
return cc.log_capture_read()
def _log(self, level, verbose, *msg):
if len(msg):
msg = " ".join([ str(m) for m in msg ])
else:
msg = str(msg)
f = inspect.currentframe()
fileline = inspect.getframeinfo(f.f_back.f_back)
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
if cc and hasattr(cc, "log"):
cc.log(fileline, level, verbose, msg)
else:
if self.log_silent or verbose > self.log_v:
return
time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f")
tid = threading.get_ident()%100
v = f" v{verbose}" if verbose else ""
print(f"[{level} {time} {tid:02}{v} {fileline}] {msg}")
def V(self, verbose, *msg): self._log('i', verbose, *msg)
def v(self, *msg): self._log('i', 1, *msg)
def vv(self, *msg): self._log('i', 10, *msg)
def vvv(self, *msg): self._log('i', 100, *msg)
def vvvv(self, *msg): self._log('i', 1000, *msg)
def i(self, *msg): self._log('i', 0, *msg)
def w(self, *msg): self._log('w', 0, *msg)
def e(self, *msg): self._log('e', 0, *msg)
def f(self, *msg): self._log('f', 0, *msg)
# check is in jupyter notebook
def in_ipynb():
try:
cfg = get_ipython().config
if 'IPKernelApp' in cfg:
return True
else:
return False
except:
return False
@contextlib.contextmanager
def simple_timer(name):
LOG.i("Timer start", name)
now = time.time()
yield
LOG.i("Time stop", name, time.time()-now)
@contextlib.contextmanager
def import_scope(flags):
prev = sys.getdlopenflags()
sys.setdlopenflags(flags)
yield
sys.setdlopenflags(prev)
def try_import_jit_utils_core(silent=None):
global cc
if cc: return
if not (silent is None):
prev = os.environ.get("log_silent", "0")
os.environ["log_silent"] = str(int(silent))
try:
# if is in notebook, must log sync, and we redirect the log
if is_in_ipynb: os.environ["log_sync"] = "1"
import jit_utils_core as cc
if is_in_ipynb:
global redirector
redirector = cc.ostream_redirect(stdout=True, stderr=True)
redirector.__enter__()
except Exception as _:
pass
if not (silent is None):
os.environ["log_silent"] = prev
def run_cmd(cmd, cwd=None, err_msg=None, print_error=True):
LOG.v(f"Run cmd: {cmd}")
if cwd:
r = sp.run(cmd, cwd=cwd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT)
else:
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT)
s = r.stdout.decode('utf8')
if r.returncode != 0:
if print_error:
sys.stderr.write(s)
if err_msg is None:
err_msg = f"Run cmd failed: {cmd}"
if not print_error:
err_msg += "\n"+s
raise Exception(err_msg)
if len(s) and s[-1] == '\n': s = s[:-1]
return s
def do_compile(args):
cmd, cache_path, jittor_path = args
try_import_jit_utils_core(True)
if cc:
return cc.cache_compile(cmd, cache_path, jittor_path)
else:
run_cmd(cmd)
return True
pool_size = 0
def run_cmds(cmds, cache_path, jittor_path):
global pool_size
if pool_size == 0:
mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
mem_gib = mem_bytes/(1024.**3)
pool_size = min(8,max(int(mem_gib // 3), 1))
LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.")
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
with Pool(pool_size) as p:
p.map(do_compile, cmds)
def download(url, filename):
from six.moves import urllib
if os.path.isfile(filename):
if os.path.getsize(filename) > 100:
return
LOG.v("Downloading", url)
urllib.request.urlretrieve(url, filename)
LOG.v("Download finished")
def find_cache_path():
from pathlib import Path
path = str(Path.home())
dirs = [".cache", "jittor", os.path.basename(cc_path)]
if os.environ.get("debug")=="1":
dirs[-1] += "_debug"
cache_name = "default"
try:
if "cache_name" in os.environ:
cache_name = os.environ["cache_name"]
else:
# try to get branch name from git
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
stderr=sp.PIPE)
assert r.returncode == 0
bs = r.stdout.decode().splitlines()
for b in bs:
if b.startswith("* "): break
cache_name = b[2:]
for c in " (){}": cache_name = cache_name.replace(c, "_")
except:
pass
for name in cache_name.split("/"):
dirs.insert(-1, name)
os.environ["cache_name"] = cache_name
LOG.v("cache_name", cache_name)
for d in dirs:
path = os.path.join(path, d)
if not os.path.isdir(path):
try:
os.mkdir(path)
except:
pass
assert os.path.isdir(path)
if path not in sys.path:
sys.path.append(path)
return path
def get_version(output):
if output.endswith("mpicc"):
version = run_cmd(output+" --showme:version")
else:
version = run_cmd(output+" --version")
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
if len(v) == 0:
v = re.findall("[0-9]+\\.[0-9]+", version)
assert len(v) != 0, f"Can not find version number from: {version}"
version = "("+v[-1]+")"
return version
def find_exe(name, check_version=True):
output = run_cmd(f'which {name}', err_msg=f'{name} not found')
if check_version:
version = get_version(name)
else:
version = ""
LOG.i(f"Found {name}{version} at {output}.")
return output
def env_or_find(name, bname):
if name in os.environ:
path = os.environ[name]
if path != "":
version = get_version(path)
LOG.i(f"Found {bname}{version} at {path}")
return path
return find_exe(bname)
def get_cc_type(cc_path):
bname = os.path.basename(cc_path)
if "clang" in bname: return "clang"
if "icc" in bname or "icpc" in bname: return "icc"
if "g++" in bname: return "g++"
LOG.f(f"Unknown cc type: {bname}")
is_in_ipynb = in_ipynb()
cc = None
LOG = LogWarper()
cc_path = env_or_find('cc_path', 'g++')
os.environ["cc_path"] = cc_path
cc_type = get_cc_type(cc_path)
cache_path = find_cache_path()