diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index a1bdc29f..40cbdba9 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -186,7 +186,7 @@ def norm(x, k, dim): if k==1: return x.abs().sum(dim) if k==2: - return (x**2).sum(dim).sqrt() + return x.sqr().sum(dim).sqrt() Var.norm = norm origin_reshape = reshape