Merge pull request #300 from Jittor/jittor_home

add jittor_home
This commit is contained in:
Xiang-Li Li 2022-03-22 21:12:39 +08:00 committed by GitHub
commit 1987728950
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 66 additions and 36 deletions

View File

@ -9,6 +9,7 @@ import platform
from .compiler import *
from jittor_utils import run_cmd, get_version, get_int_version
from jittor_utils.misc import download_url_to_local
import jittor_utils as jit_utils
def search_file(dirs, name, prefer_version=()):
if os.name == 'nt':
@ -110,8 +111,7 @@ def setup_mkl():
LOG.v("setup mkl...")
# mkl_path = os.path.join(cache_path, "mkl")
# mkl_path decouple with cc_path
from pathlib import Path
mkl_path = os.path.join(str(Path.home()), ".cache", "jittor", "mkl")
mkl_path = os.path.join(jit_utils.home(), ".cache", "jittor", "mkl")
make_cache_dir(mkl_path)
install_mkl(mkl_path)
@ -178,8 +178,7 @@ def install_cub(root_folder):
def setup_cub():
global cub_home
cub_home = ""
from pathlib import Path
cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub")
cub_path = os.path.join(jit_utils.home(), ".cache", "jittor", "cub")
cuda_version = int(get_version(nvcc_path)[1:-1].split('.')[0])
extra_flags = ""
if cuda_version < 11:
@ -376,8 +375,7 @@ def setup_cutt():
if cutt_lib_path is None or cutt_include_path is None:
LOG.v("setup cutt...")
# cutt_path decouple with cc_path
from pathlib import Path
cutt_path = os.path.join(str(Path.home()), ".cache", "jittor", "cutt")
cutt_path = os.path.join(jit_utils.home(), ".cache", "jittor", "cutt")
make_cache_dir(cutt_path)
install_cutt(cutt_path)
@ -453,8 +451,7 @@ def setup_nccl():
if nccl_lib_path is None or nccl_include_path is None:
LOG.v("setup nccl...")
# nccl_path decouple with cc_path
from pathlib import Path
nccl_path = os.path.join(str(Path.home()), ".cache", "jittor", "nccl")
nccl_path = os.path.join(jit_utils.home(), ".cache", "jittor", "nccl")
make_cache_dir(nccl_path)
nccl_home = install_nccl(nccl_path)

View File

@ -19,7 +19,7 @@ from ctypes import cdll
from ctypes.util import find_library
import jittor_utils as jit_utils
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
from jittor_utils import LOG, run_cmd, find_exe, cc_path, cc_type, cache_path
from . import pyjt_compiler
from jittor_utils import lock
from jittor_utils import install_cuda

View File

@ -21,8 +21,9 @@ import signal
from jittor_utils import LOG
import jittor as jt
import time
import jittor_utils as jit_utils
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
dataset_root = os.path.join(jit_utils.home(), ".cache", "jittor", "dataset")
mp_log_v = os.environ.get("mp_log_v", 0)
mpi = jt.mpi
img_open_hook = HookTimer(Image, "open")

View File

@ -1,7 +1,7 @@
#!python3
import os, json
from pathlib import Path
notebook_dir = os.path.join(str(Path.home()), ".cache","jittor","notebook")
import jittor_utils as jit_utils
notebook_dir = os.path.join(jit_utils.home(), ".cache","jittor","notebook")
if not os.path.isdir(notebook_dir):
os.mkdir(notebook_dir)
dirname = os.path.dirname(__file__)

View File

@ -4,8 +4,8 @@ suffix = ""
import jittor as jt
import time
from pathlib import Path
home_path = str(Path.home())
import jittor_utils as jit_utils
home_path = jit_utils.home()
perf_path = os.path.join(home_path, ".cache", "jittor_perf")
def main():

View File

@ -10,12 +10,12 @@
import unittest
import os, sys
import jittor as jt
from pathlib import Path
import jittor_utils as jit_utils
class TestLock(unittest.TestCase):
def test(self):
if os.environ.get('lock_full_test', '0') == '1':
cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock")
cache_path = os.path.join(jit_utils.home(), ".cache", "jittor", "lock")
assert os.system(f"rm -rf {cache_path}") == 0
cmd = f"cache_name=lock {sys.executable} -m jittor.test.test_example"
else:

View File

@ -8,10 +8,10 @@ import unittest, os
import jittor as jt
from jittor import LOG
import sys
from pathlib import Path
import jittor_utils as jit_utils
dirname = os.path.join(jt.flags.jittor_path, "notebook")
notebook_dir = os.path.join(str(Path.home()), ".cache","jittor","notebook")
notebook_dir = os.path.join(jit_utils.home(), ".cache","jittor","notebook")
tests = []
for mdname in os.listdir(dirname):
if not mdname.endswith(".src.md"): continue

View File

@ -15,8 +15,8 @@ def find_jittor_path():
return path[:-len(suffix)] + ".."
def find_cache_path():
from pathlib import Path
path = str(Path.home())
import jittor_utils as jit_utils
path = jit_utils.home()
dirs = [".cache", "jittor"]
for d in dirs:
path = os.path.join(path, d)

View File

@ -49,8 +49,8 @@ data_files = [ name for name in files
LOG.i("data_files", data_files)
# compile data files
from pathlib import Path
home = str(Path.home())
import jittor_utils as jit_utils
home = jit_utils.home()
# for cc_type in ["g++", "clang"]:
# for device in ["cpu", "cuda"]:

View File

@ -7,8 +7,8 @@
# ***************************************************************
import jittor as jt
import os
from pathlib import Path
home_path = str(Path.home())
import jittor_utils as jit_utils
home_path = jit_utils.home()
def run_cmd(cmd):
print("RUN CMD:", cmd)

View File

@ -24,6 +24,41 @@ import ctypes
if platform.system() == 'Darwin':
mp.set_start_method('fork')
from pathlib import Path
import json
_jittor_home = None
def home():
global _jittor_home
if _jittor_home is not None:
return _jittor_home
src_path = os.path.join(str(Path.home()),".cache","jittor")
os.makedirs(src_path,exist_ok=True)
src_path_file = os.path.join(src_path,"config.json")
data = {}
if os.path.exists(src_path_file):
with open(src_path_file,"r") as f:
data = json.load(f)
default_path = data.get("JITTOR_HOME",str(Path.home()))
_home_path = os.environ.get("JITTOR_HOME",default_path)
if not os.path.exists(_home_path):
_home_path = default_path
_home_path = os.path.abspath(_home_path)
# LOG.i(f"Use {_home_path} as Jittor Home")
with open(src_path_file,"w") as f:
data['JITTOR_HOME'] = _home_path
json.dump(data,f)
_jittor_home = _home_path
return _home_path
class Logwrapper:
def __init__(self):
self.log_silent = int(os.environ.get("log_silent", "0"))
@ -295,8 +330,7 @@ def short(s):
return ss
def find_cache_path():
from pathlib import Path
path = str(Path.home())
path = home()
# jittor version key
jtv = "jt"+get_jittor_version().rsplit('.', 1)[0]
# cc version key
@ -508,8 +542,7 @@ LOG = Logwrapper()
check_msvc_install = False
msvc_path = ""
if os.name == 'nt' and os.environ.get("cc_path", "")=="":
from pathlib import Path
msvc_path = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
msvc_path = os.path.join(home(), ".cache", "jittor", "msvc")
cc_path = os.path.join(msvc_path, "VC", r"_\_\_\_\_\bin", "cl.exe")
check_msvc_install = True
else:
@ -523,7 +556,6 @@ _py3_include_path = None
_py3_extension_suffix = None
if os.name == 'nt':
from pathlib import Path
try:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
@ -533,7 +565,7 @@ if os.name == 'nt':
if not os.path.isfile(cc_path):
from jittor_utils import install_msvc
install_msvc.install(msvc_path)
mpath = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
mpath = os.path.join(home(), ".cache", "jittor", "msvc")
if cc_path.startswith(mpath):
msvc_path = mpath
os.RTLD_NOW = os.RTLD_GLOBAL = os.RTLD_DEEPBIND = 0

View File

@ -1,9 +1,9 @@
import os
from pathlib import Path
from collections import defaultdict
import pickle
import numpy as np
import jittor_utils
import jittor_utils as jit_utils
from jittor_utils import LOG
import sys
@ -96,7 +96,7 @@ class Hook:
hook_rand()
self.rid = 0
self.base_name = base_name
self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name)
self.base_path = os.path.join(jit_utils.home(), ".cache", "jittor", "auto_diff", base_name)
if not os.path.exists(self.base_path):
os.makedirs(self.base_path, exist_ok=True)
self.mode = 'save'

View File

@ -5,10 +5,10 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import os, sys, shutil
from pathlib import Path
import glob
import jittor_utils as jit_utils
cache_path = os.path.join(str(Path.home()), ".cache", "jittor")
cache_path = os.path.join(jit_utils.home(), ".cache", "jittor")
def callback(func, path, exc_info):
print(f"remove \"{path}\" failed.")

View File

@ -41,7 +41,7 @@ def get_cuda_driver():
return None
def has_installation():
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
jtcuda_path = os.path.join(jit_utils.home(), ".cache", "jittor", "jtcuda")
return os.path.isdir(jtcuda_path)
def install_cuda():
@ -85,7 +85,7 @@ def install_cuda():
md5 = "f16d3ff63f081031d21faec3ec8b7dac"
else:
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.0")
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
jtcuda_path = os.path.join(jit_utils.home(), ".cache", "jittor", "jtcuda")
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
if os.name=='nt': nvcc_path += '.exe'
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")