mirror of https://github.com/Jittor/Jittor
Add files via upload
This commit is contained in:
parent
ab7e0c59ed
commit
212b2a8564
|
@ -185,6 +185,42 @@ def pinv(x):
|
||||||
mx = lmx[0]
|
mx = lmx[0]
|
||||||
return mx
|
return mx
|
||||||
|
|
||||||
|
def det(x):
|
||||||
|
from functools import partial
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
|
|
||||||
|
def forward_code(np, data):
|
||||||
|
a = data["inputs"][0]
|
||||||
|
L = data["outputs"][0]
|
||||||
|
tL = np.linalg.det(a)
|
||||||
|
np.copyto(L, tL)
|
||||||
|
|
||||||
|
def backward_code(np, data):
|
||||||
|
dout = data["dout"]
|
||||||
|
out = data["outputs"][0]
|
||||||
|
f_out = data["f_outputs"][0]
|
||||||
|
inp = data["inputs"][0]
|
||||||
|
n_d = np.reshape(dout, np.shape(dout) + (1, 1))
|
||||||
|
n_o = np.reshape(f_out, np.shape(f_out) + (1, 1))
|
||||||
|
s = n_d * n_o * T(np.linalg.inv(inp))
|
||||||
|
np.copyto(out, s)
|
||||||
|
|
||||||
|
s = jt.array(x.shape).data.tolist()
|
||||||
|
x_s = s[:-2]
|
||||||
|
if len(s) == 2:
|
||||||
|
x_s.append(1)
|
||||||
|
l_det = jt.numpy_code(
|
||||||
|
[x_s],
|
||||||
|
[x.dtype],
|
||||||
|
[x],
|
||||||
|
forward_code,
|
||||||
|
[backward_code],
|
||||||
|
)
|
||||||
|
det = l_det[0]
|
||||||
|
return det
|
||||||
|
|
||||||
def slogdet(x):
|
def slogdet(x):
|
||||||
from functools import partial
|
from functools import partial
|
||||||
def T(x):
|
def T(x):
|
||||||
|
@ -221,3 +257,74 @@ def slogdet(x):
|
||||||
[backward_code],
|
[backward_code],
|
||||||
)
|
)
|
||||||
return sign, mx
|
return sign, mx
|
||||||
|
|
||||||
|
def cholesky(x):
|
||||||
|
from functools import partial
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
|
|
||||||
|
def forward_code(np, data):
|
||||||
|
a = data["inputs"][0]
|
||||||
|
L = data["outputs"][0]
|
||||||
|
tL = np.linalg.cholesky(a)
|
||||||
|
np.copyto(L, tL)
|
||||||
|
|
||||||
|
def backward_code(np, data):
|
||||||
|
dout = data["dout"]
|
||||||
|
out = data["outputs"][0]
|
||||||
|
f_out = data["f_outputs"][0]
|
||||||
|
solve_trans = lambda a, b: np.linalg.solve(T(a), b)
|
||||||
|
phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1]))
|
||||||
|
|
||||||
|
def conjugate_solve(L, X):
|
||||||
|
return solve_trans(L, T(solve_trans(L, T(X))))
|
||||||
|
|
||||||
|
s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout)))
|
||||||
|
s = (s + T(s)) / 2.
|
||||||
|
np.copyto(out, s)
|
||||||
|
|
||||||
|
lL = jt.numpy_code(
|
||||||
|
[x.shape],
|
||||||
|
[x.dtype],
|
||||||
|
[x],
|
||||||
|
forward_code,
|
||||||
|
[backward_code],
|
||||||
|
)
|
||||||
|
L = lL[0]
|
||||||
|
return L
|
||||||
|
|
||||||
|
def solve(a,b):
|
||||||
|
from functools import partial
|
||||||
|
def T(x):
|
||||||
|
return np.swapaxes(x, -1, -2)
|
||||||
|
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||||
|
|
||||||
|
def forward_code(np, data):
|
||||||
|
a, b = data["inputs"]
|
||||||
|
L = data["outputs"][0]
|
||||||
|
ans = np.linalg.solve(a, b)
|
||||||
|
np.copyto(L, ans)
|
||||||
|
|
||||||
|
def backward_code1(np, data):
|
||||||
|
dout = data["dout"]
|
||||||
|
out = data["outputs"][0]
|
||||||
|
f_out = data["f_outputs"][0]
|
||||||
|
inp = data["inputs"][0]
|
||||||
|
updim = lambda x: x if x.ndim == a.ndim else x[..., None]
|
||||||
|
t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out)))
|
||||||
|
np.copyto(out, t)
|
||||||
|
|
||||||
|
def backward_code2(np, data):
|
||||||
|
out = data["outputs"][0]
|
||||||
|
np.copyto(out, 0)
|
||||||
|
|
||||||
|
l_ans = jt.numpy_code(
|
||||||
|
[b.shape],
|
||||||
|
[b.dtype],
|
||||||
|
[a, b],
|
||||||
|
forward_code,
|
||||||
|
[backward_code1, backward_code2],
|
||||||
|
)
|
||||||
|
ans = l_ans[0]
|
||||||
|
return ans
|
Loading…
Reference in New Issue