fix model test without pth

This commit is contained in:
Dun Liang 2020-06-07 12:24:46 +08:00
parent 8ffad0230d
commit 51253ac14a
1 changed files with 4 additions and 1 deletions

View File

@ -61,8 +61,11 @@ class test_models(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found") @unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1) @jt.flag_scope(use_cuda=1)
@torch.no_grad()
def test_models(self): def test_models(self):
with torch.no_grad():
self.run_models()
def run_models(self):
def to_cuda(x): def to_cuda(x):
if jt.has_cuda: if jt.has_cuda:
return x.cuda() return x.cuda()