mirror of https://github.com/Jittor/Jittor
better cache key
This commit is contained in:
parent
5683d338e8
commit
102ffa31a5
|
@ -26,8 +26,7 @@ with lock.lock_scope():
|
|||
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
|
||||
if core.get_device_count() == 0:
|
||||
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
|
||||
if has_cuda:
|
||||
from .compile_extern import cudnn, curand, cublas
|
||||
from .compile_extern import cudnn, curand, cublas
|
||||
from .init_cupy import numpy2cupy
|
||||
|
||||
import contextlib
|
||||
|
|
|
@ -397,7 +397,8 @@ def install_nccl(root_folder):
|
|||
if os.path.isdir(dirname):
|
||||
shutil.rmtree(dirname)
|
||||
if not os.path.isfile(os.path.join(dirname, "build", "lib", "libnccl.so")):
|
||||
LOG.i("Downloading nccl...")
|
||||
if not os.path.isfile(os.path.join(root_folder, filename)):
|
||||
LOG.i("Downloading nccl...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
|
||||
if core.get_device_count() == 0:
|
||||
|
@ -547,6 +548,7 @@ if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
|||
except:
|
||||
pass
|
||||
|
||||
cudnn = cublas = curand = None
|
||||
setup_mpi()
|
||||
in_mpi = inside_mpi()
|
||||
rank = mpi.world_rank() if in_mpi else 0
|
||||
|
|
|
@ -724,7 +724,7 @@ def compile_custom_ops(
|
|||
return gen_src.replace(anchor_str, anchor_str+insert_str, 1)
|
||||
|
||||
for name in pyjt_includes:
|
||||
LOG.i("handle pyjt_include", name)
|
||||
LOG.v("handle pyjt_include ", name)
|
||||
bname = os.path.basename(name).split(".")[0]
|
||||
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+"_"+bname+".cc")
|
||||
pyjt_compiler.compile_single(name, gen_src_fname)
|
||||
|
@ -870,7 +870,7 @@ def check_cache_compile():
|
|||
files = [ x.replace('/', '\\') for x in files ]
|
||||
global jit_utils_core_files
|
||||
jit_utils_core_files = files
|
||||
recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
|
||||
recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, jit_utils.cache_path+'/jit_utils_core'+extension_suffix, True)
|
||||
if recompile and jit_utils.cc:
|
||||
LOG.e("jit_utils updated, please restart jittor.")
|
||||
sys.exit(0)
|
||||
|
@ -879,7 +879,7 @@ def check_cache_compile():
|
|||
jit_utils.try_import_jit_utils_core()
|
||||
assert jit_utils.cc
|
||||
# recompile, generate cache key
|
||||
compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
|
||||
compile(cc_path, cc_flags+f" {opt_flags} ", files, jit_utils.cache_path+'/jit_utils_core'+extension_suffix, True)
|
||||
|
||||
def env_or_try_find(name, bname):
|
||||
if name in os.environ:
|
||||
|
@ -975,6 +975,24 @@ gdb_path = env_or_try_find('gdb_path', 'gdb')
|
|||
addr2line_path = try_find_exe('addr2line')
|
||||
has_pybt = check_pybt(gdb_path, python_path)
|
||||
|
||||
if nvcc_path:
|
||||
# gen cuda key for cache_path
|
||||
cu = "cu"
|
||||
v = jit_utils.get_version(nvcc_path)[1:-1]
|
||||
cu += v
|
||||
try:
|
||||
r, s = sp.getstatusoutput(f"{sys.executable} -m jittor_utils.query_cuda_cc")
|
||||
if r==0:
|
||||
s = sorted(list(set(s.strip().split())))
|
||||
cu += "_sm_" + "_".join(s)
|
||||
if "cuda_arch" not in os.environ:
|
||||
os.environ["cuda_arch"] = " ".join(cu)
|
||||
except:
|
||||
pass
|
||||
LOG.i("cuda key:", cu)
|
||||
cache_path = os.path.join(cache_path, cu)
|
||||
sys.path.append(cache_path)
|
||||
|
||||
|
||||
def check_clang_latest_supported_cpu():
|
||||
output = run_cmd('clang --print-supported-cpus')
|
||||
|
@ -1007,10 +1025,10 @@ if platform.system() == 'Darwin':
|
|||
opt_flags = ""
|
||||
|
||||
py_include = jit_utils.get_py3_include_path()
|
||||
LOG.i(f"py_include: {py_include}")
|
||||
LOG.v(f"py_include: {py_include}")
|
||||
extension_suffix = jit_utils.get_py3_extension_suffix()
|
||||
lib_suffix = extension_suffix.rsplit(".", 1)[0]
|
||||
LOG.i(f"extension_suffix: {extension_suffix}")
|
||||
LOG.v(f"extension_suffix: {extension_suffix}")
|
||||
so = ".so" if os.name != 'nt' else ".dll"
|
||||
|
||||
|
||||
|
@ -1191,7 +1209,7 @@ jit_src = gen_jit_op_maker(op_headers)
|
|||
LOG.vvvv(jit_src)
|
||||
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
|
||||
f.write(jit_src)
|
||||
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" '
|
||||
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" -L\"{jit_utils.cache_path}\" '
|
||||
# gen pyjt
|
||||
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
|
||||
|
||||
|
|
|
@ -237,12 +237,71 @@ def download(url, filename):
|
|||
urllib.request.urlretrieve(url, filename)
|
||||
LOG.v("Download finished")
|
||||
|
||||
def get_jittor_version():
|
||||
path = os.path.dirname(__file__)
|
||||
with open(os.path.join(path, "../jittor/__init__.py"), "r", encoding='utf8') as fh:
|
||||
for line in fh:
|
||||
if line.startswith('__version__'):
|
||||
version = line.split("'")[1]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
return version
|
||||
|
||||
def get_str_hash(s):
|
||||
import hashlib
|
||||
md5 = hashlib.md5()
|
||||
md5.update(s.encode())
|
||||
return md5.hexdigest()
|
||||
|
||||
def get_cpu_version():
|
||||
v = platform.processor()
|
||||
try:
|
||||
if os.name == 'nt':
|
||||
import winreg
|
||||
key_name = r"Hardware\Description\System\CentralProcessor\0"
|
||||
field_name = "ProcessorNameString"
|
||||
key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, key_name)
|
||||
value = winreg.QueryValueEx(key, field_name)[0]
|
||||
winreg.CloseKey(key)
|
||||
v = value
|
||||
else:
|
||||
with open("/proc/cpuinfo", 'r') as f:
|
||||
for l in f:
|
||||
if l.startswith("model name"):
|
||||
v = l.split(':')[-1].strip()
|
||||
break
|
||||
except:
|
||||
pass
|
||||
return v
|
||||
|
||||
def short(s):
|
||||
ss = ""
|
||||
for c in s:
|
||||
if str.isidentifier(c) or str.isnumeric(c) \
|
||||
or str.isalpha(c) or c in '.-+':
|
||||
ss += c
|
||||
if len(ss)>14:
|
||||
return ss[:14]+'x'+get_str_hash(ss)[:2]
|
||||
return ss
|
||||
|
||||
def find_cache_path():
|
||||
from pathlib import Path
|
||||
path = str(Path.home())
|
||||
dirs = [".cache", "jittor", os.path.basename(cc_path)]
|
||||
if os.environ.get("debug")=="1":
|
||||
dirs[-1] += "_debug"
|
||||
# jittor version key
|
||||
jtv = "jt"+get_jittor_version().rsplit('.', 1)[0]
|
||||
# cc version key
|
||||
ccv = cc_type+get_version(cc_path)[1:-1]
|
||||
# os version key
|
||||
osv = platform.platform() + platform.node()
|
||||
if len(osv)>14:
|
||||
osv = osv[:14] + 'x'+get_str_hash(osv)[:2]
|
||||
# py version
|
||||
pyv = "py"+platform.python_version()
|
||||
# cpu version
|
||||
cpuv = get_cpu_version()
|
||||
dirs = [".cache", "jittor", jtv, ccv, pyv, osv, cpuv]
|
||||
dirs = list(map(short, dirs))
|
||||
cache_name = "default"
|
||||
try:
|
||||
if "cache_name" in os.environ:
|
||||
|
@ -260,18 +319,14 @@ def find_cache_path():
|
|||
for c in " (){}": cache_name = cache_name.replace(c, "_")
|
||||
except:
|
||||
pass
|
||||
if os.environ.get("debug")=="1":
|
||||
dirs[-1] += "_debug"
|
||||
for name in os.path.normpath(cache_name).split(os.path.sep):
|
||||
dirs.insert(-1, name)
|
||||
os.environ["cache_name"] = cache_name
|
||||
LOG.v("cache_name: ", cache_name)
|
||||
for d in dirs:
|
||||
path = os.path.join(path, d)
|
||||
if not os.path.isdir(path):
|
||||
try:
|
||||
os.mkdir(path)
|
||||
except:
|
||||
pass
|
||||
assert os.path.isdir(path)
|
||||
path = os.path.join(path, *dirs)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
if path not in sys.path:
|
||||
sys.path.append(path)
|
||||
return path
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
# ***************************************************************
|
||||
import os, sys, shutil
|
||||
from pathlib import Path
|
||||
import glob
|
||||
|
||||
cache_path = os.path.join(str(Path.home()), ".cache", "jittor")
|
||||
|
||||
|
@ -29,6 +30,8 @@ def clean_cuda():
|
|||
rmtree(cache_path+"/cutt")
|
||||
rmtree(cache_path+"/cub")
|
||||
rmtree(cache_path+"/nccl")
|
||||
fs = glob.glob(cache_path+"/jt*")
|
||||
for f in fs: rmtree(f)
|
||||
|
||||
def clean_dataset():
|
||||
rmtree(cache_path+"/dataset")
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
import ctypes
|
||||
import os
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
if os.name == 'nt':
|
||||
cuda_driver = ctypes.CDLL("nvcuda")
|
||||
else:
|
||||
cuda_driver = ctypes.CDLL("libcuda.so")
|
||||
driver_version = ctypes.c_int()
|
||||
r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version))
|
||||
assert r == 0
|
||||
v = driver_version.value
|
||||
|
||||
dcount = ctypes.c_int()
|
||||
cuda_driver.cuInit(0)
|
||||
r = cuda_driver.cuDeviceGetCount(ctypes.byref(dcount))
|
||||
|
||||
for i in range(dcount.value):
|
||||
dev = ctypes.c_void_p()
|
||||
major = ctypes.c_int()
|
||||
minor = ctypes.c_int()
|
||||
assert 0 == cuda_driver.cuDeviceGet(ctypes.byref(dev), i)
|
||||
assert 0 == cuda_driver.cuDeviceGetAttribute(ctypes.byref(major), 75, dev)
|
||||
assert 0 == cuda_driver.cuDeviceGetAttribute(ctypes.byref(minor), 76, dev)
|
||||
print(major.value*10+minor.value)
|
Loading…
Reference in New Issue