mirror of https://github.com/Jittor/Jittor
cutt for seperate cache
This commit is contained in:
parent
45481adb3b
commit
c05193408a
|
@ -309,7 +309,7 @@ def install_cutt(root_folder):
|
|||
if md5 != true_md5:
|
||||
os.remove(fullname)
|
||||
shutil.rmtree(dirname)
|
||||
if not os.path.isfile(os.path.join(dirname, "lib/libcutt"+so)):
|
||||
if not os.path.isfile(os.path.join(cache_path, "libcutt"+so)):
|
||||
LOG.i("Downloading cutt...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
|
||||
|
@ -337,8 +337,7 @@ def install_cutt(root_folder):
|
|||
continue
|
||||
files2.append(f)
|
||||
cutt_flags = cc_flags+opt_flags+cutt_include
|
||||
os.makedirs(dirname+"/lib", exist_ok=True)
|
||||
compile(cc_path, cutt_flags, files2, dirname+"/lib/libcutt"+so, cuda_flags=arch_flag)
|
||||
compile(cc_path, cutt_flags, files2, cache_path+"/libcutt"+so, cuda_flags=arch_flag)
|
||||
return dirname
|
||||
|
||||
def setup_cutt():
|
||||
|
@ -362,7 +361,7 @@ def setup_cutt():
|
|||
install_cutt(cutt_path)
|
||||
cutt_home = os.path.join(cutt_path, "cutt-1.2")
|
||||
cutt_include_path = os.path.join(cutt_home, "src")
|
||||
cutt_lib_path = os.path.join(cutt_home, "lib")
|
||||
cutt_lib_path = cache_path
|
||||
|
||||
cutt_lib_name = os.path.join(cutt_lib_path, "libcutt"+so)
|
||||
assert os.path.isdir(cutt_include_path)
|
||||
|
|
|
@ -697,8 +697,8 @@ def compile_custom_ops(
|
|||
gen_name = "gen_ops_" + "_".join(headers.keys())
|
||||
if gen_name_ != "":
|
||||
gen_name = gen_name_
|
||||
if len(gen_name) > 100:
|
||||
gen_name = gen_name[:80] + "___hash" + hashlib.md5(gen_name.encode()).hexdigest()
|
||||
if len(gen_name) > 50:
|
||||
gen_name = gen_name[:50] + "___hash" + hashlib.md5(gen_name.encode()).hexdigest()[:6]
|
||||
|
||||
includes = sorted(list(set(includes)))
|
||||
includes = "".join(map(lambda x: f" -I\"{x}\" ", includes))
|
||||
|
|
|
@ -157,6 +157,12 @@ class TestBinaryOp(unittest.TestCase):
|
|||
c = a % b
|
||||
nc = a.data % b.data
|
||||
np.testing.assert_allclose(c.data, nc.data, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_pow(self):
|
||||
# win cuda 10.2 cannot pass
|
||||
a = jt.random((100,))
|
||||
b = a**3
|
||||
b.sync()
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -2,9 +2,14 @@ try:
|
|||
import fcntl
|
||||
except ImportError:
|
||||
fcntl = None
|
||||
import win32file
|
||||
import pywintypes
|
||||
_OVERLAPPED = pywintypes.OVERLAPPED()
|
||||
try:
|
||||
import win32file
|
||||
import pywintypes
|
||||
_OVERLAPPED = pywintypes.OVERLAPPED()
|
||||
except:
|
||||
LOG.f("""pywin32 package not found, please install it.
|
||||
If conda is used, please install with command:
|
||||
>>> conda install pywin32.""")
|
||||
|
||||
import os
|
||||
from jittor_utils import cache_path, LOG
|
||||
|
|
Loading…
Reference in New Issue