mirror of https://github.com/Jittor/Jittor
improve cuda restart issue
This commit is contained in:
parent
c12549020f
commit
042c3610a3
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.18'
|
||||
__version__ = '1.3.5.19'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -170,6 +170,9 @@ string process_acl(const string& src, const string& name, const map<string,strin
|
|||
}
|
||||
})");
|
||||
}
|
||||
if (name == "profiler.cc") {
|
||||
new_src = token_replace_all(new_src, ".cc", ".tikcc");
|
||||
}
|
||||
return new_src;
|
||||
}
|
||||
|
||||
|
|
|
@ -59,11 +59,24 @@ class TestACL(unittest.TestCase):
|
|||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_conv(self):
|
||||
x = jt.rand(10, 3, 50, 50)
|
||||
w = jt.rand(4,3,3,3)
|
||||
# x = jt.rand(10, 3, 50, 50)
|
||||
# w = jt.rand(4,3,3,3)
|
||||
x = jt.rand(2, 2, 1, 1)
|
||||
w = jt.rand(2,2,1,1)
|
||||
y = jt.nn.conv2d(x, w)
|
||||
y.sync(True)
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_matmul(self):
|
||||
# x = jt.rand(10, 3, 50, 50)
|
||||
# w = jt.rand(4,3,3,3)
|
||||
x = jt.rand(10,10)
|
||||
w = jt.rand(10,10)
|
||||
y = jt.matmul(x, w)
|
||||
ny = np.matmul(x.numpy(), w.numpy())
|
||||
np.testing.assert_allclose(y.numpy(), ny)
|
||||
# y.sync(True)
|
||||
|
||||
|
||||
|
||||
def matmul(a, b):
|
||||
|
|
|
@ -66,15 +66,18 @@ def check_cuda_env():
|
|||
os.environ[key] = ":".join(new_env)
|
||||
return changed
|
||||
changed = fix_env("PATH") \
|
||||
or fix_env("LD_LIBRARY_PATH") \
|
||||
or fix_env("CUDA_HOME")
|
||||
+ fix_env("LD_LIBRARY_PATH") \
|
||||
+ fix_env("CUDA_HOME")
|
||||
if changed:
|
||||
try:
|
||||
with open("/proc/self/cmdline", "r") as f:
|
||||
argv = f.read().split("\x00")
|
||||
if len(argv[-1]) == 0: del argv[-1]
|
||||
LOG.i(f"restart {sys.executable} {argv[1:]}")
|
||||
os.execl(sys.executable, sys.executable, *argv[1:])
|
||||
with open("/proc/self/maps", "r") as f:
|
||||
cudart_loaded = "cudart" in f.read().lower()
|
||||
if cudart_loaded:
|
||||
with open("/proc/self/cmdline", "r") as f:
|
||||
argv = f.read().split("\x00")
|
||||
if len(argv[-1]) == 0: del argv[-1]
|
||||
LOG.i(f"restart {sys.executable} {argv[1:]}")
|
||||
os.execl(sys.executable, sys.executable, *argv[1:])
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue