diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 157895bd..8127b138 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -273,6 +273,7 @@ Var.start_grad = Var.detach_inplace = detach_inplace def unsqueeze(x, dim): shape = list(x.shape) + if dim < 0: dim += len(shape) assert dim <= len(shape) return x.reshape(shape[:dim] + [1] + shape[dim:]) Var.unsqueeze = unsqueeze