mirror of https://github.com/Jittor/Jittor
Merge branch 'lxl' of https://github.com/jittor/jittor into lxl
This commit is contained in:
commit
342214d7f7
|
@ -76,10 +76,24 @@ def repeat(x, *shape):
|
|||
x = x.broadcast(x_shape)
|
||||
elif len_x_shape > len_shape:
|
||||
rep_shape = (len_x_shape - len_shape) * [1] + shape
|
||||
|
||||
reshape_shape = []
|
||||
broadcast_shape = []
|
||||
for x_s,r_s in zip(x_shape,rep_shape):
|
||||
reshape_shape.append(1)
|
||||
reshape_shape.append(x_s)
|
||||
|
||||
broadcast_shape.append(r_s)
|
||||
broadcast_shape.append(1)
|
||||
|
||||
x = x.reshape(reshape_shape)
|
||||
x = x.broadcast(broadcast_shape)
|
||||
|
||||
tar_shape = (np.array(x_shape) * np.array(rep_shape)).tolist()
|
||||
dims = []
|
||||
for i in range(len(tar_shape)): dims.append(f"i{i}%{x_shape[i]}")
|
||||
return x.reindex(tar_shape, dims)
|
||||
|
||||
x = x.reshape(tar_shape)
|
||||
return x
|
||||
|
||||
jt.Var.repeat = repeat
|
||||
|
||||
def chunk(x, chunks, dim=0):
|
||||
|
|
Loading…
Reference in New Issue