mirror of https://github.com/Jittor/Jittor
fix model test without pth
This commit is contained in:
parent
8ffad0230d
commit
51253ac14a
|
@ -61,8 +61,11 @@ class test_models(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
@torch.no_grad()
|
||||
def test_models(self):
|
||||
with torch.no_grad():
|
||||
self.run_models()
|
||||
|
||||
def run_models(self):
|
||||
def to_cuda(x):
|
||||
if jt.has_cuda:
|
||||
return x.cuda()
|
||||
|
|
Loading…
Reference in New Issue