mirror of https://github.com/Jittor/Jittor
fix model test without pth
This commit is contained in:
parent
8ffad0230d
commit
51253ac14a
|
@ -61,8 +61,11 @@ class test_models(unittest.TestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
@jt.flag_scope(use_cuda=1)
|
@jt.flag_scope(use_cuda=1)
|
||||||
@torch.no_grad()
|
|
||||||
def test_models(self):
|
def test_models(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
self.run_models()
|
||||||
|
|
||||||
|
def run_models(self):
|
||||||
def to_cuda(x):
|
def to_cuda(x):
|
||||||
if jt.has_cuda:
|
if jt.has_cuda:
|
||||||
return x.cuda()
|
return x.cuda()
|
||||||
|
|
Loading…
Reference in New Issue