diff --git a/python/jittor/test/test_models.py b/python/jittor/test/test_models.py index cc42dfc3..55d61aa8 100644 --- a/python/jittor/test/test_models.py +++ b/python/jittor/test/test_models.py @@ -61,8 +61,11 @@ class test_models(unittest.TestCase): @unittest.skipIf(not jt.has_cuda, "Cuda not found") @jt.flag_scope(use_cuda=1) - @torch.no_grad() def test_models(self): + with torch.no_grad(): + self.run_models() + + def run_models(self): def to_cuda(x): if jt.has_cuda: return x.cuda()