mirror of https://github.com/Jittor/Jittor
add repeat_interleave
This commit is contained in:
parent
fded962b54
commit
75e60d74b4
|
@ -96,6 +96,27 @@ def repeat(x, *shape):
|
|||
|
||||
jt.Var.repeat = repeat
|
||||
|
||||
def repeat_interleave(x,repeats,dim=None):
|
||||
# TODO repeats is jt.Var
|
||||
assert isinstance(repeats,int)
|
||||
if dim == None:
|
||||
x = x.reshape(-1)
|
||||
dim=0
|
||||
if dim<0: dim+=x.ndim
|
||||
|
||||
tar_shape = list(x.shape)
|
||||
x_shape = list(x.shape)
|
||||
tar_shape[dim] = tar_shape[dim]*repeats
|
||||
dims = []
|
||||
for i in range(len(tar_shape)):
|
||||
if dim==i:
|
||||
dims.append(f"i{i}/{repeats}")
|
||||
else:
|
||||
dims.append(f"i{i}")
|
||||
return x.reindex(tar_shape,dims)
|
||||
|
||||
jt.Var.repeat_interleave = repeat_interleave
|
||||
|
||||
def chunk(x, chunks, dim=0):
|
||||
r'''
|
||||
Splits a var into a specific number of chunks. Each chunk is a view of the input var.
|
||||
|
|
Loading…
Reference in New Issue