mirror of https://github.com/Jittor/Jittor
some fix
This commit is contained in:
parent
45aba7c99e
commit
eda31dcacf
|
@ -1,13 +1,18 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Authors:
|
||||
# Haoyang Peng
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
#TODO:full_matrices=1
|
||||
def svd(x):
|
||||
from functools import partial
|
||||
import copy
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
|
@ -19,6 +24,9 @@ def svd(x):
|
|||
np.copyto(v, tv)
|
||||
|
||||
def backward_code(np, data):
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
inp = data["inputs"][0]
|
||||
|
@ -56,15 +64,13 @@ def svd(x):
|
|||
t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
|
||||
np.copyto(out, t)
|
||||
|
||||
s = jt.array(x.shape).data.tolist()
|
||||
m, n = x.shape[-2:]
|
||||
k = np.min((m, n))
|
||||
k = int(k)
|
||||
s1 = copy.deepcopy(s)
|
||||
k = min(m, n)
|
||||
s1 = list(x.shape)
|
||||
s1[-1] = k
|
||||
s2 = copy.deepcopy(s)
|
||||
s2 = list(x.shape)
|
||||
s2[-2] = k
|
||||
s3 = s[:-2]
|
||||
s3 = list(x.shape)[:-2]
|
||||
s3.append(k)
|
||||
u, s, v = jt.numpy_code(
|
||||
[s1, s3, s2],
|
||||
|
@ -76,11 +82,6 @@ def svd(x):
|
|||
return u, s, v
|
||||
|
||||
def eigh(x):
|
||||
from functools import partial
|
||||
import copy
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
|
@ -90,6 +91,9 @@ def eigh(x):
|
|||
np.copyto(v, tv)
|
||||
|
||||
def backward_code(np, data):
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
inp = data["inputs"][0]
|
||||
|
@ -107,10 +111,8 @@ def eigh(x):
|
|||
t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
|
||||
np.copyto(out, t)
|
||||
|
||||
s = jt.array(x.shape).data.tolist()
|
||||
sw = s[:-2]
|
||||
sw.append(s[-1])
|
||||
sv = copy.deepcopy(s)
|
||||
sw = x.shape[:-2] + x.shape[-1:]
|
||||
sv = x.shape
|
||||
w, v = jt.numpy_code(
|
||||
[sw, sv],
|
||||
[x.dtype, x.dtype],
|
||||
|
@ -121,10 +123,6 @@ def eigh(x):
|
|||
return w, v
|
||||
|
||||
def inv(x):
|
||||
from functools import partial
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
|
@ -133,6 +131,9 @@ def inv(x):
|
|||
np.copyto(m_a, t_a)
|
||||
|
||||
def backward_code(np, data):
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
lmx = data["f_outputs"]
|
||||
|
@ -151,10 +152,6 @@ def inv(x):
|
|||
return mx
|
||||
|
||||
def pinv(x):
|
||||
from functools import partial
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
|
@ -163,6 +160,9 @@ def pinv(x):
|
|||
np.copyto(m_a, t_a)
|
||||
|
||||
def backward_code(np, data):
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
inp = data["inputs"][0]
|
||||
|
@ -186,10 +186,6 @@ def pinv(x):
|
|||
return mx
|
||||
|
||||
def det(x):
|
||||
from functools import partial
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
|
@ -198,6 +194,9 @@ def det(x):
|
|||
np.copyto(L, tL)
|
||||
|
||||
def backward_code(np, data):
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
f_out = data["f_outputs"][0]
|
||||
|
@ -207,7 +206,7 @@ def det(x):
|
|||
s = n_d * n_o * T(np.linalg.inv(inp))
|
||||
np.copyto(out, s)
|
||||
|
||||
s = jt.array(x.shape).data.tolist()
|
||||
s = x.shape
|
||||
x_s = s[:-2]
|
||||
if len(s) == 2:
|
||||
x_s.append(1)
|
||||
|
@ -222,10 +221,6 @@ def det(x):
|
|||
return det
|
||||
|
||||
def slogdet(x):
|
||||
from functools import partial
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
sign, m_a = data["outputs"]
|
||||
|
@ -234,6 +229,9 @@ def slogdet(x):
|
|||
np.copyto(sign, sign_)
|
||||
|
||||
def backward_code(np, data):
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
inp = data["inputs"][0]
|
||||
|
@ -245,7 +243,7 @@ def slogdet(x):
|
|||
t = t * T(np.linalg.inv(inp))
|
||||
np.copyto(out, t)
|
||||
|
||||
s = jt.array(x.shape).data.tolist()
|
||||
s = x.shape
|
||||
det_s = s[:-2]
|
||||
if len(det_s) == 0:
|
||||
det_s.append(1)
|
||||
|
@ -259,10 +257,6 @@ def slogdet(x):
|
|||
return sign, mx
|
||||
|
||||
def cholesky(x):
|
||||
from functools import partial
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
|
@ -271,6 +265,9 @@ def cholesky(x):
|
|||
np.copyto(L, tL)
|
||||
|
||||
def backward_code(np, data):
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
f_out = data["f_outputs"][0]
|
||||
|
@ -295,10 +292,6 @@ def cholesky(x):
|
|||
return L
|
||||
|
||||
def solve(a,b):
|
||||
from functools import partial
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
|
||||
def forward_code(np, data):
|
||||
a, b = data["inputs"]
|
||||
|
@ -307,6 +300,9 @@ def solve(a,b):
|
|||
np.copyto(L, ans)
|
||||
|
||||
def backward_code1(np, data):
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
f_out = data["f_outputs"][0]
|
||||
|
|
|
@ -1,9 +1,26 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Authors:
|
||||
# Haoyang Peng
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import autograd.numpy as anp
|
||||
from autograd import jacobian
|
||||
import unittest
|
||||
|
||||
|
||||
try:
|
||||
import autograd.numpy as anp
|
||||
from autograd import jacobian
|
||||
has_autograd = True
|
||||
except:
|
||||
has_autograd = False
|
||||
|
||||
@unittest.skipIf(not has_autograd, "No autograd found.")
|
||||
class TestCodeOp(unittest.TestCase):
|
||||
def test_svd(self):
|
||||
def check_svd(a):
|
||||
|
@ -153,7 +170,7 @@ class TestCodeOp(unittest.TestCase):
|
|||
tn = np.random.randn((4,4)).astype('float32')*10
|
||||
x = jt.array(tn)
|
||||
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
|
||||
s = jt.array(x.shape).data.tolist()
|
||||
s = list(x.shape)
|
||||
det_s = s[:-2]
|
||||
if len(det_s) == 0:
|
||||
det_s.append(1)
|
||||
|
@ -223,7 +240,7 @@ class TestCodeOp(unittest.TestCase):
|
|||
tn = np.random.randn((3, 3)).astype('float32') * 5
|
||||
x = jt.array(tn)
|
||||
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
||||
s = jt.array(x.shape).data.tolist()
|
||||
s = list(x.shape)
|
||||
x_s = s[:-2]
|
||||
if len(s) == 2:
|
||||
x_s.append(1)
|
||||
|
|
Loading…
Reference in New Issue