mirror of https://github.com/Jittor/Jittor
polish test
This commit is contained in:
parent
b00bb4b39f
commit
899bc4d9e8
|
@ -7,7 +7,7 @@
|
||||||
# This file is subject to the terms and conditions defined in
|
# This file is subject to the terms and conditions defined in
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
__version__ = '1.2.2.0'
|
__version__ = '1.2.2.1'
|
||||||
from . import lock
|
from . import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -308,7 +308,7 @@ class DepthwiseConv(Function):
|
||||||
|
|
||||||
grid = dim3(ksize_width, ksize_height, output_channels);
|
grid = dim3(ksize_width, ksize_height, output_channels);
|
||||||
threads = dim3(std::min(output_width, block_size), crop_output_height, 1);
|
threads = dim3(std::min(output_width, block_size), crop_output_height, 1);
|
||||||
|
cudaMemsetAsync(filter_grad_p, 0, filter_grad->size);
|
||||||
|
|
||||||
KernelDepthwiseConvFilterGrad<
|
KernelDepthwiseConvFilterGrad<
|
||||||
input_type><<<grid, threads, 0>>>(
|
input_type><<<grid, threads, 0>>>(
|
||||||
|
|
|
@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
||||||
RUN pip3 uninstall jittor -y
|
RUN pip3 uninstall jittor -y
|
||||||
|
|
||||||
COPY . jittor
|
COPY . jittor
|
||||||
RUN python3.7 -m pip install jittor
|
RUN python3.7 -m pip install ./jittor
|
||||||
RUN python3.7 -m jittor.test.test_core
|
RUN python3.7 -m jittor.test.test_core
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
||||||
RUN pip3 uninstall jittor -y
|
RUN pip3 uninstall jittor -y
|
||||||
|
|
||||||
COPY . jittor
|
COPY . jittor
|
||||||
RUN python3.7 -m pip install jittor
|
RUN python3.7 -m pip install ./jittor
|
||||||
RUN python3.7 -m jittor.test.test_core
|
RUN python3.7 -m jittor.test.test_core
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
||||||
RUN pip3 uninstall jittor -y
|
RUN pip3 uninstall jittor -y
|
||||||
|
|
||||||
COPY . jittor
|
COPY . jittor
|
||||||
RUN python3.7 -m pip install jittor
|
RUN python3.7 -m pip install ./jittor
|
||||||
RUN python3.7 -m jittor.test.test_core
|
RUN python3.7 -m jittor.test.test_core
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
||||||
RUN pip3 uninstall jittor -y
|
RUN pip3 uninstall jittor -y
|
||||||
|
|
||||||
COPY . jittor
|
COPY . jittor
|
||||||
RUN python3.7 -m pip install jittor
|
RUN python3.7 -m pip install ./jittor
|
||||||
RUN python3.7 -m jittor.test.test_core
|
RUN python3.7 -m jittor.test.test_core
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ RUN pip3 install jittor --timeout 100 && python3 -m jittor.test.test_example
|
||||||
RUN pip3 uninstall jittor -y
|
RUN pip3 uninstall jittor -y
|
||||||
|
|
||||||
COPY . jittor
|
COPY . jittor
|
||||||
RUN python3 -m pip install jittor
|
RUN python3 -m pip install ./jittor
|
||||||
RUN python3 -m jittor.test.test_core
|
RUN python3 -m jittor.test.test_core
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
||||||
RUN pip3 uninstall jittor -y
|
RUN pip3 uninstall jittor -y
|
||||||
|
|
||||||
COPY . jittor
|
COPY . jittor
|
||||||
RUN python3.7 -m pip install jittor
|
RUN python3.7 -m pip install ./jittor
|
||||||
RUN python3.7 -m jittor.test.test_core
|
RUN python3.7 -m jittor.test.test_core
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,8 @@ import numpy as np
|
||||||
import jittor.models as jtmodels
|
import jittor.models as jtmodels
|
||||||
|
|
||||||
def load_parameters(m1, m2):
|
def load_parameters(m1, m2):
|
||||||
m1.save('temp.pk')
|
m1.save('/tmp/temp.pk')
|
||||||
m2.load('temp.pk')
|
m2.load('/tmp/temp.pk')
|
||||||
|
|
||||||
def compare_parameters(m1, m2):
|
def compare_parameters(m1, m2):
|
||||||
ps1 = m1.parameters()
|
ps1 = m1.parameters()
|
||||||
|
@ -23,7 +23,7 @@ def compare_parameters(m1, m2):
|
||||||
y = ps2[i].data + 1e-8
|
y = ps2[i].data + 1e-8
|
||||||
relative_error = abs(x - y) / abs(y)
|
relative_error = abs(x - y) / abs(y)
|
||||||
diff = relative_error.mean()
|
diff = relative_error.mean()
|
||||||
assert diff < 1e-4, (diff, 'backward', ps2[i].name())
|
assert diff < 1e-4, (diff, 'backward', ps2[i].name(), ps1[i].mean(), ps1[i].std(), ps2[i].mean(), ps2[i].std())
|
||||||
|
|
||||||
class TestDepthwiseConv(unittest.TestCase):
|
class TestDepthwiseConv(unittest.TestCase):
|
||||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
|
@ -62,15 +62,13 @@ class TestDepthwiseConv(unittest.TestCase):
|
||||||
jt_optimizer2 = jt.nn.SGD(jittor_model2.parameters(), lr = lr)
|
jt_optimizer2 = jt.nn.SGD(jittor_model2.parameters(), lr = lr)
|
||||||
|
|
||||||
jittor_result = jittor_model(jittor_test_img)
|
jittor_result = jittor_model(jittor_test_img)
|
||||||
loss = jittor_result.sum()
|
mask = jt.random(jittor_result.shape, jittor_result.dtype)
|
||||||
|
loss = jittor_result * mask
|
||||||
jt_optimizer.step(loss)
|
jt_optimizer.step(loss)
|
||||||
jt.sync_all(True)
|
jt.sync_all(True)
|
||||||
|
|
||||||
jittor_result2 = jittor_model2(jittor_test_img)
|
jittor_result2 = jittor_model2(jittor_test_img)
|
||||||
loss = jittor_result2.sum()
|
loss = jittor_result2 * mask
|
||||||
jt_optimizer2.step(loss)
|
|
||||||
jt.sync_all(True)
|
|
||||||
compare_parameters(jittor_model, jittor_model2)
|
|
||||||
|
|
||||||
x = jittor_result2.data + 1e-8
|
x = jittor_result2.data + 1e-8
|
||||||
y = jittor_result.data + 1e-8
|
y = jittor_result.data + 1e-8
|
||||||
|
@ -78,6 +76,11 @@ class TestDepthwiseConv(unittest.TestCase):
|
||||||
diff = relative_error.mean()
|
diff = relative_error.mean()
|
||||||
assert diff < 1e-4, (diff, 'forword')
|
assert diff < 1e-4, (diff, 'forword')
|
||||||
|
|
||||||
|
jt_optimizer2.step(loss)
|
||||||
|
jt.sync_all(True)
|
||||||
|
compare_parameters(jittor_model, jittor_model2)
|
||||||
|
|
||||||
|
|
||||||
jt.clean()
|
jt.clean()
|
||||||
jt.gc()
|
jt.gc()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue