This commit is contained in:
Dun Liang 2020-08-11 20:30:26 +08:00
parent 45aba7c99e
commit eda31dcacf
2 changed files with 64 additions and 51 deletions

View File

@ -1,13 +1,18 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Authors:
# Haoyang Peng
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt import jittor as jt
import numpy as np from functools import partial
#TODO:full_matrices=1 #TODO:full_matrices=1
def svd(x): def svd(x):
from functools import partial
import copy
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
@ -19,6 +24,9 @@ def svd(x):
np.copyto(v, tv) np.copyto(v, tv)
def backward_code(np, data): def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
inp = data["inputs"][0] inp = data["inputs"][0]
@ -56,15 +64,13 @@ def svd(x):
t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt)) t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
np.copyto(out, t) np.copyto(out, t)
s = jt.array(x.shape).data.tolist()
m, n = x.shape[-2:] m, n = x.shape[-2:]
k = np.min((m, n)) k = min(m, n)
k = int(k) s1 = list(x.shape)
s1 = copy.deepcopy(s)
s1[-1] = k s1[-1] = k
s2 = copy.deepcopy(s) s2 = list(x.shape)
s2[-2] = k s2[-2] = k
s3 = s[:-2] s3 = list(x.shape)[:-2]
s3.append(k) s3.append(k)
u, s, v = jt.numpy_code( u, s, v = jt.numpy_code(
[s1, s3, s2], [s1, s3, s2],
@ -76,11 +82,6 @@ def svd(x):
return u, s, v return u, s, v
def eigh(x): def eigh(x):
from functools import partial
import copy
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
@ -90,6 +91,9 @@ def eigh(x):
np.copyto(v, tv) np.copyto(v, tv)
def backward_code(np, data): def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
inp = data["inputs"][0] inp = data["inputs"][0]
@ -107,10 +111,8 @@ def eigh(x):
t = _dot(_dot(v, F * _dot(T(v), dout)), T(v)) t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
np.copyto(out, t) np.copyto(out, t)
s = jt.array(x.shape).data.tolist() sw = x.shape[:-2] + x.shape[-1:]
sw = s[:-2] sv = x.shape
sw.append(s[-1])
sv = copy.deepcopy(s)
w, v = jt.numpy_code( w, v = jt.numpy_code(
[sw, sv], [sw, sv],
[x.dtype, x.dtype], [x.dtype, x.dtype],
@ -121,10 +123,6 @@ def eigh(x):
return w, v return w, v
def inv(x): def inv(x):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
@ -133,6 +131,9 @@ def inv(x):
np.copyto(m_a, t_a) np.copyto(m_a, t_a)
def backward_code(np, data): def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
lmx = data["f_outputs"] lmx = data["f_outputs"]
@ -151,10 +152,6 @@ def inv(x):
return mx return mx
def pinv(x): def pinv(x):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
@ -163,6 +160,9 @@ def pinv(x):
np.copyto(m_a, t_a) np.copyto(m_a, t_a)
def backward_code(np, data): def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
inp = data["inputs"][0] inp = data["inputs"][0]
@ -186,10 +186,6 @@ def pinv(x):
return mx return mx
def det(x): def det(x):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
@ -198,6 +194,9 @@ def det(x):
np.copyto(L, tL) np.copyto(L, tL)
def backward_code(np, data): def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
f_out = data["f_outputs"][0] f_out = data["f_outputs"][0]
@ -207,7 +206,7 @@ def det(x):
s = n_d * n_o * T(np.linalg.inv(inp)) s = n_d * n_o * T(np.linalg.inv(inp))
np.copyto(out, s) np.copyto(out, s)
s = jt.array(x.shape).data.tolist() s = x.shape
x_s = s[:-2] x_s = s[:-2]
if len(s) == 2: if len(s) == 2:
x_s.append(1) x_s.append(1)
@ -222,10 +221,6 @@ def det(x):
return det return det
def slogdet(x): def slogdet(x):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
sign, m_a = data["outputs"] sign, m_a = data["outputs"]
@ -234,6 +229,9 @@ def slogdet(x):
np.copyto(sign, sign_) np.copyto(sign, sign_)
def backward_code(np, data): def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
inp = data["inputs"][0] inp = data["inputs"][0]
@ -245,7 +243,7 @@ def slogdet(x):
t = t * T(np.linalg.inv(inp)) t = t * T(np.linalg.inv(inp))
np.copyto(out, t) np.copyto(out, t)
s = jt.array(x.shape).data.tolist() s = x.shape
det_s = s[:-2] det_s = s[:-2]
if len(det_s) == 0: if len(det_s) == 0:
det_s.append(1) det_s.append(1)
@ -259,10 +257,6 @@ def slogdet(x):
return sign, mx return sign, mx
def cholesky(x): def cholesky(x):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data): def forward_code(np, data):
a = data["inputs"][0] a = data["inputs"][0]
@ -271,6 +265,9 @@ def cholesky(x):
np.copyto(L, tL) np.copyto(L, tL)
def backward_code(np, data): def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
f_out = data["f_outputs"][0] f_out = data["f_outputs"][0]
@ -295,10 +292,6 @@ def cholesky(x):
return L return L
def solve(a,b): def solve(a,b):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data): def forward_code(np, data):
a, b = data["inputs"] a, b = data["inputs"]
@ -307,6 +300,9 @@ def solve(a,b):
np.copyto(L, ans) np.copyto(L, ans)
def backward_code1(np, data): def backward_code1(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"] dout = data["dout"]
out = data["outputs"][0] out = data["outputs"][0]
f_out = data["f_outputs"][0] f_out = data["f_outputs"][0]

View File

@ -1,9 +1,26 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Authors:
# Haoyang Peng
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt import jittor as jt
import numpy as np import numpy as np
import autograd.numpy as anp
from autograd import jacobian
import unittest import unittest
try:
import autograd.numpy as anp
from autograd import jacobian
has_autograd = True
except:
has_autograd = False
@unittest.skipIf(not has_autograd, "No autograd found.")
class TestCodeOp(unittest.TestCase): class TestCodeOp(unittest.TestCase):
def test_svd(self): def test_svd(self):
def check_svd(a): def check_svd(a):
@ -153,7 +170,7 @@ class TestCodeOp(unittest.TestCase):
tn = np.random.randn((4,4)).astype('float32')*10 tn = np.random.randn((4,4)).astype('float32')*10
x = jt.array(tn) x = jt.array(tn)
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"]) x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
s = jt.array(x.shape).data.tolist() s = list(x.shape)
det_s = s[:-2] det_s = s[:-2]
if len(det_s) == 0: if len(det_s) == 0:
det_s.append(1) det_s.append(1)
@ -223,7 +240,7 @@ class TestCodeOp(unittest.TestCase):
tn = np.random.randn((3, 3)).astype('float32') * 5 tn = np.random.randn((3, 3)).astype('float32') * 5
x = jt.array(tn) x = jt.array(tn)
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
s = jt.array(x.shape).data.tolist() s = list(x.shape)
x_s = s[:-2] x_s = s[:-2]
if len(s) == 2: if len(s) == 2:
x_s.append(1) x_s.append(1)