mirror of https://github.com/Jittor/Jittor
300 lines
9.8 KiB
Python
300 lines
9.8 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
# Maintainers:
|
|
# Haoyang Peng <2247838039@qq.com>
|
|
# Guowei Yang <471184555@qq.com>
|
|
# Dun Liang <randonlang@gmail.com>.
|
|
#
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
import jittor as jt
|
|
import numpy as np
|
|
import unittest
|
|
|
|
try:
|
|
import torch
|
|
from torch.autograd import Variable
|
|
import autograd.numpy as anp
|
|
from autograd import jacobian
|
|
|
|
has_autograd = True
|
|
except:
|
|
has_autograd = False
|
|
|
|
|
|
@unittest.skipIf(not has_autograd, "No autograd found.")
|
|
class TestLinalgOp(unittest.TestCase):
|
|
def test_svd(self):
|
|
def check_svd(a):
|
|
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
|
return u, s, v
|
|
|
|
def check_u(a):
|
|
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
|
return u
|
|
|
|
def check_s(a):
|
|
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
|
return s
|
|
|
|
def check_v(a):
|
|
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
|
return v
|
|
|
|
for i in range(50):
|
|
# not for full-matrices!
|
|
a = jt.random((2, 2, 5, 4))
|
|
c_a = anp.array(a.data)
|
|
u, s, v = jt.linalg.svd(a)
|
|
tu, ts, tv = check_svd(c_a)
|
|
assert np.allclose(tu, u.data)
|
|
assert np.allclose(ts, s.data)
|
|
assert np.allclose(tv, v.data)
|
|
ju = jt.grad(u, a)
|
|
js = jt.grad(s, a)
|
|
jv = jt.grad(v, a)
|
|
grad_u = jacobian(check_u)
|
|
gu = grad_u(c_a)
|
|
gu = np.sum(gu, 4)
|
|
gu = np.sum(gu, 4)
|
|
gu = np.sum(gu, 2)
|
|
gu = np.sum(gu, 2)
|
|
grad_s = jacobian(check_s)
|
|
gs = grad_s(c_a)
|
|
gs = np.sum(gs, 4)
|
|
gs = np.sum(gs, 2)
|
|
gs = np.sum(gs, 2)
|
|
grad_v = jacobian(check_v)
|
|
gv = grad_v(c_a)
|
|
gv = np.sum(gv, 4)
|
|
gv = np.sum(gv, 4)
|
|
gv = np.sum(gv, 2)
|
|
gv = np.sum(gv, 2)
|
|
try:
|
|
assert np.allclose(ju.data, gu, atol=1e-5)
|
|
except AssertionError:
|
|
print(ju.data)
|
|
print(gu)
|
|
try:
|
|
assert np.allclose(js.data, gs, atol=1e-5)
|
|
except AssertionError:
|
|
print(js.data)
|
|
print(gs)
|
|
try:
|
|
assert np.allclose(jv.data, gv, atol=1e-5)
|
|
except AssertionError:
|
|
print(jv.data)
|
|
print(gv)
|
|
|
|
def test_eigh(self):
|
|
def check_eigh(a, UPLO='L'):
|
|
w, v = anp.linalg.eigh(a, UPLO)
|
|
return w, v
|
|
|
|
def check_w(a, UPLO='L'):
|
|
w, v = anp.linalg.eigh(a, UPLO)
|
|
return w
|
|
|
|
def check_v(a, UPLO='L'):
|
|
w, v = anp.linalg.eigh(a, UPLO)
|
|
return v
|
|
|
|
for i in range(50):
|
|
a = jt.random((2, 2, 3, 3))
|
|
c_a = a.data
|
|
w, v = jt.linalg.eigh(a)
|
|
tw, tv = check_eigh(c_a)
|
|
assert np.allclose(w.data, tw)
|
|
assert np.allclose(v.data, tv)
|
|
jw = jt.grad(w, a)
|
|
jv = jt.grad(v, a)
|
|
check_gw = jacobian(check_w)
|
|
check_gv = jacobian(check_v)
|
|
gw = check_gw(c_a)
|
|
gw = np.sum(gw, 4)
|
|
gw = np.sum(gw, 2)
|
|
gw = np.sum(gw, 2)
|
|
assert np.allclose(gw, jw.data, rtol=1, atol=5e-8)
|
|
gv = check_gv(c_a)
|
|
gv = np.sum(gv, 4)
|
|
gv = np.sum(gv, 4)
|
|
gv = np.sum(gv, 2)
|
|
gv = np.sum(gv, 2)
|
|
assert np.allclose(gv, jv.data, rtol=1, atol=5e-8)
|
|
|
|
def test_pinv(self):
|
|
def check_pinv(a):
|
|
w = anp.linalg.pinv(a)
|
|
return w
|
|
|
|
for i in range(50):
|
|
x = jt.random((2, 2, 4, 3))
|
|
c_a = x.data
|
|
mx = jt.linalg.pinv(x)
|
|
tx = check_pinv(c_a)
|
|
np.allclose(mx.data, tx)
|
|
jx = jt.grad(mx, x)
|
|
check_grad = jacobian(check_pinv)
|
|
gx = check_grad(c_a)
|
|
np.allclose(gx, jx.data)
|
|
|
|
def test_inv(self):
|
|
def check_inv(a):
|
|
w = anp.linalg.inv(a)
|
|
return w
|
|
|
|
for i in range(50):
|
|
tn = np.random.randn(4, 4).astype('float32') * 5
|
|
while np.allclose(np.linalg.det(tn), 0):
|
|
tn = np.random.randn((4, 4)).astype('float32') * 5
|
|
x = jt.array(tn)
|
|
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
|
c_a = x.data
|
|
mx = jt.linalg.inv(x)
|
|
tx = check_inv(c_a)
|
|
np.allclose(mx.data, tx)
|
|
jx = jt.grad(mx, x)
|
|
check_grad = jacobian(check_inv)
|
|
gx = check_grad(c_a)
|
|
np.allclose(gx, jx.data)
|
|
|
|
def test_slogdet(self):
|
|
def check_ans(a):
|
|
s, w = anp.linalg.slogdet(a)
|
|
return s, w
|
|
|
|
def check_slogdet(a):
|
|
s, w = anp.linalg.slogdet(a)
|
|
return w
|
|
|
|
for i in range(50):
|
|
tn = np.random.randn(4, 4).astype('float32') * 10
|
|
while np.allclose(np.linalg.det(tn), 0):
|
|
tn = np.random.randn((4, 4)).astype('float32') * 10
|
|
x = jt.array(tn)
|
|
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
|
s = list(x.shape)
|
|
det_s = s[:-2]
|
|
if len(det_s) == 0:
|
|
det_s.append(1)
|
|
sign, mx = jt.linalg.slogdet(x)
|
|
ts, ta = check_ans(x.data)
|
|
assert np.allclose(sign.data, ts)
|
|
assert np.allclose(mx.data, ta)
|
|
jx = jt.grad(mx, x)
|
|
check_sgrad = jacobian(check_slogdet)
|
|
gx = check_sgrad(x.data)
|
|
gx = np.sum(gx, 2)
|
|
gx = np.sum(gx, 2)
|
|
assert np.allclose(gx, jx.data)
|
|
|
|
def test_cholesky(self):
|
|
def check_cholesky(a):
|
|
L = anp.linalg.cholesky(a)
|
|
return L
|
|
|
|
for i in range(50):
|
|
x = jt.array(np.diag((np.random.rand(3) + 1) * 2))
|
|
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
|
tx = x.data
|
|
L = jt.linalg.cholesky(x)
|
|
tL = check_cholesky(tx)
|
|
assert np.allclose(tL, L.data)
|
|
jx = jt.grad(L, x)
|
|
check_grad = jacobian(check_cholesky)
|
|
gx = check_grad(tx)
|
|
gx = np.sum(gx, 0)
|
|
gx = np.sum(gx, 0)
|
|
gx = np.sum(gx, 0)
|
|
gx = np.sum(gx, 0)
|
|
assert np.allclose(jx.data, gx)
|
|
|
|
def test_solve(self):
|
|
def check_solve(a, b):
|
|
ans = anp.linalg.solve(a, b)
|
|
return ans
|
|
|
|
for i in range(50):
|
|
a = jt.random((2, 2, 3, 3))
|
|
b = jt.random((2, 2, 3))
|
|
ans = jt.linalg.solve(a, b)
|
|
ta = check_solve(a.data, b.data)
|
|
assert np.allclose(ans.data, ta)
|
|
jx = jt.grad(ans, a)
|
|
check_sgrad = jacobian(check_solve)
|
|
gx = check_sgrad(a.data, b.data)
|
|
gx = np.sum(gx, 0)
|
|
gx = np.sum(gx, 0)
|
|
gx = np.sum(gx, 0)
|
|
try:
|
|
assert np.allclose(gx, jx.data, rtol=1)
|
|
except AssertionError:
|
|
print(gx)
|
|
print(jx.data)
|
|
|
|
def test_det(self):
|
|
def check_det(a):
|
|
de = anp.linalg.det(a)
|
|
return de
|
|
|
|
for i in range(50):
|
|
tn = np.random.randn(3, 3).astype('float32') * 5
|
|
while np.allclose(np.linalg.det(tn), 0):
|
|
tn = np.random.randn((3, 3)).astype('float32') * 5
|
|
x = jt.array(tn)
|
|
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
|
s = list(x.shape)
|
|
x_s = s[:-2]
|
|
if len(s) == 2:
|
|
x_s.append(1)
|
|
det = jt.linalg.det(x)
|
|
ta = check_det(x.data)
|
|
assert np.allclose(det.data, ta)
|
|
jx = jt.grad(det, x)
|
|
check_sgrad = jacobian(check_det)
|
|
gx = check_sgrad(x.data)
|
|
gx = np.sum(gx, 2)
|
|
gx = np.sum(gx, 2)
|
|
assert np.allclose(gx, jx.data)
|
|
|
|
def test_qr(self):
|
|
for i in range(50):
|
|
tn = np.random.randn(3, 3).astype('float32')
|
|
while np.allclose(np.linalg.det(tn), 0):
|
|
tn = np.random.randn((3, 3)).astype('float32')
|
|
x = jt.array(tn)
|
|
# x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
|
t_x = torch.from_numpy(tn)
|
|
t_x = Variable(t_x, requires_grad=True)
|
|
jq, jr = jt.linalg.qr(x)
|
|
tq, tr = torch.qr(t_x)
|
|
try:
|
|
assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
|
|
assert np.allclose(jr.data, tr.detach().numpy(), rtol=1e-4, atol=1e-6)
|
|
except AssertionError:
|
|
print("ours' qr results:")
|
|
print(jq)
|
|
print(jr)
|
|
print("pytorch's qr results:")
|
|
print(tq)
|
|
print(tr)
|
|
gq = jt.grad(jq, x).data
|
|
gr = jt.grad(jr, x).data
|
|
tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True)
|
|
tgr = torch.autograd.grad(tr, t_x, torch.ones_like(tr), retain_graph=True)
|
|
try:
|
|
assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
|
|
assert np.allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6)
|
|
except AssertionError:
|
|
print("ours' qr grad results:")
|
|
print(gq)
|
|
print(gr)
|
|
print("pytorch's qr grad result")
|
|
print(tgq[0])
|
|
print(tgr[0])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|