mirror of https://github.com/Jittor/Jittor
fix unsqueeze
This commit is contained in:
parent
0a5bd61bf4
commit
8022cd9c3d
|
@ -273,6 +273,7 @@ Var.start_grad = Var.detach_inplace = detach_inplace
|
|||
|
||||
def unsqueeze(x, dim):
|
||||
shape = list(x.shape)
|
||||
if dim < 0: dim += len(shape)
|
||||
assert dim <= len(shape)
|
||||
return x.reshape(shape[:dim] + [1] + shape[dim:])
|
||||
Var.unsqueeze = unsqueeze
|
||||
|
|
Loading…
Reference in New Issue