diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py index 320efcb2..522719bc 100644 --- a/python/jittor/linalg.py +++ b/python/jittor/linalg.py @@ -488,6 +488,24 @@ def einsum(string, *args): shps = np_cpu.concatenate([in_.shape for in_ in inputs]) p = einsum_expr.replace(" ", "").split(',') s = p[:-1] + p[-1].split('->') + rec_shape = [] + ellip_expr = None + const_rep = '1234567890' # assume tensor shape no more than 10 dimensions + for idx, expr in enumerate(s[:-1]): + if "..." in expr: + assert "..." in s[-1] + else: + continue + shp = inputs[idx].shape + ellipsis_pos = len(expr.replace("...", "")) + nellip_expr = const_rep[0 : len(shp) - ellipsis_pos] + if ellip_expr is None: + ellip_expr = nellip_expr + else: + assert ellip_expr == nellip_expr, "Please keep broadcast ellipsis record the same ellipsis." + s[idx] = expr.replace("...", ellip_expr) + if ellip_expr: + s[-1] = s[-1].replace("...", ellip_expr) if s[-1]=='': return () else: