JittorMirror/python/jittor/linalg.py

326 lines
9.4 KiB
Python

# ***************************************************************
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Haoyang Peng <2247838039@qq.com>
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from functools import partial
#TODO:full_matrices=1
def svd(x):
def forward_code(np, data):
a = data["inputs"][0]
u, s, v = data["outputs"]
#TODO:remove copyto
tu, ts, tv = np.linalg.svd(a, full_matrices=0)
np.copyto(u, tu)
np.copyto(s, ts)
np.copyto(v, tv)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
out_index = data["out_index"]
u, s, v = data["f_outputs"]
v = T(v)
m, n = inp.shape[-2:]
k = np.min((m, n))
i = np.reshape(np.eye(k), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k))))
if out_index == 0:
f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i)
gu = dout
utgu = _dot(T(u), gu)
t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :]
t = _dot(_dot(u, t), T(v))
if m > n:
i_minus_uut = (np.reshape(np.eye(m), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (m, m)))) -
_dot(u, np.conj(T(u))))
t = t + T(_dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut))
np.copyto(out, t)
elif out_index == 1:
gs = dout
t = i * gs[..., :, np.newaxis]
t = _dot(_dot(u, t), T(v))
np.copyto(out, t)
elif out_index == 2:
f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i)
gv = dout
vtgv = _dot(T(v), gv)
t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv)))
t = _dot(_dot(u, t), T(v))
if m < n:
i_minus_vvt = (np.reshape(np.eye(n), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (n, n)))) -
_dot(v, np.conj(T(v))))
t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
np.copyto(out, t)
m, n = x.shape[-2:]
k = min(m, n)
s1 = list(x.shape)
s1[-1] = k
s2 = list(x.shape)
s2[-2] = k
s3 = list(x.shape)[:-2]
s3.append(k)
u, s, v = jt.numpy_code(
[s1, s3, s2],
[x.dtype, x.dtype, x.dtype],
[x],
forward_code,
[backward_code],
)
return u, s, v
def eigh(x):
def forward_code(np, data):
a = data["inputs"][0]
w, v = data["outputs"]
tw, tv = np.linalg.eigh(a, UPLO='L')
np.copyto(w, tw)
np.copyto(v, tv)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
out_index = data["out_index"]
w, v = data["f_outputs"]
k = int(inp.shape[-1])
w_repeated = np.repeat(w[..., np.newaxis], k, axis=-1)
if out_index == 0:
t = _dot(v * dout[..., np.newaxis, :], T(v))
np.copyto(out, t)
elif out_index == 1:
if np.any(dout):
off_diag = np.ones((k, k)) - np.eye(k)
F = off_diag / (T(w_repeated) - w_repeated + np.eye(k))
t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
np.copyto(out, t)
sw = x.shape[:-2] + x.shape[-1:]
sv = x.shape
w, v = jt.numpy_code(
[sw, sv],
[x.dtype, x.dtype],
[x],
forward_code,
[backward_code],
)
return w, v
def inv(x):
def forward_code(np, data):
a = data["inputs"][0]
m_a = data["outputs"][0]
t_a = np.linalg.inv(a)
np.copyto(m_a, t_a)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
lmx = data["f_outputs"]
mx = lmx[0]
t = -_dot(_dot(T(mx), dout), T(mx))
np.copyto(out, t)
lmx = jt.numpy_code(
[x.shape],
[x.dtype],
[x],
forward_code,
[backward_code],
)
mx = lmx[0]
return mx
def pinv(x):
def forward_code(np, data):
a = data["inputs"][0]
m_a = data["outputs"][0]
t_a = np.linalg.pinv(a)
np.copyto(m_a, t_a)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
lmx = data["f_outputs"]
mx = lmx[0]
t = T(
-_dot(_dot(mx, T(dout)), mx)
+ _dot(_dot(_dot(mx, T(mx)), dout), np.eye(inp.shape[-2]) - _dot(inp, mx))
+ _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx)
)
np.copyto(out, t)
lmx = jt.numpy_code(
[x.shape],
[x.dtype],
[x],
forward_code,
[backward_code],
)
mx = lmx[0]
return mx
def det(x):
def forward_code(np, data):
a = data["inputs"][0]
L = data["outputs"][0]
tL = np.linalg.det(a)
np.copyto(L, tL)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
f_out = data["f_outputs"][0]
inp = data["inputs"][0]
n_d = np.reshape(dout, np.shape(dout) + (1, 1))
n_o = np.reshape(f_out, np.shape(f_out) + (1, 1))
s = n_d * n_o * T(np.linalg.inv(inp))
np.copyto(out, s)
s = x.shape
x_s = s[:-2]
if len(s) == 2:
x_s.append(1)
l_det = jt.numpy_code(
[x_s],
[x.dtype],
[x],
forward_code,
[backward_code],
)
det = l_det[0]
return det
def slogdet(x):
def forward_code(np, data):
a = data["inputs"][0]
sign, m_a = data["outputs"]
sign_, t_a = np.linalg.slogdet(a)
np.copyto(m_a, t_a)
np.copyto(sign, sign_)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
inp = data["inputs"][0]
out_index = data["out_index"]
if out_index == 0:
np.copyto(out, 0)
if out_index == 1:
t = np.reshape(dout, np.shape(dout) + (1, 1))
t = t * T(np.linalg.inv(inp))
np.copyto(out, t)
s = x.shape
det_s = s[:-2]
if len(det_s) == 0:
det_s.append(1)
sign, mx = jt.numpy_code(
[det_s, det_s],
[x.dtype, x.dtype],
[x],
forward_code,
[backward_code],
)
return sign, mx
def cholesky(x):
def forward_code(np, data):
a = data["inputs"][0]
L = data["outputs"][0]
tL = np.linalg.cholesky(a)
np.copyto(L, tL)
def backward_code(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
f_out = data["f_outputs"][0]
solve_trans = lambda a, b: np.linalg.solve(T(a), b)
phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1]))
def conjugate_solve(L, X):
return solve_trans(L, T(solve_trans(L, T(X))))
s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout)))
s = (s + T(s)) / 2.
np.copyto(out, s)
lL = jt.numpy_code(
[x.shape],
[x.dtype],
[x],
forward_code,
[backward_code],
)
L = lL[0]
return L
def solve(a,b):
def forward_code(np, data):
a, b = data["inputs"]
L = data["outputs"][0]
ans = np.linalg.solve(a, b)
np.copyto(L, ans)
def backward_code1(np, data):
def T(x):
return np.swapaxes(x, -1, -2)
_dot = partial(np.einsum, '...ij,...jk->...ik')
dout = data["dout"]
out = data["outputs"][0]
f_out = data["f_outputs"][0]
inp = data["inputs"][0]
updim = lambda x: x if x.ndim == a.ndim else x[..., None]
t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out)))
np.copyto(out, t)
def backward_code2(np, data):
out = data["outputs"][0]
np.copyto(out, 0)
l_ans = jt.numpy_code(
[b.shape],
[b.dtype],
[a, b],
forward_code,
[backward_code1, backward_code2],
)
ans = l_ans[0]
return ans