This commit is contained in:
Dun Liang 2020-08-11 20:30:26 +08:00
parent 45aba7c99e
commit eda31dcacf
2 changed files with 64 additions and 51 deletions

View File

@ -1,13 +1,18 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Authors:
# Haoyang Peng
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np
from functools import partial
#TODO:full_matrices=1
def svd(x):
from functools import partial
import copy
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data):
a = data["inputs"][0]
@ -19,6 +24,9 @@ def svd(x):
np.copyto(v, tv)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
@ -56,15 +64,13 @@ def svd(x):
t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
np.copyto(out, t)
s = jt.array(x.shape).data.tolist()
m, n = x.shape[-2:]
k = np.min((m, n))
k = int(k)
s1 = copy.deepcopy(s)
k = min(m, n)
s1 = list(x.shape)
s1[-1] = k
s2 = copy.deepcopy(s)
s2 = list(x.shape)
s2[-2] = k
s3 = s[:-2]
s3 = list(x.shape)[:-2]
s3.append(k)
u, s, v = jt.numpy_code(
[s1, s3, s2],
@ -76,11 +82,6 @@ def svd(x):
return u, s, v
def eigh(x):
from functools import partial
import copy
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data):
a = data["inputs"][0]
@ -90,6 +91,9 @@ def eigh(x):
np.copyto(v, tv)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
@ -107,10 +111,8 @@ def eigh(x):
t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
np.copyto(out, t)
s = jt.array(x.shape).data.tolist()
sw = s[:-2]
sw.append(s[-1])
sv = copy.deepcopy(s)
sw = x.shape[:-2] + x.shape[-1:]
sv = x.shape
w, v = jt.numpy_code(
[sw, sv],
[x.dtype, x.dtype],
@ -121,10 +123,6 @@ def eigh(x):
return w, v
def inv(x):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data):
a = data["inputs"][0]
@ -133,6 +131,9 @@ def inv(x):
np.copyto(m_a, t_a)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
lmx = data["f_outputs"]
@ -151,10 +152,6 @@ def inv(x):
return mx
def pinv(x):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data):
a = data["inputs"][0]
@ -163,6 +160,9 @@ def pinv(x):
np.copyto(m_a, t_a)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
@ -186,10 +186,6 @@ def pinv(x):
return mx
def det(x):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data):
a = data["inputs"][0]
@ -198,6 +194,9 @@ def det(x):
np.copyto(L, tL)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
f_out = data["f_outputs"][0]
@ -207,7 +206,7 @@ def det(x):
s = n_d * n_o * T(np.linalg.inv(inp))
np.copyto(out, s)
s = jt.array(x.shape).data.tolist()
s = x.shape
x_s = s[:-2]
if len(s) == 2:
x_s.append(1)
@ -222,10 +221,6 @@ def det(x):
return det
def slogdet(x):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data):
a = data["inputs"][0]
sign, m_a = data["outputs"]
@ -234,6 +229,9 @@ def slogdet(x):
np.copyto(sign, sign_)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
@ -245,7 +243,7 @@ def slogdet(x):
t = t * T(np.linalg.inv(inp))
np.copyto(out, t)
s = jt.array(x.shape).data.tolist()
s = x.shape
det_s = s[:-2]
if len(det_s) == 0:
det_s.append(1)
@ -259,10 +257,6 @@ def slogdet(x):
return sign, mx
def cholesky(x):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data):
a = data["inputs"][0]
@ -271,6 +265,9 @@ def cholesky(x):
np.copyto(L, tL)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
f_out = data["f_outputs"][0]
@ -295,10 +292,6 @@ def cholesky(x):
return L
def solve(a,b):
from functools import partial
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
def forward_code(np, data):
a, b = data["inputs"]
@ -307,6 +300,9 @@ def solve(a,b):
np.copyto(L, ans)
def backward_code1(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
f_out = data["f_outputs"][0]

View File

@ -1,9 +1,26 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Authors:
# Haoyang Peng
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np
import autograd.numpy as anp
from autograd import jacobian
import unittest
try:
import autograd.numpy as anp
from autograd import jacobian
has_autograd = True
except:
has_autograd = False
@unittest.skipIf(not has_autograd, "No autograd found.")
class TestCodeOp(unittest.TestCase):
def test_svd(self):
def check_svd(a):
@ -153,7 +170,7 @@ class TestCodeOp(unittest.TestCase):
tn = np.random.randn((4,4)).astype('float32')*10
x = jt.array(tn)
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
s = jt.array(x.shape).data.tolist()
s = list(x.shape)
det_s = s[:-2]
if len(det_s) == 0:
det_s.append(1)
@ -223,7 +240,7 @@ class TestCodeOp(unittest.TestCase):
tn = np.random.randn((3, 3)).astype('float32') * 5
x = jt.array(tn)
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
s = jt.array(x.shape).data.tolist()
s = list(x.shape)
x_s = s[:-2]
if len(s) == 2:
x_s.append(1)