mirror of https://github.com/Jittor/Jittor
some fix
This commit is contained in:
parent
45aba7c99e
commit
eda31dcacf
|
@ -1,13 +1,18 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||||
|
# Authors:
|
||||||
|
# Haoyang Peng
|
||||||
|
# Guowei Yang <471184555@qq.com>
|
||||||
|
# Dun Liang <randonlang@gmail.com>.
|
||||||
|
#
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import numpy as np
|
from functools import partial
|
||||||
|
|
||||||
#TODO:full_matrices=1
|
#TODO:full_matrices=1
|
||||||
def svd(x):
|
def svd(x):
|
||||||
from functools import partial
|
|
||||||
import copy
|
|
||||||
def T(x):
|
|
||||||
return np.swapaxes(x, -1, -2)
|
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
|
||||||
|
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
|
@ -19,6 +24,9 @@ def svd(x):
|
||||||
np.copyto(v, tv)
|
np.copyto(v, tv)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
inp = data["inputs"][0]
|
inp = data["inputs"][0]
|
||||||
|
@ -56,15 +64,13 @@ def svd(x):
|
||||||
t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
|
t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
|
|
||||||
s = jt.array(x.shape).data.tolist()
|
|
||||||
m, n = x.shape[-2:]
|
m, n = x.shape[-2:]
|
||||||
k = np.min((m, n))
|
k = min(m, n)
|
||||||
k = int(k)
|
s1 = list(x.shape)
|
||||||
s1 = copy.deepcopy(s)
|
|
||||||
s1[-1] = k
|
s1[-1] = k
|
||||||
s2 = copy.deepcopy(s)
|
s2 = list(x.shape)
|
||||||
s2[-2] = k
|
s2[-2] = k
|
||||||
s3 = s[:-2]
|
s3 = list(x.shape)[:-2]
|
||||||
s3.append(k)
|
s3.append(k)
|
||||||
u, s, v = jt.numpy_code(
|
u, s, v = jt.numpy_code(
|
||||||
[s1, s3, s2],
|
[s1, s3, s2],
|
||||||
|
@ -76,11 +82,6 @@ def svd(x):
|
||||||
return u, s, v
|
return u, s, v
|
||||||
|
|
||||||
def eigh(x):
|
def eigh(x):
|
||||||
from functools import partial
|
|
||||||
import copy
|
|
||||||
def T(x):
|
|
||||||
return np.swapaxes(x, -1, -2)
|
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
|
||||||
|
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
|
@ -90,6 +91,9 @@ def eigh(x):
|
||||||
np.copyto(v, tv)
|
np.copyto(v, tv)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
inp = data["inputs"][0]
|
inp = data["inputs"][0]
|
||||||
|
@ -107,10 +111,8 @@ def eigh(x):
|
||||||
t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
|
t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
|
|
||||||
s = jt.array(x.shape).data.tolist()
|
sw = x.shape[:-2] + x.shape[-1:]
|
||||||
sw = s[:-2]
|
sv = x.shape
|
||||||
sw.append(s[-1])
|
|
||||||
sv = copy.deepcopy(s)
|
|
||||||
w, v = jt.numpy_code(
|
w, v = jt.numpy_code(
|
||||||
[sw, sv],
|
[sw, sv],
|
||||||
[x.dtype, x.dtype],
|
[x.dtype, x.dtype],
|
||||||
|
@ -121,10 +123,6 @@ def eigh(x):
|
||||||
return w, v
|
return w, v
|
||||||
|
|
||||||
def inv(x):
|
def inv(x):
|
||||||
from functools import partial
|
|
||||||
def T(x):
|
|
||||||
return np.swapaxes(x, -1, -2)
|
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
|
||||||
|
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
|
@ -133,6 +131,9 @@ def inv(x):
|
||||||
np.copyto(m_a, t_a)
|
np.copyto(m_a, t_a)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
lmx = data["f_outputs"]
|
lmx = data["f_outputs"]
|
||||||
|
@ -151,10 +152,6 @@ def inv(x):
|
||||||
return mx
|
return mx
|
||||||
|
|
||||||
def pinv(x):
|
def pinv(x):
|
||||||
from functools import partial
|
|
||||||
def T(x):
|
|
||||||
return np.swapaxes(x, -1, -2)
|
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
|
||||||
|
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
|
@ -163,6 +160,9 @@ def pinv(x):
|
||||||
np.copyto(m_a, t_a)
|
np.copyto(m_a, t_a)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
inp = data["inputs"][0]
|
inp = data["inputs"][0]
|
||||||
|
@ -186,10 +186,6 @@ def pinv(x):
|
||||||
return mx
|
return mx
|
||||||
|
|
||||||
def det(x):
|
def det(x):
|
||||||
from functools import partial
|
|
||||||
def T(x):
|
|
||||||
return np.swapaxes(x, -1, -2)
|
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
|
||||||
|
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
|
@ -198,6 +194,9 @@ def det(x):
|
||||||
np.copyto(L, tL)
|
np.copyto(L, tL)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
f_out = data["f_outputs"][0]
|
f_out = data["f_outputs"][0]
|
||||||
|
@ -207,7 +206,7 @@ def det(x):
|
||||||
s = n_d * n_o * T(np.linalg.inv(inp))
|
s = n_d * n_o * T(np.linalg.inv(inp))
|
||||||
np.copyto(out, s)
|
np.copyto(out, s)
|
||||||
|
|
||||||
s = jt.array(x.shape).data.tolist()
|
s = x.shape
|
||||||
x_s = s[:-2]
|
x_s = s[:-2]
|
||||||
if len(s) == 2:
|
if len(s) == 2:
|
||||||
x_s.append(1)
|
x_s.append(1)
|
||||||
|
@ -222,10 +221,6 @@ def det(x):
|
||||||
return det
|
return det
|
||||||
|
|
||||||
def slogdet(x):
|
def slogdet(x):
|
||||||
from functools import partial
|
|
||||||
def T(x):
|
|
||||||
return np.swapaxes(x, -1, -2)
|
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
sign, m_a = data["outputs"]
|
sign, m_a = data["outputs"]
|
||||||
|
@ -234,6 +229,9 @@ def slogdet(x):
|
||||||
np.copyto(sign, sign_)
|
np.copyto(sign, sign_)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
inp = data["inputs"][0]
|
inp = data["inputs"][0]
|
||||||
|
@ -245,7 +243,7 @@ def slogdet(x):
|
||||||
t = t * T(np.linalg.inv(inp))
|
t = t * T(np.linalg.inv(inp))
|
||||||
np.copyto(out, t)
|
np.copyto(out, t)
|
||||||
|
|
||||||
s = jt.array(x.shape).data.tolist()
|
s = x.shape
|
||||||
det_s = s[:-2]
|
det_s = s[:-2]
|
||||||
if len(det_s) == 0:
|
if len(det_s) == 0:
|
||||||
det_s.append(1)
|
det_s.append(1)
|
||||||
|
@ -259,10 +257,6 @@ def slogdet(x):
|
||||||
return sign, mx
|
return sign, mx
|
||||||
|
|
||||||
def cholesky(x):
|
def cholesky(x):
|
||||||
from functools import partial
|
|
||||||
def T(x):
|
|
||||||
return np.swapaxes(x, -1, -2)
|
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
|
||||||
|
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a = data["inputs"][0]
|
a = data["inputs"][0]
|
||||||
|
@ -271,6 +265,9 @@ def cholesky(x):
|
||||||
np.copyto(L, tL)
|
np.copyto(L, tL)
|
||||||
|
|
||||||
def backward_code(np, data):
|
def backward_code(np, data):
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
f_out = data["f_outputs"][0]
|
f_out = data["f_outputs"][0]
|
||||||
|
@ -295,10 +292,6 @@ def cholesky(x):
|
||||||
return L
|
return L
|
||||||
|
|
||||||
def solve(a,b):
|
def solve(a,b):
|
||||||
from functools import partial
|
|
||||||
def T(x):
|
|
||||||
return np.swapaxes(x, -1, -2)
|
|
||||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
|
||||||
|
|
||||||
def forward_code(np, data):
|
def forward_code(np, data):
|
||||||
a, b = data["inputs"]
|
a, b = data["inputs"]
|
||||||
|
@ -307,6 +300,9 @@ def solve(a,b):
|
||||||
np.copyto(L, ans)
|
np.copyto(L, ans)
|
||||||
|
|
||||||
def backward_code1(np, data):
|
def backward_code1(np, data):
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
dout = data["dout"]
|
dout = data["dout"]
|
||||||
out = data["outputs"][0]
|
out = data["outputs"][0]
|
||||||
f_out = data["f_outputs"][0]
|
f_out = data["f_outputs"][0]
|
||||||
|
|
|
@ -1,9 +1,26 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||||
|
# Authors:
|
||||||
|
# Haoyang Peng
|
||||||
|
# Guowei Yang <471184555@qq.com>
|
||||||
|
# Dun Liang <randonlang@gmail.com>.
|
||||||
|
#
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import autograd.numpy as anp
|
|
||||||
from autograd import jacobian
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import autograd.numpy as anp
|
||||||
|
from autograd import jacobian
|
||||||
|
has_autograd = True
|
||||||
|
except:
|
||||||
|
has_autograd = False
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_autograd, "No autograd found.")
|
||||||
class TestCodeOp(unittest.TestCase):
|
class TestCodeOp(unittest.TestCase):
|
||||||
def test_svd(self):
|
def test_svd(self):
|
||||||
def check_svd(a):
|
def check_svd(a):
|
||||||
|
@ -153,7 +170,7 @@ class TestCodeOp(unittest.TestCase):
|
||||||
tn = np.random.randn((4,4)).astype('float32')*10
|
tn = np.random.randn((4,4)).astype('float32')*10
|
||||||
x = jt.array(tn)
|
x = jt.array(tn)
|
||||||
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
|
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
|
||||||
s = jt.array(x.shape).data.tolist()
|
s = list(x.shape)
|
||||||
det_s = s[:-2]
|
det_s = s[:-2]
|
||||||
if len(det_s) == 0:
|
if len(det_s) == 0:
|
||||||
det_s.append(1)
|
det_s.append(1)
|
||||||
|
@ -223,7 +240,7 @@ class TestCodeOp(unittest.TestCase):
|
||||||
tn = np.random.randn((3, 3)).astype('float32') * 5
|
tn = np.random.randn((3, 3)).astype('float32') * 5
|
||||||
x = jt.array(tn)
|
x = jt.array(tn)
|
||||||
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
||||||
s = jt.array(x.shape).data.tolist()
|
s = list(x.shape)
|
||||||
x_s = s[:-2]
|
x_s = s[:-2]
|
||||||
if len(s) == 2:
|
if len(s) == 2:
|
||||||
x_s.append(1)
|
x_s.append(1)
|
||||||
|
|
Loading…
Reference in New Issue