mirror of https://github.com/Jittor/Jittor
polish conda lib conflict
This commit is contained in:
parent
20da1fe7ac
commit
19ec9d0a4e
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.3'
|
||||
__version__ = '1.3.5.4'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -199,7 +199,8 @@ def setup_cuda_extern():
|
|||
"jtcuda" not in cp:
|
||||
LOG.w(f"CUDA related path found in LD_LIBRARY_PATH or PATH({check_ld_path}), "
|
||||
"This path may cause jittor found the wrong libs, "
|
||||
"please unset LD_LIBRARY_PATH and remove cuda lib path in Path. ")
|
||||
"please unset LD_LIBRARY_PATH and remove cuda lib path in Path. \n"
|
||||
"Or you can let jittor install cuda for you: `python3.x -m jittor_utils.install_cuda`")
|
||||
LOG.vv("setup cuda extern...")
|
||||
cache_path_cuda = os.path.join(cache_path, "cuda")
|
||||
cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc")
|
||||
|
|
|
@ -44,6 +44,41 @@ def has_installation():
|
|||
jtcuda_path = os.path.join(jit_utils.home(), ".cache", "jittor", "jtcuda")
|
||||
return os.path.isdir(jtcuda_path)
|
||||
|
||||
def check_cuda_env():
|
||||
if not has_installation():
|
||||
return
|
||||
if os.name == "nt":
|
||||
return
|
||||
def fix_env(key):
|
||||
env = os.environ.get(key, "")
|
||||
env = env.replace(";",":").split(":")
|
||||
new_env = []
|
||||
changed = False
|
||||
for cp in env:
|
||||
x = cp.lower()
|
||||
if "cuda" in x and "jtcuda" not in x:
|
||||
changed = True
|
||||
continue
|
||||
if "jtcuda" in x:
|
||||
new_env.insert(0, x)
|
||||
else:
|
||||
new_env.append(x)
|
||||
os.environ[key] = ":".join(new_env)
|
||||
return changed
|
||||
changed = fix_env("PATH") \
|
||||
or fix_env("LD_LIBRARY_PATH") \
|
||||
or fix_env("CUDA_HOME")
|
||||
if changed:
|
||||
try:
|
||||
with open("/proc/self/cmdline", "r") as f:
|
||||
argv = f.read().split("\x00")
|
||||
if len(argv[-1]) == 0: del argv[-1]
|
||||
LOG.i(f"restart {sys.executable} {argv[1:]}")
|
||||
os.execl(sys.executable, sys.executable, *argv[1:])
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def install_cuda():
|
||||
if "nvcc_path" in os.environ and os.environ["nvcc_path"] == "":
|
||||
return None
|
||||
|
@ -94,6 +129,7 @@ def install_cuda():
|
|||
sys.path.append(nvcc_lib_path)
|
||||
new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path
|
||||
os.environ["LD_LIBRARY_PATH"] = new_ld_path
|
||||
check_cuda_env()
|
||||
|
||||
if os.path.isfile(nvcc_path):
|
||||
return nvcc_path
|
||||
|
|
Loading…
Reference in New Issue