Merge branch 'lxl' of https://github.com/jittor/jittor into lxl

This commit is contained in:
li-xl 2020-11-26 17:34:40 +08:00
commit 342214d7f7
1 changed files with 17 additions and 3 deletions

View File

@ -76,10 +76,24 @@ def repeat(x, *shape):
x = x.broadcast(x_shape)
elif len_x_shape > len_shape:
rep_shape = (len_x_shape - len_shape) * [1] + shape
reshape_shape = []
broadcast_shape = []
for x_s,r_s in zip(x_shape,rep_shape):
reshape_shape.append(1)
reshape_shape.append(x_s)
broadcast_shape.append(r_s)
broadcast_shape.append(1)
x = x.reshape(reshape_shape)
x = x.broadcast(broadcast_shape)
tar_shape = (np.array(x_shape) * np.array(rep_shape)).tolist()
dims = []
for i in range(len(tar_shape)): dims.append(f"i{i}%{x_shape[i]}")
return x.reindex(tar_shape, dims)
x = x.reshape(tar_shape)
return x
jt.Var.repeat = repeat
def chunk(x, chunks, dim=0):