mirror of https://github.com/Jittor/Jittor
polish cuda env setup
This commit is contained in:
parent
607d13079f
commit
e014f4f25c
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.38'
|
||||
__version__ = '1.3.5.39'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -143,9 +143,9 @@ class TestCudnnConvOp(unittest.TestCase):
|
|||
|
||||
y2 = jt.nn.conv3d(x, w, None, stride, padding, dilation, group)
|
||||
dx2, dw2 = jt.grad(masky*y2, [x, w])
|
||||
np.testing.assert_allclose(y.data, y2.data, rtol=1e-5, atol=1e-3)
|
||||
np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-5, atol=1e-3)
|
||||
np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
|
||||
np.testing.assert_allclose(y.data, y2.data, rtol=1e-3, atol=1e-3)
|
||||
np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-3, atol=1e-3)
|
||||
np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-3, atol=1e-3)
|
||||
|
||||
check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
|
||||
check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
|
||||
|
@ -177,9 +177,9 @@ class TestCudnnConvOp(unittest.TestCase):
|
|||
|
||||
dx2, dw2 = jt.grad(masky*y2, [x, w])
|
||||
jt.sync_all()
|
||||
np.testing.assert_allclose(y.numpy(), y2.numpy(), rtol=1e-6, atol=1e-4)
|
||||
np.testing.assert_allclose(dx.numpy(), dx2.numpy(), rtol=1e-6, atol=1e-4)
|
||||
np.testing.assert_allclose(dw.numpy(), dw2.numpy(), rtol=1e-5, atol=1e-3)
|
||||
np.testing.assert_allclose(y.numpy(), y2.numpy(), rtol=1e-3, atol=1e-4)
|
||||
np.testing.assert_allclose(dx.numpy(), dx2.numpy(), rtol=1e-3, atol=1e-4)
|
||||
np.testing.assert_allclose(dw.numpy(), dw2.numpy(), rtol=1e-3, atol=1e-3)
|
||||
|
||||
check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
|
||||
check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
|
||||
|
|
|
@ -70,9 +70,11 @@ def check_cuda_env():
|
|||
+ fix_env("CUDA_HOME")
|
||||
if changed:
|
||||
try:
|
||||
with open("/proc/self/maps", "r") as f:
|
||||
cudart_loaded = "cudart" in f.read().lower()
|
||||
if cudart_loaded:
|
||||
# LD_LIBRARY_PATH change must triggle restart
|
||||
# because dyloader already setup
|
||||
# with open("/proc/self/maps", "r") as f:
|
||||
# cudart_loaded = "cudart" in f.read().lower()
|
||||
# if cudart_loaded:
|
||||
with open("/proc/self/cmdline", "r") as f:
|
||||
argv = f.read().split("\x00")
|
||||
if len(argv[-1]) == 0: del argv[-1]
|
||||
|
|
Loading…
Reference in New Issue