add unfold,fold,bicubic,qr,eye (#180)

* fix bicubic,add fold.

* add eye.

* add test for qr,bicubic,fold,unfold.

* fix bicubic and fold to code_op ver.add grad test.

* add docs.update pinv to support (..,M,N) shape

* edit maintainer and testfunc's name.

* fix nn

* fix nn

Co-authored-by: Exusial <2247838039@qq.com>
Co-authored-by: Gword <471184555@qq.com>
This commit is contained in:
Exusial 2021-03-02 16:38:28 +08:00 committed by GitHub
parent 3c6b5c3b53
commit 0d7f0db1b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 479 additions and 122 deletions

View File

@ -3,8 +3,55 @@ jittor.linalg
这里是Jittor的线性代数函数的API文档您可以通过`from jittor import linalg`来获取该模块。
## 基本函数简介
#### 基本线性代数运算API
- linalg.inv(a)
对a进行求逆运算
- linalg.pinv(a)
对a进行广义求逆运算。该运算不要求原矩阵a可逆。
- linalg.slogdet(a)
对a求取slogdet。会返回值以及符号。
- linalg.det(a)
对a求行列式。
- linalg.solve(a,b)
求解线性方程Ax=b的解。
#### 分解API
- linalg.cholesky(a)
对a进行cholesky分解。
- linalg.qr(a)
对a进行qr分解。
- linalg.svd
对a进行奇异值分解。
#### 特征值API
- linalg.eig(a)
求取a的特征值以及特征向量。
- linalg.eigh(a)
针对埃尔米特矩阵或者对称矩阵求特征值以及特征向量。
目前的linalg库支持
```eval_rst
.. automodule:: jittor.linalg
:members:
:undoc-members:
```

View File

@ -11,9 +11,21 @@
import jittor as jt
from functools import partial
#TODO:full_matrices=1
def svd(x):
r'''
calculate the Singular Value Decomposition of x.It follows the below fomula:
x = usv*
only support full matrices == False ver now, which means:
x's shape (...,M,K)
u's shape (...,M,K)
s's shape (...,K)
v's shape (...,K,N)
where K is min(M,N).
:param x:
:return:u,s,v.
'''
def forward_code(np, data):
a = data["inputs"][0]
u, s, v = data["outputs"]
@ -81,8 +93,15 @@ def svd(x):
)
return u, s, v
def eigh(x):
def eigh(x):
r"""
calculate the eigenvalues and eigenvectors of x.
:param x (...,M,M):
:return:w, v.
w (...,M) : the eigenvalues.
v (...,M,M) : normalized eigenvectors.
"""
def forward_code(np, data):
a = data["inputs"][0]
w, v = data["outputs"]
@ -122,8 +141,13 @@ def eigh(x):
)
return w, v
def inv(x):
def inv(x):
r"""
calculate the inverse of x.
:param x (...,M,M):
:return:x^-1 (...,M,M).
"""
def forward_code(np, data):
a = data["inputs"][0]
m_a = data["outputs"][0]
@ -151,8 +175,13 @@ def inv(x):
mx = lmx[0]
return mx
def pinv(x):
def pinv(x):
r"""
calculate the pseudo-inverse of a x.
:param x (...,M,N)
:return: x's pinv (...N,M)
"""
def forward_code(np, data):
a = data["inputs"][0]
m_a = data["outputs"][0]
@ -174,9 +203,9 @@ def pinv(x):
+ _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx)
)
np.copyto(out, t)
sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]]
lmx = jt.numpy_code(
[x.shape],
[sw],
[x.dtype],
[x],
forward_code,
@ -185,8 +214,13 @@ def pinv(x):
mx = lmx[0]
return mx
def det(x):
def det(x):
r"""
calculate the determinant of x.
:param x (...,M,M):
:return:|x| (...,1)
"""
def forward_code(np, data):
a = data["inputs"][0]
L = data["outputs"][0]
@ -220,7 +254,15 @@ def det(x):
det = l_det[0]
return det
def slogdet(x):
r"""
calculate the sign and log of the determinant of x.
:param x (...,M,M):
:return sign, x's logdet.
sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf.
logdet in shape (...,1).
"""
def forward_code(np, data):
a = data["inputs"][0]
sign, m_a = data["outputs"]
@ -256,8 +298,15 @@ def slogdet(x):
)
return sign, mx
def cholesky(x):
def cholesky(x):
r"""
do Cholesky decomposition of x in the form of below formula:
x = LL^T
x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix.
:param x (...,M,M):
:return: L (...,M,M).
"""
def forward_code(np, data):
a = data["inputs"][0]
L = data["outputs"][0]
@ -291,8 +340,14 @@ def cholesky(x):
L = lL[0]
return L
def solve(a,b):
def solve(a,b):
r"""
Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular.
:param a:(...,M,M)
:param b:(...,M)
:return:solution of Ax = b formula.x in the shape of (...M)
"""
def forward_code(np, data):
a, b = data["inputs"]
L = data["outputs"][0]
@ -323,4 +378,55 @@ def solve(a,b):
[backward_code1, backward_code2],
)
ans = l_ans[0]
return ans
return ans
def qr(x):
r"""
do the qr factorization of x in the below formula:
x = QR where Q is orthogonal matrix and R is upper-triangle matrix.
:param x (...,M,M):
:return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
"""
def forward_code(np, data):
a = data["inputs"][0]
q, r = data["outputs"]
Q, R = np.linalg.qr(a)
np.copyto(q,Q)
np.copyto(r,R)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
_harmard = partial(np.einsum, '...ij,...ij->...ij')
dout = data["dout"]
out = data["outputs"][0]
q, r = data["f_outputs"]
out_index = data["out_index"]
#pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags
if out_index == 0: # Q_TERM
q_t = _dot(T(q),dout)
rhs_solve = q_t - T(q_t)
rhs_solve = T(np.tril(rhs_solve,-1))
qsolve = np.linalg.solve(r,rhs_solve)
qsolve = T(qsolve)
tq = _dot(q,qsolve)
np.copyto(out,tq)
else: #R_TERM
r_t = _dot(r ,T(dout))
rhs_solve = r_t - T(r_t)
rhs_solve = np.tril(rhs_solve,-1)
rhs_solve = T(rhs_solve)
r_solve = np.linalg.solve(r,rhs_solve)
tr = _dot(q,(T(r_solve) + dout))
np.copyto(out,tr)
q, r = jt.numpy_code(
[x.shape,x.shape],
[x.dtype,x.dtype],
[x],
forward_code,
[backward_code],
)
return q, r

View File

@ -953,11 +953,11 @@ def hardtanh(x,min_val=-1,max_val=1):
class Softplus(Module):
r'''
SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
Args:
[in] beta (float): the beta value for the Softplus formulation. Default: 1.
[in] threshold (float): values above this revert to a linear function. Default: 20.
'''
def __init__(self, beta=1, threshold=20):
@ -976,59 +976,94 @@ class Resize(Module):
def execute(self, x):
return resize(x, self.size, self.mode, self.align_corners)
def _bicubic(x, a, func):
# normal ver
if func == 1:
return (a+2)*(jt.abs(x)**3)-(a+3)*(x**2)+1
if func == 2:
return a*(jt.abs(x)**3)-5*a*(x**2)+8*a*(jt.abs(x))-4*a
return 0
def _interpolate(img, x, y, ids, mode):
if mode=="nearest":
if mode == "nearest":
return img.reindex([*ids, x.floor(), y.floor()])
if mode=="bilinear":
if mode == "bilinear":
fx, fy = x.floor(), y.floor()
cx, cy = fx+1, fy+1
dx, dy = x-fx, y-fy
cx, cy = fx + 1, fy + 1
dx, dy = x - fx, y - fy
a = img.reindex_var([*ids, fx, fy])
b = img.reindex_var([*ids, cx, fy])
c = img.reindex_var([*ids, fx, cy])
d = img.reindex_var([*ids, cx, cy])
dnx, dny = 1-dx, 1-dy
ab = dx*b + dnx*a
cd = dx*d + dnx*c
o = ab*dny + cd*dy
dnx, dny = 1 - dx, 1 - dy
ab = dx * b + dnx * a
cd = dx * d + dnx * c
o = ab * dny + cd * dy
return o
raise(f"Not support interpolation mode: {mode}")
if mode=="bicubic": # ugly ver.
n,c,h,w = img.shape
fx, fy = x.floor(), y.floor()
dix, diy = x - fx, y - fy
ax, ay = _bicubic(dix+1,-0.75,2), _bicubic(diy+1,-0.75,2)
bx, by = _bicubic(dix,-0.75,1), _bicubic(diy,-0.75,1)
cx, cy = _bicubic(1-dix,-0.75,1), _bicubic(1-diy,-0.75,1)
dx, dy = _bicubic(2-dix,-0.75,2), _bicubic(2-diy,-0.75,2)
afx, afy = jt.maximum(jt.minimum(fx-1,h-1),0), jt.maximum(jt.minimum(fy-1,w-1),0)
bfx, bfy = jt.maximum(jt.minimum(fx,h-1),0), jt.maximum(jt.minimum(fy,w-1),0)
cfx, cfy = jt.maximum(jt.minimum(fx+1,h-1),0), jt.maximum(jt.minimum(fy+1,w-1),0)
dfx, dfy = jt.maximum(jt.minimum(fx+2,h-1),0), jt.maximum(jt.minimum(fy+2,w-1),0)
a = ax*(img.reindex_var([*ids,afx,afy])*ay+img.reindex_var([*ids,afx,bfy])*by+img.reindex_var([*ids,afx,cfy])*cy+img.reindex_var([*ids,afx,dfy])*dy)
b = bx*(img.reindex_var([*ids,bfx,afy])*ay+img.reindex_var([*ids,bfx,bfy])*by+img.reindex_var([*ids,bfx,cfy])*cy+img.reindex_var([*ids,bfx,dfy])*dy)
c = cx*(img.reindex_var([*ids,cfx,afy])*ay+img.reindex_var([*ids,cfx,bfy])*by+img.reindex_var([*ids,cfx,cfy])*cy+img.reindex_var([*ids,cfx,dfy])*dy)
d = dx*(img.reindex_var([*ids,dfx,afy])*ay+img.reindex_var([*ids,dfx,bfy])*by+img.reindex_var([*ids,dfx,cfy])*cy+img.reindex_var([*ids,dfx,dfy])*dy)
o = a + b + c + d
return o
raise (f"Not support interpolation mode: {mode}")
def resize(img, size, mode="nearest", align_corners=False):
n,c,h,w = img.shape
H,W = size
nid, cid, hid, wid = jt.index((n,c,H,W))
n, c, h, w = img.shape
H, W = size
nid, cid, hid, wid = jt.index((n, c, H, W))
if align_corners:
x = hid * ((h-1) / max(1, H-1))
y = wid * ((w-1) / max(1, W-1))
x = hid * ((h - 1) / max(1, H - 1))
y = wid * ((w - 1) / max(1, W - 1))
else:
x = hid * (h / H) + (h/H*0.5 - 0.5)
if H>h: x = x.clamp(0, h-1)
y = wid * (w / W) + (w/W*0.5 - 0.5)
if W>w: y = y.clamp(0, w-1)
return _interpolate(img, x, y, (nid,cid), mode)
x = hid * (h / H) + (h / H * 0.5 - 0.5)
if H > h: x = x.clamp(0, h - 1)
y = wid * (w / W) + (w / W * 0.5 - 0.5)
if W > w: y = y.clamp(0, w - 1)
return _interpolate(img, x, y, (nid, cid), mode)
def upsample(img, size, mode="nearest", align_corners=False):
n,c,h,w = img.shape
H,W = size
nid, cid, hid, wid = jt.index((n,c,H,W))
n, c, h, w = img.shape
H, W = size
nid, cid, hid, wid = jt.index((n, c, H, W))
if align_corners:
x = hid * ((h-1) / max(1, H-1))
y = wid * ((w-1) / max(1, W-1))
x = hid * ((h - 1) / max(1, H - 1))
y = wid * ((w - 1) / max(1, W - 1))
elif mode == "bicubic":
x = (hid + 0.5) * (h / H) - 0.5
y = (wid + 0.5) * (w / W) - 0.5
else:
x = hid * (h / H)
y = wid * (w / W)
return _interpolate(img, x, y, (nid,cid), mode)
return _interpolate(img, x, y, (nid, cid), mode)
def interpolate(X,size=None,scale_factor=None,mode='bilinear',align_corners=False):
def interpolate(X, size=None, scale_factor=None, mode='bilinear', align_corners=False):
if scale_factor is not None:
size = [X.shape[-2]*scale_factor,X.shape[-1]*scale_factor]
if isinstance(size,int):
size = (size,size)
if scale_factor is not None and scale_factor>1:
return upsample(X,size,mode,align_corners)
size = [X.shape[-2] * scale_factor, X.shape[-1] * scale_factor]
if isinstance(size, int):
size = (size, size)
if scale_factor is not None and scale_factor > 1:
return upsample(X, size, mode, align_corners)
else:
return resize(X,size,mode,align_corners)
return resize(X, size, mode, align_corners)
def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
r'''
@ -1045,7 +1080,7 @@ def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
[in] mode (string): the interpolate way, default: bilinear.
[in] padding_mode (string): the padding way, default: zeros.
[out] output (var): the output var, whose shape is (N, C, Ho, Wo)
Example:
@ -1069,10 +1104,10 @@ def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
assert Ni == No
assert len(input.shape) == 4 and len(grid.shape)
nid, cid, hid, wid = jt.index((Ni,Ci,Ho,Wo))
x = ((grid[:,:,:,1].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Hi - 1)
y = ((grid[:,:,:,0].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Wi - 1)
return _interpolate(input, x, y, (nid,cid), mode)
nid, cid, hid, wid = jt.index((Ni, Ci, Ho, Wo))
x = ((grid[:, :, :, 1].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Hi - 1)
y = ((grid[:, :, :, 0].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Wi - 1)
return _interpolate(input, x, y, (nid, cid), mode)
def linspace_from_neg_one(grid,num_steps,align_corners):
@ -1327,4 +1362,49 @@ class Sequential(Module):
def __len__(self):
return len(self.layers)
def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
assert X.ndim == 4
if not isinstance(kernel_size, tuple):
kernel_size = (kernel_size, kernel_size)
if not isinstance(dilation, tuple):
dilation = (dilation, dilation)
if not isinstance(padding, tuple):
padding = (padding, padding)
if not isinstance(stride, tuple):
stride = (stride, stride)
n, c, h, w = X.shape
shape = X.shape
area = kernel_size[0] * kernel_size[1]
block_nums = []
for i in range(2, 4):
block_nums.append(
(shape[i] + 2 * padding[i - 2] - dilation[i - 2] * (kernel_size[i - 2] - 1) - 1) // stride[i - 2] + 1)
if padding[0] != 0 or padding[1] != 0:
X = X.reindex([n, c, h + padding[0] * 2, w + padding[1] * 2],
["i0", "i1", f"i2-{padding[0]}", f"i3-{padding[1]}"])
output = X.reindex([n, c * area, block_nums[0] * block_nums[1]], ["i0", f"i1/{area}",
f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}",
f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"])
return output
def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1):
assert X.ndim==3
if not isinstance(kernel_size,tuple):
kernel_size = (kernel_size,kernel_size)
if not isinstance(dilation,tuple):
dilation = (dilation,dilation)
if not isinstance(padding,tuple):
padding = (padding,padding)
if not isinstance(stride,tuple):
stride = (stride,stride)
n,cl,num = X.shape
area = kernel_size[0] * kernel_size[1]
block_nums = []
for i in range(2,4):
block_nums.append((output_size[i-2]+2*padding[i-2]-dilation[i-2]*(kernel_size[i-2]-1)-1) // stride[i-2]+1)
output = X.reindex_reduce("add",[n,cl // area,output_size[0]+2*padding[0],output_size[1]+2*padding[1]],["i0",f"i1/{area}",f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}",f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"])
return output[:,:,padding[0]:padding[0]+output_size[0],padding[1]:padding[1]+output_size[1]]
ModuleList = Sequential

View File

@ -0,0 +1,45 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Haoyang Peng <2247838039@qq.com>
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import torch
from torch.nn import functional as F
import numpy as np
class TestBicubicInterpolate(unittest.TestCase):
# this is for testing bicubic interpolate
def test_bicubic(self):
for _ in range(20):
try:
tn = np.random.randn(1,1,5,5).astype('float32')
ja = jt.array(tn)
ta = torch.autograd.Variable(torch.from_numpy(tn),requires_grad=True)
# test upsample
ju = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic')
tu = F.interpolate(ta,scale_factor=2,mode='bicubic')
assert np.allclose(ju.data,tu.detach().numpy(),rtol=1e-03,atol=1e-06)
gju = jt.grad(ju,ja)
gtu = torch.autograd.grad(tu,ta,torch.ones_like(tu),retain_graph=True)[0]
assert np.allclose(gju.data,gtu.detach().numpy(),rtol=1e-03,atol=1e-06)
# test align
je = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic',align_corners=True)
te = F.interpolate(ta,scale_factor=2,mode='bicubic',align_corners=True)
assert np.allclose(je.data,te.detach().numpy(),rtol=1e-03,atol=1e-06)
gje = jt.grad(je,ja)
gte = torch.autograd.grad(te,ta,torch.ones_like(tu),retain_graph=True)[0]
assert np.allclose(gje.data,gte.detach().numpy(),rtol=1e-03,atol=1e-06)
except AssertionError:
print(ju,tu)
print(je,te)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,40 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Haoyang Peng <2247838039@qq.com>
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import torch
from torch.nn import functional as F
import numpy as np
class TestFoldOp(unittest.TestCase):
def test_fold(self):
# test unfold first and the test fold.
for i in range(4,10):
tn = np.random.randn(1,3,i,i).astype('float32')
ja = jt.array(tn)
ta = torch.autograd.Variable(torch.from_numpy(tn),requires_grad=True)
juf = jt.nn.unfold(ja,kernel_size=2,stride=2,dilation=2,padding=2)
tuf = F.unfold(ta,kernel_size=2,stride=2,dilation=2,padding=2)
assert np.allclose(juf.data,tuf.detach().numpy())
gjuf = jt.grad(juf,ja)
gtuf = torch.autograd.grad(tuf,ta,torch.ones_like(tuf),retain_graph=True)[0]
assert np.allclose(gjuf.data,gtuf.detach().numpy())
# test fold
jf = jt.nn.fold(juf,output_size=(i,i),kernel_size=2,stride=2,dilation=2,padding=2)
tf = F.fold(tuf,output_size=(i,i),kernel_size=2,stride=2,dilation=2,padding=2)
assert np.allclose(jf.data,tf.detach().numpy())
gjf = jt.grad(jf,juf)
gtf = torch.autograd.grad(tf,tuf,torch.ones_like(tf),retain_graph=True)[0]
assert np.allclose(gjf.data,gtf.detach().numpy())
if __name__ == "__main__":
unittest.main()

View File

@ -8,49 +8,52 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import torch
from torch.autograd import Variable
import jittor as jt
import numpy as np
import unittest
try:
import autograd.numpy as anp
from autograd import jacobian
has_autograd = True
except:
has_autograd = False
@unittest.skipIf(not has_autograd, "No autograd found.")
class TestCodeOp(unittest.TestCase):
class TestLinalgOp(unittest.TestCase):
def test_svd(self):
def check_svd(a):
u,s,v = anp.linalg.svd(a, full_matrices=0)
return u,s,v
u, s, v = anp.linalg.svd(a, full_matrices=0)
return u, s, v
def check_u(a):
u,s,v = anp.linalg.svd(a, full_matrices=0)
u, s, v = anp.linalg.svd(a, full_matrices=0)
return u
def check_s(a):
u,s,v = anp.linalg.svd(a, full_matrices=0)
u, s, v = anp.linalg.svd(a, full_matrices=0)
return s
def check_v(a):
u,s,v = anp.linalg.svd(a, full_matrices=0)
u, s, v = anp.linalg.svd(a, full_matrices=0)
return v
for i in range(50):
#not for full-matrices!
a = jt.random((2,2,5,4))
# not for full-matrices!
a = jt.random((2, 2, 5, 4))
c_a = anp.array(a.data)
u,s,v = jt.linalg.svd(a)
tu,ts,tv = check_svd(c_a)
assert np.allclose(tu,u.data)
assert np.allclose(ts,s.data)
assert np.allclose(tv,v.data)
ju = jt.grad(u,a)
js = jt.grad(s,a)
jv = jt.grad(v,a)
u, s, v = jt.linalg.svd(a)
tu, ts, tv = check_svd(c_a)
assert np.allclose(tu, u.data)
assert np.allclose(ts, s.data)
assert np.allclose(tv, v.data)
ju = jt.grad(u, a)
js = jt.grad(s, a)
jv = jt.grad(v, a)
grad_u = jacobian(check_u)
gu = grad_u(c_a)
gu = np.sum(gu, 4)
@ -69,56 +72,56 @@ class TestCodeOp(unittest.TestCase):
gv = np.sum(gv, 2)
gv = np.sum(gv, 2)
try:
assert np.allclose(ju.data,gu,atol=1e-5)
assert np.allclose(ju.data, gu, atol=1e-5)
except AssertionError:
print(ju.data)
print(gu)
try:
assert np.allclose(js.data,gs,atol=1e-5)
assert np.allclose(js.data, gs, atol=1e-5)
except AssertionError:
print(js.data)
print(gs)
try:
assert np.allclose(jv.data,gv,atol=1e-5)
assert np.allclose(jv.data, gv, atol=1e-5)
except AssertionError:
print(jv.data)
print(gv)
def test_eigh(self):
def check_eigh(a,UPLO='L'):
w, v = anp.linalg.eigh(a,UPLO)
def check_eigh(a, UPLO='L'):
w, v = anp.linalg.eigh(a, UPLO)
return w, v
def check_w(a,UPLO='L'):
w, v = anp.linalg.eigh(a,UPLO)
def check_w(a, UPLO='L'):
w, v = anp.linalg.eigh(a, UPLO)
return w
def check_v(a,UPLO='L'):
w, v = anp.linalg.eigh(a,UPLO)
def check_v(a, UPLO='L'):
w, v = anp.linalg.eigh(a, UPLO)
return v
for i in range(50):
a = jt.random((2,2,3,3))
a = jt.random((2, 2, 3, 3))
c_a = a.data
w, v = jt.linalg.eigh(a)
tw, tv = check_eigh(c_a)
assert np.allclose(w.data,tw)
assert np.allclose(v.data,tv)
assert np.allclose(w.data, tw)
assert np.allclose(v.data, tv)
jw = jt.grad(w, a)
jv = jt.grad(v, a)
check_gw = jacobian(check_w)
check_gv = jacobian(check_v)
gw = check_gw(c_a)
gw = np.sum(gw,4)
gw = np.sum(gw,2)
gw = np.sum(gw,2)
assert np.allclose(gw,jw.data,rtol = 1,atol = 5e-8)
gw = np.sum(gw, 4)
gw = np.sum(gw, 2)
gw = np.sum(gw, 2)
assert np.allclose(gw, jw.data, rtol=1, atol=5e-8)
gv = check_gv(c_a)
gv = np.sum(gv,4)
gv = np.sum(gv,4)
gv = np.sum(gv,2)
gv = np.sum(gv,2)
assert np.allclose(gv,jv.data,rtol = 1,atol = 5e-8)
gv = np.sum(gv, 4)
gv = np.sum(gv, 4)
gv = np.sum(gv, 2)
gv = np.sum(gv, 2)
assert np.allclose(gv, jv.data, rtol=1, atol=5e-8)
def test_pinv(self):
def check_pinv(a):
@ -126,34 +129,35 @@ class TestCodeOp(unittest.TestCase):
return w
for i in range(50):
x = jt.random((2,2,4,4))
x = jt.random((2, 2, 4, 3))
c_a = x.data
mx = jt.linalg.pinv(x)
tx = check_pinv(c_a)
np.allclose(mx.data,tx)
jx = jt.grad(mx,x)
np.allclose(mx.data, tx)
jx = jt.grad(mx, x)
check_grad = jacobian(check_pinv)
gx = check_grad(c_a)
np.allclose(gx,jx.data)
np.allclose(gx, jx.data)
def test_inv(self):
def check_inv(a):
w = anp.linalg.inv(a)
return w
for i in range(50):
tn = np.random.randn(4,4).astype('float32')*5
while np.allclose(np.linalg.det(tn),0):
tn = np.random.randn((4,4)).astype('float32')*5
tn = np.random.randn(4, 4).astype('float32') * 5
while np.allclose(np.linalg.det(tn), 0):
tn = np.random.randn((4, 4)).astype('float32') * 5
x = jt.array(tn)
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
c_a = x.data
mx = jt.linalg.inv(x)
tx = check_inv(c_a)
np.allclose(mx.data,tx)
jx = jt.grad(mx,x)
np.allclose(mx.data, tx)
jx = jt.grad(mx, x)
check_grad = jacobian(check_inv)
gx = check_grad(c_a)
np.allclose(gx,jx.data)
np.allclose(gx, jx.data)
def test_slogdet(self):
def check_ans(a):
@ -165,11 +169,11 @@ class TestCodeOp(unittest.TestCase):
return w
for i in range(50):
tn = np.random.randn(4,4).astype('float32')*10
while np.allclose(np.linalg.det(tn),0):
tn = np.random.randn((4,4)).astype('float32')*10
tn = np.random.randn(4, 4).astype('float32') * 10
while np.allclose(np.linalg.det(tn), 0):
tn = np.random.randn((4, 4)).astype('float32') * 10
x = jt.array(tn)
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
s = list(x.shape)
det_s = s[:-2]
if len(det_s) == 0:
@ -178,12 +182,12 @@ class TestCodeOp(unittest.TestCase):
ts, ta = check_ans(x.data)
assert np.allclose(sign.data, ts)
assert np.allclose(mx.data, ta)
jx = jt.grad(mx,x)
jx = jt.grad(mx, x)
check_sgrad = jacobian(check_slogdet)
gx = check_sgrad(x.data)
gx = np.sum(gx,2)
gx = np.sum(gx,2)
assert np.allclose(gx,jx.data)
gx = np.sum(gx, 2)
gx = np.sum(gx, 2)
assert np.allclose(gx, jx.data)
def test_cholesky(self):
def check_cholesky(a):
@ -192,39 +196,39 @@ class TestCodeOp(unittest.TestCase):
for i in range(50):
x = jt.array(np.diag((np.random.rand(3) + 1) * 2))
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
tx = x.data
L = jt.linalg.cholesky(x)
tL = check_cholesky(tx)
assert np.allclose(tL,L.data)
jx = jt.grad(L,x)
assert np.allclose(tL, L.data)
jx = jt.grad(L, x)
check_grad = jacobian(check_cholesky)
gx = check_grad(tx)
gx = np.sum(gx, 0)
gx = np.sum(gx, 0)
gx = np.sum(gx, 0)
gx = np.sum(gx, 0)
assert np.allclose(jx.data,gx)
assert np.allclose(jx.data, gx)
def test_solve(self):
def check_solve(a,b):
ans = anp.linalg.solve(a,b)
def check_solve(a, b):
ans = anp.linalg.solve(a, b)
return ans
for i in range(50):
a = jt.random((2,2,3,3))
b = jt.random((2,2,3))
ans = jt.linalg.solve(a,b)
ta = check_solve(a.data,b.data)
a = jt.random((2, 2, 3, 3))
b = jt.random((2, 2, 3))
ans = jt.linalg.solve(a, b)
ta = check_solve(a.data, b.data)
assert np.allclose(ans.data, ta)
jx = jt.grad(ans, a)
check_sgrad = jacobian(check_solve)
gx = check_sgrad(a.data,b.data)
gx = np.sum(gx,0)
gx = np.sum(gx,0)
gx = np.sum(gx,0)
gx = check_sgrad(a.data, b.data)
gx = np.sum(gx, 0)
gx = np.sum(gx, 0)
gx = np.sum(gx, 0)
try:
assert np.allclose(gx, jx.data,rtol=1)
assert np.allclose(gx, jx.data, rtol=1)
except AssertionError:
print(gx)
print(jx.data)
@ -254,7 +258,42 @@ class TestCodeOp(unittest.TestCase):
gx = np.sum(gx, 2)
assert np.allclose(gx, jx.data)
def test_qr(self):
for i in range(50):
tn = np.random.randn(3, 3).astype('float32')
while np.allclose(np.linalg.det(tn), 0):
tn = np.random.randn((3, 3)).astype('float32')
x = jt.array(tn)
# x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
t_x = torch.from_numpy(tn)
t_x = Variable(t_x, requires_grad=True)
jq, jr = jt.linalg.qr(x)
tq, tr = torch.qr(t_x)
try:
assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
assert np.allclose(jr.data, tr.detach().numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' qr results:")
print(jq)
print(jr)
print("pytorch's qr results:")
print(tq)
print(tr)
gq = jt.grad(jq, x).data
gr = jt.grad(jr, x).data
tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True)
tgr = torch.autograd.grad(tr, t_x, torch.ones_like(tr), retain_graph=True)
try:
assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
assert np.allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6)
except AssertionError:
print("ours' qr grad results:")
print(gq)
print(gr)
print("pytorch's qr grad result")
print(tgq[0])
print(tgr[0])
if __name__ == "__main__":
unittest.main()