diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py index e15480ef..a25a1333 100644 --- a/python/jittor/linalg.py +++ b/python/jittor/linalg.py @@ -1,13 +1,18 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. All Rights Reserved. +# Authors: +# Haoyang Peng +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** import jittor as jt -import numpy as np +from functools import partial #TODO:full_matrices=1 def svd(x): - from functools import partial - import copy - def T(x): - return np.swapaxes(x, -1, -2) - _dot = partial(np.einsum, '...ij,...jk->...ik') def forward_code(np, data): a = data["inputs"][0] @@ -19,6 +24,9 @@ def svd(x): np.copyto(v, tv) def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] @@ -56,15 +64,13 @@ def svd(x): t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt)) np.copyto(out, t) - s = jt.array(x.shape).data.tolist() m, n = x.shape[-2:] - k = np.min((m, n)) - k = int(k) - s1 = copy.deepcopy(s) + k = min(m, n) + s1 = list(x.shape) s1[-1] = k - s2 = copy.deepcopy(s) + s2 = list(x.shape) s2[-2] = k - s3 = s[:-2] + s3 = list(x.shape)[:-2] s3.append(k) u, s, v = jt.numpy_code( [s1, s3, s2], @@ -76,11 +82,6 @@ def svd(x): return u, s, v def eigh(x): - from functools import partial - import copy - def T(x): - return np.swapaxes(x, -1, -2) - _dot = partial(np.einsum, '...ij,...jk->...ik') def forward_code(np, data): a = data["inputs"][0] @@ -90,6 +91,9 @@ def eigh(x): np.copyto(v, tv) def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] @@ -107,10 +111,8 @@ def eigh(x): t = _dot(_dot(v, F * _dot(T(v), dout)), T(v)) np.copyto(out, t) - s = jt.array(x.shape).data.tolist() - sw = s[:-2] - sw.append(s[-1]) - sv = copy.deepcopy(s) + sw = x.shape[:-2] + x.shape[-1:] + sv = x.shape w, v = jt.numpy_code( [sw, sv], [x.dtype, x.dtype], @@ -121,10 +123,6 @@ def eigh(x): return w, v def inv(x): - from functools import partial - def T(x): - return np.swapaxes(x, -1, -2) - _dot = partial(np.einsum, '...ij,...jk->...ik') def forward_code(np, data): a = data["inputs"][0] @@ -133,6 +131,9 @@ def inv(x): np.copyto(m_a, t_a) def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] lmx = data["f_outputs"] @@ -151,10 +152,6 @@ def inv(x): return mx def pinv(x): - from functools import partial - def T(x): - return np.swapaxes(x, -1, -2) - _dot = partial(np.einsum, '...ij,...jk->...ik') def forward_code(np, data): a = data["inputs"][0] @@ -163,6 +160,9 @@ def pinv(x): np.copyto(m_a, t_a) def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] @@ -186,10 +186,6 @@ def pinv(x): return mx def det(x): - from functools import partial - def T(x): - return np.swapaxes(x, -1, -2) - _dot = partial(np.einsum, '...ij,...jk->...ik') def forward_code(np, data): a = data["inputs"][0] @@ -198,6 +194,9 @@ def det(x): np.copyto(L, tL) def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] @@ -207,7 +206,7 @@ def det(x): s = n_d * n_o * T(np.linalg.inv(inp)) np.copyto(out, s) - s = jt.array(x.shape).data.tolist() + s = x.shape x_s = s[:-2] if len(s) == 2: x_s.append(1) @@ -222,10 +221,6 @@ def det(x): return det def slogdet(x): - from functools import partial - def T(x): - return np.swapaxes(x, -1, -2) - _dot = partial(np.einsum, '...ij,...jk->...ik') def forward_code(np, data): a = data["inputs"][0] sign, m_a = data["outputs"] @@ -234,6 +229,9 @@ def slogdet(x): np.copyto(sign, sign_) def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] @@ -245,7 +243,7 @@ def slogdet(x): t = t * T(np.linalg.inv(inp)) np.copyto(out, t) - s = jt.array(x.shape).data.tolist() + s = x.shape det_s = s[:-2] if len(det_s) == 0: det_s.append(1) @@ -259,10 +257,6 @@ def slogdet(x): return sign, mx def cholesky(x): - from functools import partial - def T(x): - return np.swapaxes(x, -1, -2) - _dot = partial(np.einsum, '...ij,...jk->...ik') def forward_code(np, data): a = data["inputs"][0] @@ -271,6 +265,9 @@ def cholesky(x): np.copyto(L, tL) def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] @@ -295,10 +292,6 @@ def cholesky(x): return L def solve(a,b): - from functools import partial - def T(x): - return np.swapaxes(x, -1, -2) - _dot = partial(np.einsum, '...ij,...jk->...ik') def forward_code(np, data): a, b = data["inputs"] @@ -307,6 +300,9 @@ def solve(a,b): np.copyto(L, ans) def backward_code1(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] diff --git a/python/jittor/test/test_linalg.py b/python/jittor/test/test_linalg.py index 6824c55b..80279bda 100644 --- a/python/jittor/test/test_linalg.py +++ b/python/jittor/test/test_linalg.py @@ -1,9 +1,26 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. All Rights Reserved. +# Authors: +# Haoyang Peng +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** import jittor as jt import numpy as np -import autograd.numpy as anp -from autograd import jacobian import unittest + +try: + import autograd.numpy as anp + from autograd import jacobian + has_autograd = True +except: + has_autograd = False + +@unittest.skipIf(not has_autograd, "No autograd found.") class TestCodeOp(unittest.TestCase): def test_svd(self): def check_svd(a): @@ -153,7 +170,7 @@ class TestCodeOp(unittest.TestCase): tn = np.random.randn((4,4)).astype('float32')*10 x = jt.array(tn) x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"]) - s = jt.array(x.shape).data.tolist() + s = list(x.shape) det_s = s[:-2] if len(det_s) == 0: det_s.append(1) @@ -223,7 +240,7 @@ class TestCodeOp(unittest.TestCase): tn = np.random.randn((3, 3)).astype('float32') * 5 x = jt.array(tn) x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) - s = jt.array(x.shape).data.tolist() + s = list(x.shape) x_s = s[:-2] if len(s) == 2: x_s.append(1)