squeeze -1 dim

This commit is contained in:
Dun Liang 2020-05-06 18:58:42 +08:00
parent 1a0d7e3810
commit 49950eda35
1 changed files with 2 additions and 1 deletions

View File

@ -404,7 +404,8 @@ Var.unsqueeze = unsqueeze
def squeeze(x, dim):
shape = list(x.shape)
assert dim < len(shape)
if dim < 0: dim += len(shape)
assert dim < len(shape) and dim >= 0
assert shape[dim] == 1
return x.reshape(shape[:dim] + shape[dim+1:])
Var.squeeze = squeeze