add repeat_interleave

This commit is contained in:
li-xl 2020-12-15 15:17:51 +08:00 committed by Dun Liang
parent fded962b54
commit 75e60d74b4
1 changed files with 21 additions and 0 deletions

View File

@ -96,6 +96,27 @@ def repeat(x, *shape):
jt.Var.repeat = repeat
def repeat_interleave(x,repeats,dim=None):
# TODO repeats is jt.Var
assert isinstance(repeats,int)
if dim == None:
x = x.reshape(-1)
dim=0
if dim<0: dim+=x.ndim
tar_shape = list(x.shape)
x_shape = list(x.shape)
tar_shape[dim] = tar_shape[dim]*repeats
dims = []
for i in range(len(tar_shape)):
if dim==i:
dims.append(f"i{i}/{repeats}")
else:
dims.append(f"i{i}")
return x.reindex(tar_shape,dims)
jt.Var.repeat_interleave = repeat_interleave
def chunk(x, chunks, dim=0):
r'''
Splits a var into a specific number of chunks. Each chunk is a view of the input var.