mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor
This commit is contained in:
commit
6b9bfef1da
|
@ -3,8 +3,55 @@ jittor.linalg
|
|||
|
||||
这里是Jittor的线性代数函数的API文档,您可以通过`from jittor import linalg`来获取该模块。
|
||||
|
||||
## 基本函数简介
|
||||
#### 基本线性代数运算API
|
||||
- linalg.inv(a)
|
||||
|
||||
对a进行求逆运算
|
||||
|
||||
- linalg.pinv(a)
|
||||
|
||||
对a进行广义求逆运算。该运算不要求原矩阵a可逆。
|
||||
|
||||
- linalg.slogdet(a)
|
||||
|
||||
对a求取slogdet。会返回值以及符号。
|
||||
|
||||
- linalg.det(a)
|
||||
|
||||
对a求行列式。
|
||||
|
||||
- linalg.solve(a,b)
|
||||
|
||||
求解线性方程Ax=b的解。
|
||||
|
||||
#### 分解API
|
||||
- linalg.cholesky(a)
|
||||
|
||||
对a进行cholesky分解。
|
||||
|
||||
- linalg.qr(a)
|
||||
|
||||
对a进行qr分解。
|
||||
|
||||
- linalg.svd
|
||||
|
||||
对a进行奇异值分解。
|
||||
#### 特征值API
|
||||
- linalg.eig(a)
|
||||
|
||||
求取a的特征值以及特征向量。
|
||||
|
||||
- linalg.eigh(a)
|
||||
|
||||
针对埃尔米特矩阵或者对称矩阵求特征值以及特征向量。
|
||||
|
||||
|
||||
目前的linalg库支持
|
||||
|
||||
```eval_rst
|
||||
.. automodule:: jittor.linalg
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
||||
|
||||
|
|
|
@ -11,9 +11,21 @@
|
|||
import jittor as jt
|
||||
from functools import partial
|
||||
|
||||
|
||||
#TODO:full_matrices=1
|
||||
def svd(x):
|
||||
|
||||
r'''
|
||||
calculate the Singular Value Decomposition of x.It follows the below fomula:
|
||||
x = usv*
|
||||
only support full matrices == False ver now, which means:
|
||||
x's shape (...,M,K)
|
||||
u's shape (...,M,K)
|
||||
s's shape (...,K)
|
||||
v's shape (...,K,N)
|
||||
where K is min(M,N).
|
||||
:param x:
|
||||
:return:u,s,v.
|
||||
'''
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
u, s, v = data["outputs"]
|
||||
|
@ -81,8 +93,15 @@ def svd(x):
|
|||
)
|
||||
return u, s, v
|
||||
|
||||
def eigh(x):
|
||||
|
||||
def eigh(x):
|
||||
r"""
|
||||
calculate the eigenvalues and eigenvectors of x.
|
||||
:param x (...,M,M):
|
||||
:return:w, v.
|
||||
w (...,M) : the eigenvalues.
|
||||
v (...,M,M) : normalized eigenvectors.
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
w, v = data["outputs"]
|
||||
|
@ -122,8 +141,13 @@ def eigh(x):
|
|||
)
|
||||
return w, v
|
||||
|
||||
def inv(x):
|
||||
|
||||
def inv(x):
|
||||
r"""
|
||||
calculate the inverse of x.
|
||||
:param x (...,M,M):
|
||||
:return:x^-1 (...,M,M).
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
m_a = data["outputs"][0]
|
||||
|
@ -151,8 +175,13 @@ def inv(x):
|
|||
mx = lmx[0]
|
||||
return mx
|
||||
|
||||
def pinv(x):
|
||||
|
||||
def pinv(x):
|
||||
r"""
|
||||
calculate the pseudo-inverse of a x.
|
||||
:param x (...,M,N)
|
||||
:return: x's pinv (...N,M)
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
m_a = data["outputs"][0]
|
||||
|
@ -174,9 +203,9 @@ def pinv(x):
|
|||
+ _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx)
|
||||
)
|
||||
np.copyto(out, t)
|
||||
|
||||
sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]]
|
||||
lmx = jt.numpy_code(
|
||||
[x.shape],
|
||||
[sw],
|
||||
[x.dtype],
|
||||
[x],
|
||||
forward_code,
|
||||
|
@ -185,8 +214,13 @@ def pinv(x):
|
|||
mx = lmx[0]
|
||||
return mx
|
||||
|
||||
def det(x):
|
||||
|
||||
def det(x):
|
||||
r"""
|
||||
calculate the determinant of x.
|
||||
:param x (...,M,M):
|
||||
:return:|x| (...,1)
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
L = data["outputs"][0]
|
||||
|
@ -220,7 +254,15 @@ def det(x):
|
|||
det = l_det[0]
|
||||
return det
|
||||
|
||||
|
||||
def slogdet(x):
|
||||
r"""
|
||||
calculate the sign and log of the determinant of x.
|
||||
:param x (...,M,M):
|
||||
:return sign, x's logdet.
|
||||
sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf.
|
||||
logdet in shape (...,1).
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
sign, m_a = data["outputs"]
|
||||
|
@ -256,8 +298,15 @@ def slogdet(x):
|
|||
)
|
||||
return sign, mx
|
||||
|
||||
def cholesky(x):
|
||||
|
||||
def cholesky(x):
|
||||
r"""
|
||||
do Cholesky decomposition of x in the form of below formula:
|
||||
x = LL^T
|
||||
x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix.
|
||||
:param x (...,M,M):
|
||||
:return: L (...,M,M).
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
L = data["outputs"][0]
|
||||
|
@ -291,8 +340,14 @@ def cholesky(x):
|
|||
L = lL[0]
|
||||
return L
|
||||
|
||||
def solve(a,b):
|
||||
|
||||
def solve(a,b):
|
||||
r"""
|
||||
Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular.
|
||||
:param a:(...,M,M)
|
||||
:param b:(...,M)
|
||||
:return:solution of Ax = b formula.x in the shape of (...M)
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a, b = data["inputs"]
|
||||
L = data["outputs"][0]
|
||||
|
@ -323,4 +378,55 @@ def solve(a,b):
|
|||
[backward_code1, backward_code2],
|
||||
)
|
||||
ans = l_ans[0]
|
||||
return ans
|
||||
return ans
|
||||
|
||||
|
||||
def qr(x):
|
||||
r"""
|
||||
do the qr factorization of x in the below formula:
|
||||
x = QR where Q is orthogonal matrix and R is upper-triangle matrix.
|
||||
:param x (...,M,M):
|
||||
:return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
q, r = data["outputs"]
|
||||
Q, R = np.linalg.qr(a)
|
||||
np.copyto(q,Q)
|
||||
np.copyto(r,R)
|
||||
|
||||
def backward_code(np, data):
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
_harmard = partial(np.einsum, '...ij,...ij->...ij')
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
q, r = data["f_outputs"]
|
||||
out_index = data["out_index"]
|
||||
#pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags
|
||||
if out_index == 0: # Q_TERM
|
||||
q_t = _dot(T(q),dout)
|
||||
rhs_solve = q_t - T(q_t)
|
||||
rhs_solve = T(np.tril(rhs_solve,-1))
|
||||
qsolve = np.linalg.solve(r,rhs_solve)
|
||||
qsolve = T(qsolve)
|
||||
tq = _dot(q,qsolve)
|
||||
np.copyto(out,tq)
|
||||
else: #R_TERM
|
||||
r_t = _dot(r ,T(dout))
|
||||
rhs_solve = r_t - T(r_t)
|
||||
rhs_solve = np.tril(rhs_solve,-1)
|
||||
rhs_solve = T(rhs_solve)
|
||||
r_solve = np.linalg.solve(r,rhs_solve)
|
||||
tr = _dot(q,(T(r_solve) + dout))
|
||||
np.copyto(out,tr)
|
||||
|
||||
q, r = jt.numpy_code(
|
||||
[x.shape,x.shape],
|
||||
[x.dtype,x.dtype],
|
||||
[x],
|
||||
forward_code,
|
||||
[backward_code],
|
||||
)
|
||||
return q, r
|
||||
|
|
|
@ -953,11 +953,11 @@ def hardtanh(x,min_val=-1,max_val=1):
|
|||
class Softplus(Module):
|
||||
r'''
|
||||
SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
[in] beta (float): the beta value for the Softplus formulation. Default: 1.
|
||||
|
||||
|
||||
[in] threshold (float): values above this revert to a linear function. Default: 20.
|
||||
'''
|
||||
def __init__(self, beta=1, threshold=20):
|
||||
|
@ -976,59 +976,94 @@ class Resize(Module):
|
|||
def execute(self, x):
|
||||
return resize(x, self.size, self.mode, self.align_corners)
|
||||
|
||||
|
||||
def _bicubic(x, a, func):
|
||||
# normal ver
|
||||
if func == 1:
|
||||
return (a+2)*(jt.abs(x)**3)-(a+3)*(x**2)+1
|
||||
if func == 2:
|
||||
return a*(jt.abs(x)**3)-5*a*(x**2)+8*a*(jt.abs(x))-4*a
|
||||
return 0
|
||||
|
||||
|
||||
def _interpolate(img, x, y, ids, mode):
|
||||
if mode=="nearest":
|
||||
if mode == "nearest":
|
||||
return img.reindex([*ids, x.floor(), y.floor()])
|
||||
if mode=="bilinear":
|
||||
if mode == "bilinear":
|
||||
fx, fy = x.floor(), y.floor()
|
||||
cx, cy = fx+1, fy+1
|
||||
dx, dy = x-fx, y-fy
|
||||
cx, cy = fx + 1, fy + 1
|
||||
dx, dy = x - fx, y - fy
|
||||
a = img.reindex_var([*ids, fx, fy])
|
||||
b = img.reindex_var([*ids, cx, fy])
|
||||
c = img.reindex_var([*ids, fx, cy])
|
||||
d = img.reindex_var([*ids, cx, cy])
|
||||
dnx, dny = 1-dx, 1-dy
|
||||
ab = dx*b + dnx*a
|
||||
cd = dx*d + dnx*c
|
||||
o = ab*dny + cd*dy
|
||||
dnx, dny = 1 - dx, 1 - dy
|
||||
ab = dx * b + dnx * a
|
||||
cd = dx * d + dnx * c
|
||||
o = ab * dny + cd * dy
|
||||
return o
|
||||
raise(f"Not support interpolation mode: {mode}")
|
||||
if mode=="bicubic": # ugly ver.
|
||||
n,c,h,w = img.shape
|
||||
fx, fy = x.floor(), y.floor()
|
||||
dix, diy = x - fx, y - fy
|
||||
ax, ay = _bicubic(dix+1,-0.75,2), _bicubic(diy+1,-0.75,2)
|
||||
bx, by = _bicubic(dix,-0.75,1), _bicubic(diy,-0.75,1)
|
||||
cx, cy = _bicubic(1-dix,-0.75,1), _bicubic(1-diy,-0.75,1)
|
||||
dx, dy = _bicubic(2-dix,-0.75,2), _bicubic(2-diy,-0.75,2)
|
||||
afx, afy = jt.maximum(jt.minimum(fx-1,h-1),0), jt.maximum(jt.minimum(fy-1,w-1),0)
|
||||
bfx, bfy = jt.maximum(jt.minimum(fx,h-1),0), jt.maximum(jt.minimum(fy,w-1),0)
|
||||
cfx, cfy = jt.maximum(jt.minimum(fx+1,h-1),0), jt.maximum(jt.minimum(fy+1,w-1),0)
|
||||
dfx, dfy = jt.maximum(jt.minimum(fx+2,h-1),0), jt.maximum(jt.minimum(fy+2,w-1),0)
|
||||
a = ax*(img.reindex_var([*ids,afx,afy])*ay+img.reindex_var([*ids,afx,bfy])*by+img.reindex_var([*ids,afx,cfy])*cy+img.reindex_var([*ids,afx,dfy])*dy)
|
||||
b = bx*(img.reindex_var([*ids,bfx,afy])*ay+img.reindex_var([*ids,bfx,bfy])*by+img.reindex_var([*ids,bfx,cfy])*cy+img.reindex_var([*ids,bfx,dfy])*dy)
|
||||
c = cx*(img.reindex_var([*ids,cfx,afy])*ay+img.reindex_var([*ids,cfx,bfy])*by+img.reindex_var([*ids,cfx,cfy])*cy+img.reindex_var([*ids,cfx,dfy])*dy)
|
||||
d = dx*(img.reindex_var([*ids,dfx,afy])*ay+img.reindex_var([*ids,dfx,bfy])*by+img.reindex_var([*ids,dfx,cfy])*cy+img.reindex_var([*ids,dfx,dfy])*dy)
|
||||
o = a + b + c + d
|
||||
return o
|
||||
raise (f"Not support interpolation mode: {mode}")
|
||||
|
||||
|
||||
def resize(img, size, mode="nearest", align_corners=False):
|
||||
n,c,h,w = img.shape
|
||||
H,W = size
|
||||
nid, cid, hid, wid = jt.index((n,c,H,W))
|
||||
n, c, h, w = img.shape
|
||||
H, W = size
|
||||
nid, cid, hid, wid = jt.index((n, c, H, W))
|
||||
if align_corners:
|
||||
x = hid * ((h-1) / max(1, H-1))
|
||||
y = wid * ((w-1) / max(1, W-1))
|
||||
x = hid * ((h - 1) / max(1, H - 1))
|
||||
y = wid * ((w - 1) / max(1, W - 1))
|
||||
else:
|
||||
x = hid * (h / H) + (h/H*0.5 - 0.5)
|
||||
if H>h: x = x.clamp(0, h-1)
|
||||
y = wid * (w / W) + (w/W*0.5 - 0.5)
|
||||
if W>w: y = y.clamp(0, w-1)
|
||||
return _interpolate(img, x, y, (nid,cid), mode)
|
||||
x = hid * (h / H) + (h / H * 0.5 - 0.5)
|
||||
if H > h: x = x.clamp(0, h - 1)
|
||||
y = wid * (w / W) + (w / W * 0.5 - 0.5)
|
||||
if W > w: y = y.clamp(0, w - 1)
|
||||
return _interpolate(img, x, y, (nid, cid), mode)
|
||||
|
||||
|
||||
def upsample(img, size, mode="nearest", align_corners=False):
|
||||
n,c,h,w = img.shape
|
||||
H,W = size
|
||||
nid, cid, hid, wid = jt.index((n,c,H,W))
|
||||
n, c, h, w = img.shape
|
||||
H, W = size
|
||||
nid, cid, hid, wid = jt.index((n, c, H, W))
|
||||
if align_corners:
|
||||
x = hid * ((h-1) / max(1, H-1))
|
||||
y = wid * ((w-1) / max(1, W-1))
|
||||
x = hid * ((h - 1) / max(1, H - 1))
|
||||
y = wid * ((w - 1) / max(1, W - 1))
|
||||
elif mode == "bicubic":
|
||||
x = (hid + 0.5) * (h / H) - 0.5
|
||||
y = (wid + 0.5) * (w / W) - 0.5
|
||||
else:
|
||||
x = hid * (h / H)
|
||||
y = wid * (w / W)
|
||||
return _interpolate(img, x, y, (nid,cid), mode)
|
||||
return _interpolate(img, x, y, (nid, cid), mode)
|
||||
|
||||
def interpolate(X,size=None,scale_factor=None,mode='bilinear',align_corners=False):
|
||||
|
||||
def interpolate(X, size=None, scale_factor=None, mode='bilinear', align_corners=False):
|
||||
if scale_factor is not None:
|
||||
size = [X.shape[-2]*scale_factor,X.shape[-1]*scale_factor]
|
||||
if isinstance(size,int):
|
||||
size = (size,size)
|
||||
if scale_factor is not None and scale_factor>1:
|
||||
return upsample(X,size,mode,align_corners)
|
||||
size = [X.shape[-2] * scale_factor, X.shape[-1] * scale_factor]
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
if scale_factor is not None and scale_factor > 1:
|
||||
return upsample(X, size, mode, align_corners)
|
||||
else:
|
||||
return resize(X,size,mode,align_corners)
|
||||
return resize(X, size, mode, align_corners)
|
||||
|
||||
|
||||
def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
|
||||
r'''
|
||||
|
@ -1045,7 +1080,7 @@ def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
|
|||
[in] mode (string): the interpolate way, default: bilinear.
|
||||
|
||||
[in] padding_mode (string): the padding way, default: zeros.
|
||||
|
||||
|
||||
[out] output (var): the output var, whose shape is (N, C, Ho, Wo)
|
||||
|
||||
Example:
|
||||
|
@ -1069,10 +1104,10 @@ def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
|
|||
assert Ni == No
|
||||
assert len(input.shape) == 4 and len(grid.shape)
|
||||
|
||||
nid, cid, hid, wid = jt.index((Ni,Ci,Ho,Wo))
|
||||
x = ((grid[:,:,:,1].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Hi - 1)
|
||||
y = ((grid[:,:,:,0].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Wi - 1)
|
||||
return _interpolate(input, x, y, (nid,cid), mode)
|
||||
nid, cid, hid, wid = jt.index((Ni, Ci, Ho, Wo))
|
||||
x = ((grid[:, :, :, 1].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Hi - 1)
|
||||
y = ((grid[:, :, :, 0].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Wi - 1)
|
||||
return _interpolate(input, x, y, (nid, cid), mode)
|
||||
|
||||
|
||||
def linspace_from_neg_one(grid,num_steps,align_corners):
|
||||
|
@ -1327,4 +1362,49 @@ class Sequential(Module):
|
|||
def __len__(self):
|
||||
return len(self.layers)
|
||||
|
||||
|
||||
def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
|
||||
assert X.ndim == 4
|
||||
if not isinstance(kernel_size, tuple):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if not isinstance(dilation, tuple):
|
||||
dilation = (dilation, dilation)
|
||||
if not isinstance(padding, tuple):
|
||||
padding = (padding, padding)
|
||||
if not isinstance(stride, tuple):
|
||||
stride = (stride, stride)
|
||||
n, c, h, w = X.shape
|
||||
shape = X.shape
|
||||
area = kernel_size[0] * kernel_size[1]
|
||||
block_nums = []
|
||||
for i in range(2, 4):
|
||||
block_nums.append(
|
||||
(shape[i] + 2 * padding[i - 2] - dilation[i - 2] * (kernel_size[i - 2] - 1) - 1) // stride[i - 2] + 1)
|
||||
if padding[0] != 0 or padding[1] != 0:
|
||||
X = X.reindex([n, c, h + padding[0] * 2, w + padding[1] * 2],
|
||||
["i0", "i1", f"i2-{padding[0]}", f"i3-{padding[1]}"])
|
||||
output = X.reindex([n, c * area, block_nums[0] * block_nums[1]], ["i0", f"i1/{area}",
|
||||
f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}",
|
||||
f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"])
|
||||
return output
|
||||
|
||||
|
||||
def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1):
|
||||
assert X.ndim==3
|
||||
if not isinstance(kernel_size,tuple):
|
||||
kernel_size = (kernel_size,kernel_size)
|
||||
if not isinstance(dilation,tuple):
|
||||
dilation = (dilation,dilation)
|
||||
if not isinstance(padding,tuple):
|
||||
padding = (padding,padding)
|
||||
if not isinstance(stride,tuple):
|
||||
stride = (stride,stride)
|
||||
n,cl,num = X.shape
|
||||
area = kernel_size[0] * kernel_size[1]
|
||||
block_nums = []
|
||||
for i in range(2,4):
|
||||
block_nums.append((output_size[i-2]+2*padding[i-2]-dilation[i-2]*(kernel_size[i-2]-1)-1) // stride[i-2]+1)
|
||||
output = X.reindex_reduce("add",[n,cl // area,output_size[0]+2*padding[0],output_size[1]+2*padding[1]],["i0",f"i1/{area}",f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}",f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"])
|
||||
return output[:,:,padding[0]:padding[0]+output_size[0],padding[1]:padding[1]+output_size[1]]
|
||||
|
||||
ModuleList = Sequential
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Haoyang Peng <2247838039@qq.com>
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestBicubicInterpolate(unittest.TestCase):
|
||||
# this is for testing bicubic interpolate
|
||||
def test_bicubic(self):
|
||||
for _ in range(20):
|
||||
try:
|
||||
tn = np.random.randn(1,1,5,5).astype('float32')
|
||||
ja = jt.array(tn)
|
||||
ta = torch.autograd.Variable(torch.from_numpy(tn),requires_grad=True)
|
||||
# test upsample
|
||||
ju = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic')
|
||||
tu = F.interpolate(ta,scale_factor=2,mode='bicubic')
|
||||
assert np.allclose(ju.data,tu.detach().numpy(),rtol=1e-03,atol=1e-06)
|
||||
gju = jt.grad(ju,ja)
|
||||
gtu = torch.autograd.grad(tu,ta,torch.ones_like(tu),retain_graph=True)[0]
|
||||
assert np.allclose(gju.data,gtu.detach().numpy(),rtol=1e-03,atol=1e-06)
|
||||
# test align
|
||||
je = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic',align_corners=True)
|
||||
te = F.interpolate(ta,scale_factor=2,mode='bicubic',align_corners=True)
|
||||
assert np.allclose(je.data,te.detach().numpy(),rtol=1e-03,atol=1e-06)
|
||||
gje = jt.grad(je,ja)
|
||||
gte = torch.autograd.grad(te,ta,torch.ones_like(tu),retain_graph=True)[0]
|
||||
assert np.allclose(gje.data,gte.detach().numpy(),rtol=1e-03,atol=1e-06)
|
||||
except AssertionError:
|
||||
print(ju,tu)
|
||||
print(je,te)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,40 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Haoyang Peng <2247838039@qq.com>
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestFoldOp(unittest.TestCase):
|
||||
def test_fold(self):
|
||||
# test unfold first and the test fold.
|
||||
for i in range(4,10):
|
||||
tn = np.random.randn(1,3,i,i).astype('float32')
|
||||
ja = jt.array(tn)
|
||||
ta = torch.autograd.Variable(torch.from_numpy(tn),requires_grad=True)
|
||||
juf = jt.nn.unfold(ja,kernel_size=2,stride=2,dilation=2,padding=2)
|
||||
tuf = F.unfold(ta,kernel_size=2,stride=2,dilation=2,padding=2)
|
||||
assert np.allclose(juf.data,tuf.detach().numpy())
|
||||
gjuf = jt.grad(juf,ja)
|
||||
gtuf = torch.autograd.grad(tuf,ta,torch.ones_like(tuf),retain_graph=True)[0]
|
||||
assert np.allclose(gjuf.data,gtuf.detach().numpy())
|
||||
# test fold
|
||||
jf = jt.nn.fold(juf,output_size=(i,i),kernel_size=2,stride=2,dilation=2,padding=2)
|
||||
tf = F.fold(tuf,output_size=(i,i),kernel_size=2,stride=2,dilation=2,padding=2)
|
||||
assert np.allclose(jf.data,tf.detach().numpy())
|
||||
gjf = jt.grad(jf,juf)
|
||||
gtf = torch.autograd.grad(tf,tuf,torch.ones_like(tf),retain_graph=True)[0]
|
||||
assert np.allclose(gjf.data,gtf.detach().numpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -8,49 +8,52 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
||||
try:
|
||||
import autograd.numpy as anp
|
||||
from autograd import jacobian
|
||||
|
||||
has_autograd = True
|
||||
except:
|
||||
has_autograd = False
|
||||
|
||||
|
||||
@unittest.skipIf(not has_autograd, "No autograd found.")
|
||||
class TestCodeOp(unittest.TestCase):
|
||||
class TestLinalgOp(unittest.TestCase):
|
||||
def test_svd(self):
|
||||
def check_svd(a):
|
||||
u,s,v = anp.linalg.svd(a, full_matrices=0)
|
||||
return u,s,v
|
||||
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
||||
return u, s, v
|
||||
|
||||
def check_u(a):
|
||||
u,s,v = anp.linalg.svd(a, full_matrices=0)
|
||||
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
||||
return u
|
||||
|
||||
def check_s(a):
|
||||
u,s,v = anp.linalg.svd(a, full_matrices=0)
|
||||
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
||||
return s
|
||||
|
||||
def check_v(a):
|
||||
u,s,v = anp.linalg.svd(a, full_matrices=0)
|
||||
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
||||
return v
|
||||
|
||||
for i in range(50):
|
||||
#not for full-matrices!
|
||||
a = jt.random((2,2,5,4))
|
||||
# not for full-matrices!
|
||||
a = jt.random((2, 2, 5, 4))
|
||||
c_a = anp.array(a.data)
|
||||
u,s,v = jt.linalg.svd(a)
|
||||
tu,ts,tv = check_svd(c_a)
|
||||
assert np.allclose(tu,u.data)
|
||||
assert np.allclose(ts,s.data)
|
||||
assert np.allclose(tv,v.data)
|
||||
ju = jt.grad(u,a)
|
||||
js = jt.grad(s,a)
|
||||
jv = jt.grad(v,a)
|
||||
u, s, v = jt.linalg.svd(a)
|
||||
tu, ts, tv = check_svd(c_a)
|
||||
assert np.allclose(tu, u.data)
|
||||
assert np.allclose(ts, s.data)
|
||||
assert np.allclose(tv, v.data)
|
||||
ju = jt.grad(u, a)
|
||||
js = jt.grad(s, a)
|
||||
jv = jt.grad(v, a)
|
||||
grad_u = jacobian(check_u)
|
||||
gu = grad_u(c_a)
|
||||
gu = np.sum(gu, 4)
|
||||
|
@ -69,56 +72,56 @@ class TestCodeOp(unittest.TestCase):
|
|||
gv = np.sum(gv, 2)
|
||||
gv = np.sum(gv, 2)
|
||||
try:
|
||||
assert np.allclose(ju.data,gu,atol=1e-5)
|
||||
assert np.allclose(ju.data, gu, atol=1e-5)
|
||||
except AssertionError:
|
||||
print(ju.data)
|
||||
print(gu)
|
||||
try:
|
||||
assert np.allclose(js.data,gs,atol=1e-5)
|
||||
assert np.allclose(js.data, gs, atol=1e-5)
|
||||
except AssertionError:
|
||||
print(js.data)
|
||||
print(gs)
|
||||
try:
|
||||
assert np.allclose(jv.data,gv,atol=1e-5)
|
||||
assert np.allclose(jv.data, gv, atol=1e-5)
|
||||
except AssertionError:
|
||||
print(jv.data)
|
||||
print(gv)
|
||||
|
||||
def test_eigh(self):
|
||||
def check_eigh(a,UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a,UPLO)
|
||||
def check_eigh(a, UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a, UPLO)
|
||||
return w, v
|
||||
|
||||
def check_w(a,UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a,UPLO)
|
||||
def check_w(a, UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a, UPLO)
|
||||
return w
|
||||
|
||||
def check_v(a,UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a,UPLO)
|
||||
def check_v(a, UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a, UPLO)
|
||||
return v
|
||||
|
||||
for i in range(50):
|
||||
a = jt.random((2,2,3,3))
|
||||
a = jt.random((2, 2, 3, 3))
|
||||
c_a = a.data
|
||||
w, v = jt.linalg.eigh(a)
|
||||
tw, tv = check_eigh(c_a)
|
||||
assert np.allclose(w.data,tw)
|
||||
assert np.allclose(v.data,tv)
|
||||
assert np.allclose(w.data, tw)
|
||||
assert np.allclose(v.data, tv)
|
||||
jw = jt.grad(w, a)
|
||||
jv = jt.grad(v, a)
|
||||
check_gw = jacobian(check_w)
|
||||
check_gv = jacobian(check_v)
|
||||
gw = check_gw(c_a)
|
||||
gw = np.sum(gw,4)
|
||||
gw = np.sum(gw,2)
|
||||
gw = np.sum(gw,2)
|
||||
assert np.allclose(gw,jw.data,rtol = 1,atol = 5e-8)
|
||||
gw = np.sum(gw, 4)
|
||||
gw = np.sum(gw, 2)
|
||||
gw = np.sum(gw, 2)
|
||||
assert np.allclose(gw, jw.data, rtol=1, atol=5e-8)
|
||||
gv = check_gv(c_a)
|
||||
gv = np.sum(gv,4)
|
||||
gv = np.sum(gv,4)
|
||||
gv = np.sum(gv,2)
|
||||
gv = np.sum(gv,2)
|
||||
assert np.allclose(gv,jv.data,rtol = 1,atol = 5e-8)
|
||||
gv = np.sum(gv, 4)
|
||||
gv = np.sum(gv, 4)
|
||||
gv = np.sum(gv, 2)
|
||||
gv = np.sum(gv, 2)
|
||||
assert np.allclose(gv, jv.data, rtol=1, atol=5e-8)
|
||||
|
||||
def test_pinv(self):
|
||||
def check_pinv(a):
|
||||
|
@ -126,34 +129,35 @@ class TestCodeOp(unittest.TestCase):
|
|||
return w
|
||||
|
||||
for i in range(50):
|
||||
x = jt.random((2,2,4,4))
|
||||
x = jt.random((2, 2, 4, 3))
|
||||
c_a = x.data
|
||||
mx = jt.linalg.pinv(x)
|
||||
tx = check_pinv(c_a)
|
||||
np.allclose(mx.data,tx)
|
||||
jx = jt.grad(mx,x)
|
||||
np.allclose(mx.data, tx)
|
||||
jx = jt.grad(mx, x)
|
||||
check_grad = jacobian(check_pinv)
|
||||
gx = check_grad(c_a)
|
||||
np.allclose(gx,jx.data)
|
||||
np.allclose(gx, jx.data)
|
||||
|
||||
def test_inv(self):
|
||||
def check_inv(a):
|
||||
w = anp.linalg.inv(a)
|
||||
return w
|
||||
|
||||
for i in range(50):
|
||||
tn = np.random.randn(4,4).astype('float32')*5
|
||||
while np.allclose(np.linalg.det(tn),0):
|
||||
tn = np.random.randn((4,4)).astype('float32')*5
|
||||
tn = np.random.randn(4, 4).astype('float32') * 5
|
||||
while np.allclose(np.linalg.det(tn), 0):
|
||||
tn = np.random.randn((4, 4)).astype('float32') * 5
|
||||
x = jt.array(tn)
|
||||
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
|
||||
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
||||
c_a = x.data
|
||||
mx = jt.linalg.inv(x)
|
||||
tx = check_inv(c_a)
|
||||
np.allclose(mx.data,tx)
|
||||
jx = jt.grad(mx,x)
|
||||
np.allclose(mx.data, tx)
|
||||
jx = jt.grad(mx, x)
|
||||
check_grad = jacobian(check_inv)
|
||||
gx = check_grad(c_a)
|
||||
np.allclose(gx,jx.data)
|
||||
np.allclose(gx, jx.data)
|
||||
|
||||
def test_slogdet(self):
|
||||
def check_ans(a):
|
||||
|
@ -165,11 +169,11 @@ class TestCodeOp(unittest.TestCase):
|
|||
return w
|
||||
|
||||
for i in range(50):
|
||||
tn = np.random.randn(4,4).astype('float32')*10
|
||||
while np.allclose(np.linalg.det(tn),0):
|
||||
tn = np.random.randn((4,4)).astype('float32')*10
|
||||
tn = np.random.randn(4, 4).astype('float32') * 10
|
||||
while np.allclose(np.linalg.det(tn), 0):
|
||||
tn = np.random.randn((4, 4)).astype('float32') * 10
|
||||
x = jt.array(tn)
|
||||
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
|
||||
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
||||
s = list(x.shape)
|
||||
det_s = s[:-2]
|
||||
if len(det_s) == 0:
|
||||
|
@ -178,12 +182,12 @@ class TestCodeOp(unittest.TestCase):
|
|||
ts, ta = check_ans(x.data)
|
||||
assert np.allclose(sign.data, ts)
|
||||
assert np.allclose(mx.data, ta)
|
||||
jx = jt.grad(mx,x)
|
||||
jx = jt.grad(mx, x)
|
||||
check_sgrad = jacobian(check_slogdet)
|
||||
gx = check_sgrad(x.data)
|
||||
gx = np.sum(gx,2)
|
||||
gx = np.sum(gx,2)
|
||||
assert np.allclose(gx,jx.data)
|
||||
gx = np.sum(gx, 2)
|
||||
gx = np.sum(gx, 2)
|
||||
assert np.allclose(gx, jx.data)
|
||||
|
||||
def test_cholesky(self):
|
||||
def check_cholesky(a):
|
||||
|
@ -192,39 +196,39 @@ class TestCodeOp(unittest.TestCase):
|
|||
|
||||
for i in range(50):
|
||||
x = jt.array(np.diag((np.random.rand(3) + 1) * 2))
|
||||
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
|
||||
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
||||
tx = x.data
|
||||
L = jt.linalg.cholesky(x)
|
||||
tL = check_cholesky(tx)
|
||||
assert np.allclose(tL,L.data)
|
||||
jx = jt.grad(L,x)
|
||||
assert np.allclose(tL, L.data)
|
||||
jx = jt.grad(L, x)
|
||||
check_grad = jacobian(check_cholesky)
|
||||
gx = check_grad(tx)
|
||||
gx = np.sum(gx, 0)
|
||||
gx = np.sum(gx, 0)
|
||||
gx = np.sum(gx, 0)
|
||||
gx = np.sum(gx, 0)
|
||||
assert np.allclose(jx.data,gx)
|
||||
assert np.allclose(jx.data, gx)
|
||||
|
||||
def test_solve(self):
|
||||
def check_solve(a,b):
|
||||
ans = anp.linalg.solve(a,b)
|
||||
def check_solve(a, b):
|
||||
ans = anp.linalg.solve(a, b)
|
||||
return ans
|
||||
|
||||
for i in range(50):
|
||||
a = jt.random((2,2,3,3))
|
||||
b = jt.random((2,2,3))
|
||||
ans = jt.linalg.solve(a,b)
|
||||
ta = check_solve(a.data,b.data)
|
||||
a = jt.random((2, 2, 3, 3))
|
||||
b = jt.random((2, 2, 3))
|
||||
ans = jt.linalg.solve(a, b)
|
||||
ta = check_solve(a.data, b.data)
|
||||
assert np.allclose(ans.data, ta)
|
||||
jx = jt.grad(ans, a)
|
||||
check_sgrad = jacobian(check_solve)
|
||||
gx = check_sgrad(a.data,b.data)
|
||||
gx = np.sum(gx,0)
|
||||
gx = np.sum(gx,0)
|
||||
gx = np.sum(gx,0)
|
||||
gx = check_sgrad(a.data, b.data)
|
||||
gx = np.sum(gx, 0)
|
||||
gx = np.sum(gx, 0)
|
||||
gx = np.sum(gx, 0)
|
||||
try:
|
||||
assert np.allclose(gx, jx.data,rtol=1)
|
||||
assert np.allclose(gx, jx.data, rtol=1)
|
||||
except AssertionError:
|
||||
print(gx)
|
||||
print(jx.data)
|
||||
|
@ -254,7 +258,42 @@ class TestCodeOp(unittest.TestCase):
|
|||
gx = np.sum(gx, 2)
|
||||
assert np.allclose(gx, jx.data)
|
||||
|
||||
def test_qr(self):
|
||||
for i in range(50):
|
||||
tn = np.random.randn(3, 3).astype('float32')
|
||||
while np.allclose(np.linalg.det(tn), 0):
|
||||
tn = np.random.randn((3, 3)).astype('float32')
|
||||
x = jt.array(tn)
|
||||
# x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
||||
t_x = torch.from_numpy(tn)
|
||||
t_x = Variable(t_x, requires_grad=True)
|
||||
jq, jr = jt.linalg.qr(x)
|
||||
tq, tr = torch.qr(t_x)
|
||||
try:
|
||||
assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
|
||||
assert np.allclose(jr.data, tr.detach().numpy(), rtol=1e-4, atol=1e-6)
|
||||
except AssertionError:
|
||||
print("ours' qr results:")
|
||||
print(jq)
|
||||
print(jr)
|
||||
print("pytorch's qr results:")
|
||||
print(tq)
|
||||
print(tr)
|
||||
gq = jt.grad(jq, x).data
|
||||
gr = jt.grad(jr, x).data
|
||||
tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True)
|
||||
tgr = torch.autograd.grad(tr, t_x, torch.ones_like(tr), retain_graph=True)
|
||||
try:
|
||||
assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
|
||||
assert np.allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6)
|
||||
except AssertionError:
|
||||
print("ours' qr grad results:")
|
||||
print(gq)
|
||||
print(gr)
|
||||
print("pytorch's qr grad result")
|
||||
print(tgq[0])
|
||||
print(tgr[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
|
Loading…
Reference in New Issue