mirror of https://github.com/Jittor/Jittor
fix unsqueeze
This commit is contained in:
parent
8022cd9c3d
commit
0f0b08e30d
|
@ -273,7 +273,7 @@ Var.start_grad = Var.detach_inplace = detach_inplace
|
|||
|
||||
def unsqueeze(x, dim):
|
||||
shape = list(x.shape)
|
||||
if dim < 0: dim += len(shape)
|
||||
if dim < 0: dim += len(shape) + 1
|
||||
assert dim <= len(shape)
|
||||
return x.reshape(shape[:dim] + [1] + shape[dim:])
|
||||
Var.unsqueeze = unsqueeze
|
||||
|
|
Loading…
Reference in New Issue