mirror of https://github.com/Jittor/Jittor
fix einsum space interpret error.
This commit is contained in:
parent
b6c9421a9a
commit
6339f62f56
|
@ -486,7 +486,7 @@ def einsum(string, *args):
|
||||||
|
|
||||||
def einsum_outshape(einsum_expr, inputs):
|
def einsum_outshape(einsum_expr, inputs):
|
||||||
shps = np_cpu.concatenate([in_.shape for in_ in inputs])
|
shps = np_cpu.concatenate([in_.shape for in_ in inputs])
|
||||||
p = einsum_expr.split(',')
|
p = einsum_expr.replace(" ", "").split(',')
|
||||||
s = p[:-1] + p[-1].split('->')
|
s = p[:-1] + p[-1].split('->')
|
||||||
if s[-1]=='':
|
if s[-1]=='':
|
||||||
return ()
|
return ()
|
||||||
|
|
Loading…
Reference in New Issue