mirror of https://github.com/Jittor/Jittor
fix unsqueeze
This commit is contained in:
parent
0a5bd61bf4
commit
8022cd9c3d
|
@ -273,6 +273,7 @@ Var.start_grad = Var.detach_inplace = detach_inplace
|
||||||
|
|
||||||
def unsqueeze(x, dim):
|
def unsqueeze(x, dim):
|
||||||
shape = list(x.shape)
|
shape = list(x.shape)
|
||||||
|
if dim < 0: dim += len(shape)
|
||||||
assert dim <= len(shape)
|
assert dim <= len(shape)
|
||||||
return x.reshape(shape[:dim] + [1] + shape[dim:])
|
return x.reshape(shape[:dim] + [1] + shape[dim:])
|
||||||
Var.unsqueeze = unsqueeze
|
Var.unsqueeze = unsqueeze
|
||||||
|
|
Loading…
Reference in New Issue