better cache key

This commit is contained in:
Dun Liang 2021-10-13 19:16:46 +08:00
parent 5683d338e8
commit 102ffa31a5
6 changed files with 122 additions and 20 deletions

View File

@ -26,8 +26,7 @@ with lock.lock_scope():
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
if core.get_device_count() == 0:
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
if has_cuda:
from .compile_extern import cudnn, curand, cublas
from .compile_extern import cudnn, curand, cublas
from .init_cupy import numpy2cupy
import contextlib

View File

@ -397,7 +397,8 @@ def install_nccl(root_folder):
if os.path.isdir(dirname):
shutil.rmtree(dirname)
if not os.path.isfile(os.path.join(dirname, "build", "lib", "libnccl.so")):
LOG.i("Downloading nccl...")
if not os.path.isfile(os.path.join(root_folder, filename)):
LOG.i("Downloading nccl...")
download_url_to_local(url, filename, root_folder, true_md5)
if core.get_device_count() == 0:
@ -547,6 +548,7 @@ if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
except:
pass
cudnn = cublas = curand = None
setup_mpi()
in_mpi = inside_mpi()
rank = mpi.world_rank() if in_mpi else 0

View File

@ -724,7 +724,7 @@ def compile_custom_ops(
return gen_src.replace(anchor_str, anchor_str+insert_str, 1)
for name in pyjt_includes:
LOG.i("handle pyjt_include", name)
LOG.v("handle pyjt_include ", name)
bname = os.path.basename(name).split(".")[0]
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+"_"+bname+".cc")
pyjt_compiler.compile_single(name, gen_src_fname)
@ -870,7 +870,7 @@ def check_cache_compile():
files = [ x.replace('/', '\\') for x in files ]
global jit_utils_core_files
jit_utils_core_files = files
recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, jit_utils.cache_path+'/jit_utils_core'+extension_suffix, True)
if recompile and jit_utils.cc:
LOG.e("jit_utils updated, please restart jittor.")
sys.exit(0)
@ -879,7 +879,7 @@ def check_cache_compile():
jit_utils.try_import_jit_utils_core()
assert jit_utils.cc
# recompile, generate cache key
compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
compile(cc_path, cc_flags+f" {opt_flags} ", files, jit_utils.cache_path+'/jit_utils_core'+extension_suffix, True)
def env_or_try_find(name, bname):
if name in os.environ:
@ -975,6 +975,24 @@ gdb_path = env_or_try_find('gdb_path', 'gdb')
addr2line_path = try_find_exe('addr2line')
has_pybt = check_pybt(gdb_path, python_path)
if nvcc_path:
# gen cuda key for cache_path
cu = "cu"
v = jit_utils.get_version(nvcc_path)[1:-1]
cu += v
try:
r, s = sp.getstatusoutput(f"{sys.executable} -m jittor_utils.query_cuda_cc")
if r==0:
s = sorted(list(set(s.strip().split())))
cu += "_sm_" + "_".join(s)
if "cuda_arch" not in os.environ:
os.environ["cuda_arch"] = " ".join(cu)
except:
pass
LOG.i("cuda key:", cu)
cache_path = os.path.join(cache_path, cu)
sys.path.append(cache_path)
def check_clang_latest_supported_cpu():
output = run_cmd('clang --print-supported-cpus')
@ -1007,10 +1025,10 @@ if platform.system() == 'Darwin':
opt_flags = ""
py_include = jit_utils.get_py3_include_path()
LOG.i(f"py_include: {py_include}")
LOG.v(f"py_include: {py_include}")
extension_suffix = jit_utils.get_py3_extension_suffix()
lib_suffix = extension_suffix.rsplit(".", 1)[0]
LOG.i(f"extension_suffix: {extension_suffix}")
LOG.v(f"extension_suffix: {extension_suffix}")
so = ".so" if os.name != 'nt' else ".dll"
@ -1191,7 +1209,7 @@ jit_src = gen_jit_op_maker(op_headers)
LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
f.write(jit_src)
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" '
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" -L\"{jit_utils.cache_path}\" '
# gen pyjt
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)

View File

@ -237,12 +237,71 @@ def download(url, filename):
urllib.request.urlretrieve(url, filename)
LOG.v("Download finished")
def get_jittor_version():
path = os.path.dirname(__file__)
with open(os.path.join(path, "../jittor/__init__.py"), "r", encoding='utf8') as fh:
for line in fh:
if line.startswith('__version__'):
version = line.split("'")[1]
break
else:
raise RuntimeError("Unable to find version string.")
return version
def get_str_hash(s):
import hashlib
md5 = hashlib.md5()
md5.update(s.encode())
return md5.hexdigest()
def get_cpu_version():
v = platform.processor()
try:
if os.name == 'nt':
import winreg
key_name = r"Hardware\Description\System\CentralProcessor\0"
field_name = "ProcessorNameString"
key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, key_name)
value = winreg.QueryValueEx(key, field_name)[0]
winreg.CloseKey(key)
v = value
else:
with open("/proc/cpuinfo", 'r') as f:
for l in f:
if l.startswith("model name"):
v = l.split(':')[-1].strip()
break
except:
pass
return v
def short(s):
ss = ""
for c in s:
if str.isidentifier(c) or str.isnumeric(c) \
or str.isalpha(c) or c in '.-+':
ss += c
if len(ss)>14:
return ss[:14]+'x'+get_str_hash(ss)[:2]
return ss
def find_cache_path():
from pathlib import Path
path = str(Path.home())
dirs = [".cache", "jittor", os.path.basename(cc_path)]
if os.environ.get("debug")=="1":
dirs[-1] += "_debug"
# jittor version key
jtv = "jt"+get_jittor_version().rsplit('.', 1)[0]
# cc version key
ccv = cc_type+get_version(cc_path)[1:-1]
# os version key
osv = platform.platform() + platform.node()
if len(osv)>14:
osv = osv[:14] + 'x'+get_str_hash(osv)[:2]
# py version
pyv = "py"+platform.python_version()
# cpu version
cpuv = get_cpu_version()
dirs = [".cache", "jittor", jtv, ccv, pyv, osv, cpuv]
dirs = list(map(short, dirs))
cache_name = "default"
try:
if "cache_name" in os.environ:
@ -260,18 +319,14 @@ def find_cache_path():
for c in " (){}": cache_name = cache_name.replace(c, "_")
except:
pass
if os.environ.get("debug")=="1":
dirs[-1] += "_debug"
for name in os.path.normpath(cache_name).split(os.path.sep):
dirs.insert(-1, name)
os.environ["cache_name"] = cache_name
LOG.v("cache_name: ", cache_name)
for d in dirs:
path = os.path.join(path, d)
if not os.path.isdir(path):
try:
os.mkdir(path)
except:
pass
assert os.path.isdir(path)
path = os.path.join(path, *dirs)
os.makedirs(path, exist_ok=True)
if path not in sys.path:
sys.path.append(path)
return path

View File

@ -6,6 +6,7 @@
# ***************************************************************
import os, sys, shutil
from pathlib import Path
import glob
cache_path = os.path.join(str(Path.home()), ".cache", "jittor")
@ -29,6 +30,8 @@ def clean_cuda():
rmtree(cache_path+"/cutt")
rmtree(cache_path+"/cub")
rmtree(cache_path+"/nccl")
fs = glob.glob(cache_path+"/jt*")
for f in fs: rmtree(f)
def clean_dataset():
rmtree(cache_path+"/dataset")

View File

@ -0,0 +1,25 @@
import ctypes
import os
if "CUDA_VISIBLE_DEVICES" in os.environ:
del os.environ["CUDA_VISIBLE_DEVICES"]
if os.name == 'nt':
cuda_driver = ctypes.CDLL("nvcuda")
else:
cuda_driver = ctypes.CDLL("libcuda.so")
driver_version = ctypes.c_int()
r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version))
assert r == 0
v = driver_version.value
dcount = ctypes.c_int()
cuda_driver.cuInit(0)
r = cuda_driver.cuDeviceGetCount(ctypes.byref(dcount))
for i in range(dcount.value):
dev = ctypes.c_void_p()
major = ctypes.c_int()
minor = ctypes.c_int()
assert 0 == cuda_driver.cuDeviceGet(ctypes.byref(dev), i)
assert 0 == cuda_driver.cuDeviceGetAttribute(ctypes.byref(major), 75, dev)
assert 0 == cuda_driver.cuDeviceGetAttribute(ctypes.byref(minor), 76, dev)
print(major.value*10+minor.value)