mirror of https://github.com/Jittor/Jittor
polish initialize order
This commit is contained in:
parent
40cb853e21
commit
d85af13024
|
@ -1111,6 +1111,7 @@ at_beginning = [
|
|||
"src/event_queue.cc",
|
||||
"src/mem/allocator/sfrl_allocator.cc",
|
||||
"src/mem/allocator.cc",
|
||||
"src/misc/nano_string.cc",
|
||||
]
|
||||
at_last = [
|
||||
"src/profiler/profiler.cc",
|
||||
|
|
|
@ -138,6 +138,7 @@ class TestCudnnConvOp(unittest.TestCase):
|
|||
y = jt.nn.conv3d(x, w, None, stride, padding, dilation, group)
|
||||
masky = jt.rand_like(y)
|
||||
dx, dw = jt.grad(masky*y, [x, w])
|
||||
jt.sync_all()
|
||||
|
||||
y2 = jt.nn.conv3d(x, w, None, stride, padding, dilation, group)
|
||||
dx2, dw2 = jt.grad(masky*y2, [x, w])
|
||||
|
@ -159,19 +160,23 @@ class TestCudnnConvOp(unittest.TestCase):
|
|||
with jt.flag_scope(use_cuda=1):
|
||||
x = jt.random(xshape)
|
||||
w = jt.random(wshape)
|
||||
jt.sync_all()
|
||||
|
||||
y2 = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation)
|
||||
jt.sync_all()
|
||||
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
# y = jt.cudnn.ops.cudnn_conv3d_backward_x(w, x, *y2.shape[2:], *stride, *padding, *dilation, group)
|
||||
y = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation)
|
||||
masky = jt.rand_like(y)
|
||||
dx, dw = jt.grad(masky*y, [x, w])
|
||||
jt.sync_all()
|
||||
|
||||
dx2, dw2 = jt.grad(masky*y2, [x, w])
|
||||
np.testing.assert_allclose(y.data, y2.data, rtol=1e-6, atol=1e-4)
|
||||
np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-6, atol=1e-4)
|
||||
np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
|
||||
jt.sync_all()
|
||||
np.testing.assert_allclose(y.numpy(), y2.numpy(), rtol=1e-6, atol=1e-4)
|
||||
np.testing.assert_allclose(dx.numpy(), dx2.numpy(), rtol=1e-6, atol=1e-4)
|
||||
np.testing.assert_allclose(dw.numpy(), dw2.numpy(), rtol=1e-5, atol=1e-3)
|
||||
|
||||
check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
|
||||
check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
|
||||
|
|
|
@ -8,7 +8,7 @@ def install(path):
|
|||
LOG.i("Installing MSVC...")
|
||||
filename = "msvc.zip"
|
||||
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename
|
||||
md5sum = "13d420e5919e5ec81155fe923b3d1a07"
|
||||
md5sum = "0fd71436c034808649b24baf28998ccc"
|
||||
download_url_to_local(url, filename, path, md5sum)
|
||||
fullname = os.path.join(path, filename)
|
||||
import zipfile
|
||||
|
|
Loading…
Reference in New Issue