mirror of https://github.com/Jittor/Jittor
fix code op
This commit is contained in:
parent
18fc4baa3d
commit
19f827d83a
|
@ -723,7 +723,10 @@ def check_cache_compile():
|
|||
|
||||
def env_or_try_find(name, bname):
|
||||
if name in os.environ:
|
||||
return os.environ[name]
|
||||
path = os.environ[name]
|
||||
version = jit_utils.get_version(path)
|
||||
LOG.i(f"Found {bname}{version} at {path}")
|
||||
return path
|
||||
return try_find_exe(bname)
|
||||
|
||||
def try_find_exe(*args):
|
||||
|
|
|
@ -172,15 +172,19 @@ def find_cache_path():
|
|||
sys.path.append(path)
|
||||
return path
|
||||
|
||||
def get_version(output):
|
||||
version = run_cmd(output+" --version")
|
||||
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
|
||||
if len(v) == 0:
|
||||
v = re.findall("[0-9]+\\.[0-9]+", version)
|
||||
assert len(v) != 0, f"Can not find version number from: {version}"
|
||||
version = "("+v[-1]+")"
|
||||
return version
|
||||
|
||||
def find_exe(name, check_version=True):
|
||||
output = run_cmd(f'which {name}', err_msg=f'{name} not found')
|
||||
if check_version:
|
||||
version = run_cmd(output+" --version")
|
||||
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
|
||||
if len(v) == 0:
|
||||
v = re.findall("[0-9]+\\.[0-9]+", version)
|
||||
assert len(v) != 0, f"Can not find version number from: {version}"
|
||||
version = "("+v[-1]+")"
|
||||
version = get_version(name)
|
||||
else:
|
||||
version = ""
|
||||
LOG.i(f"Found {name}{version} at {output}.")
|
||||
|
@ -188,7 +192,10 @@ def find_exe(name, check_version=True):
|
|||
|
||||
def env_or_find(name, bname):
|
||||
if name in os.environ:
|
||||
return os.environ[name]
|
||||
path = os.environ[name]
|
||||
version = get_version(path)
|
||||
LOG.i(f"Found {bname}{version} at {path}")
|
||||
return path
|
||||
return find_exe(bname)
|
||||
|
||||
def get_cc_type(cc_path):
|
||||
|
|
|
@ -61,7 +61,7 @@ void CodeOp::jit_prepare() {
|
|||
add_jit_define("INDIM", JK::hex1(i), JK::hex1(in[i]->shape.size()));
|
||||
add_jit_define("Tin", JK::hex1(i), in[i]->dtype());
|
||||
}
|
||||
if (use_cuda) {
|
||||
if (flags.get(NodeFlags::_cuda)) {
|
||||
jk << JK::key << "HEADER" << JK::val << cuda_header;
|
||||
ASSERT(cuda_src.size());
|
||||
jk << "\nnamespace jittor {\n";
|
||||
|
|
Loading…
Reference in New Issue