From 8022cd9c3d12c583f8bdd33d082ccfc6ac817810 Mon Sep 17 00:00:00 2001 From: zhouwy19 Date: Mon, 6 Jul 2020 22:15:42 +0800 Subject: [PATCH] fix unsqueeze --- python/jittor/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 157895bd..8127b138 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -273,6 +273,7 @@ Var.start_grad = Var.detach_inplace = detach_inplace def unsqueeze(x, dim): shape = list(x.shape) + if dim < 0: dim += len(shape) assert dim <= len(shape) return x.reshape(shape[:dim] + [1] + shape[dim:]) Var.unsqueeze = unsqueeze