mirror of https://github.com/Jittor/Jittor
squeeze -1 dim
This commit is contained in:
parent
1a0d7e3810
commit
49950eda35
|
@ -404,7 +404,8 @@ Var.unsqueeze = unsqueeze
|
|||
|
||||
def squeeze(x, dim):
|
||||
shape = list(x.shape)
|
||||
assert dim < len(shape)
|
||||
if dim < 0: dim += len(shape)
|
||||
assert dim < len(shape) and dim >= 0
|
||||
assert shape[dim] == 1
|
||||
return x.reshape(shape[:dim] + shape[dim+1:])
|
||||
Var.squeeze = squeeze
|
||||
|
|
Loading…
Reference in New Issue