fix unsqueeze

This commit is contained in:
zhouwy19 2020-07-06 22:15:42 +08:00
parent 0a5bd61bf4
commit 8022cd9c3d
1 changed files with 1 additions and 0 deletions

View File

@ -273,6 +273,7 @@ Var.start_grad = Var.detach_inplace = detach_inplace
def unsqueeze(x, dim): def unsqueeze(x, dim):
shape = list(x.shape) shape = list(x.shape)
if dim < 0: dim += len(shape)
assert dim <= len(shape) assert dim <= len(shape)
return x.reshape(shape[:dim] + [1] + shape[dim:]) return x.reshape(shape[:dim] + [1] + shape[dim:])
Var.unsqueeze = unsqueeze Var.unsqueeze = unsqueeze