mirror of https://github.com/Jittor/Jittor
add epillsis (numpy version) support.
This commit is contained in:
parent
6339f62f56
commit
5eaccf538d
|
@ -488,6 +488,24 @@ def einsum(string, *args):
|
||||||
shps = np_cpu.concatenate([in_.shape for in_ in inputs])
|
shps = np_cpu.concatenate([in_.shape for in_ in inputs])
|
||||||
p = einsum_expr.replace(" ", "").split(',')
|
p = einsum_expr.replace(" ", "").split(',')
|
||||||
s = p[:-1] + p[-1].split('->')
|
s = p[:-1] + p[-1].split('->')
|
||||||
|
rec_shape = []
|
||||||
|
ellip_expr = None
|
||||||
|
const_rep = '1234567890' # assume tensor shape no more than 10 dimensions
|
||||||
|
for idx, expr in enumerate(s[:-1]):
|
||||||
|
if "..." in expr:
|
||||||
|
assert "..." in s[-1]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
shp = inputs[idx].shape
|
||||||
|
ellipsis_pos = len(expr.replace("...", ""))
|
||||||
|
nellip_expr = const_rep[0 : len(shp) - ellipsis_pos]
|
||||||
|
if ellip_expr is None:
|
||||||
|
ellip_expr = nellip_expr
|
||||||
|
else:
|
||||||
|
assert ellip_expr == nellip_expr, "Please keep broadcast ellipsis record the same ellipsis."
|
||||||
|
s[idx] = expr.replace("...", ellip_expr)
|
||||||
|
if ellip_expr:
|
||||||
|
s[-1] = s[-1].replace("...", ellip_expr)
|
||||||
if s[-1]=='':
|
if s[-1]=='':
|
||||||
return ()
|
return ()
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue