diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py index aeee373c..320efcb2 100644 --- a/python/jittor/linalg.py +++ b/python/jittor/linalg.py @@ -486,7 +486,7 @@ def einsum(string, *args): def einsum_outshape(einsum_expr, inputs): shps = np_cpu.concatenate([in_.shape for in_ in inputs]) - p = einsum_expr.split(',') + p = einsum_expr.replace(" ", "").split(',') s = p[:-1] + p[-1].split('->') if s[-1]=='': return ()