fix unsqueeze

This commit is contained in:
zhouwy19 2020-07-06 22:25:03 +08:00
parent 8022cd9c3d
commit 0f0b08e30d
1 changed files with 1 additions and 1 deletions

View File

@ -273,7 +273,7 @@ Var.start_grad = Var.detach_inplace = detach_inplace
def unsqueeze(x, dim):
shape = list(x.shape)
if dim < 0: dim += len(shape)
if dim < 0: dim += len(shape) + 1
assert dim <= len(shape)
return x.reshape(shape[:dim] + [1] + shape[dim:])
Var.unsqueeze = unsqueeze