mirror of https://github.com/Jittor/Jittor
Merge branch 'master' into doc
This commit is contained in:
commit
8b987b928e
|
@ -3,8 +3,55 @@ jittor.linalg
|
|||
|
||||
这里是Jittor的线性代数函数的API文档,您可以通过`from jittor import linalg`来获取该模块。
|
||||
|
||||
## 基本函数简介
|
||||
#### 基本线性代数运算API
|
||||
- linalg.inv(a)
|
||||
|
||||
对a进行求逆运算
|
||||
|
||||
- linalg.pinv(a)
|
||||
|
||||
对a进行广义求逆运算。该运算不要求原矩阵a可逆。
|
||||
|
||||
- linalg.slogdet(a)
|
||||
|
||||
对a求取slogdet。会返回值以及符号。
|
||||
|
||||
- linalg.det(a)
|
||||
|
||||
对a求行列式。
|
||||
|
||||
- linalg.solve(a,b)
|
||||
|
||||
求解线性方程Ax=b的解。
|
||||
|
||||
#### 分解API
|
||||
- linalg.cholesky(a)
|
||||
|
||||
对a进行cholesky分解。
|
||||
|
||||
- linalg.qr(a)
|
||||
|
||||
对a进行qr分解。
|
||||
|
||||
- linalg.svd
|
||||
|
||||
对a进行奇异值分解。
|
||||
#### 特征值API
|
||||
- linalg.eig(a)
|
||||
|
||||
求取a的特征值以及特征向量。
|
||||
|
||||
- linalg.eigh(a)
|
||||
|
||||
针对埃尔米特矩阵或者对称矩阵求特征值以及特征向量。
|
||||
|
||||
|
||||
目前的linalg库支持
|
||||
|
||||
```eval_rst
|
||||
.. automodule:: jittor.linalg
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
||||
|
||||
|
|
|
@ -8,3 +8,11 @@ jittor.optim
|
|||
:members:
|
||||
:undoc-members:
|
||||
```
|
||||
|
||||
以下是Jittor的学习率调度模块的API文档,学习率调度模块需要配合优化器使用,您可以通过`from jittor import lr_scheduler`来获取该模块。
|
||||
|
||||
```eval_rst
|
||||
.. automodule:: jittor.lr_scheduler
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.34'
|
||||
__version__ = '1.2.2.48'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -709,7 +709,7 @@ class Module:
|
|||
def __init__(self, *args, **kw):
|
||||
pass
|
||||
def execute(self, *args, **kw):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
def __call__(self, *args, **kw):
|
||||
return self.execute(*args, **kw)
|
||||
def __repr__(self):
|
||||
|
|
|
@ -82,6 +82,7 @@ def setup_mkl():
|
|||
|
||||
def install_cub(root_folder):
|
||||
url = "https://github.com/NVIDIA/cub/archive/1.11.0.tar.gz"
|
||||
url = "https://codeload.github.com/NVIDIA/cub/tar.gz/1.11.0"
|
||||
filename = "cub-1.11.0.tgz"
|
||||
md5 = "97196a885598e40592100e1caaf3d5ea"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
|
@ -196,6 +197,7 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
def install_cutt(root_folder):
|
||||
# Modified from: https://github.com/ap-hynninen/cutt
|
||||
url = "https://github.com/Jittor/cutt/archive/master.zip"
|
||||
url = "https://codeload.github.com/Jittor/cutt/zip/master"
|
||||
|
||||
filename = "cutt-master.zip"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
|
@ -270,12 +272,13 @@ def setup_cutt():
|
|||
|
||||
|
||||
def install_nccl(root_folder):
|
||||
url = "https://github.com/NVIDIA/nccl/archive/v2.6.4-1.tar.gz"
|
||||
url = "https://github.com/NVIDIA/nccl/archive/v2.8.4-1.tar.gz"
|
||||
url = "https://codeload.github.com/NVIDIA/nccl/tar.gz/v2.8.4-1"
|
||||
|
||||
filename = "nccl.tgz"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, "nccl-2.6.4-1")
|
||||
true_md5 = "38d7a9e98d95a99df0a4f1ad6fb50fa7"
|
||||
dirname = os.path.join(root_folder, "nccl-2.8.4-1")
|
||||
true_md5 = "900666558c5bc43e0a5e84045b88a06f"
|
||||
|
||||
if os.path.exists(fullname):
|
||||
md5 = run_cmd('md5sum '+fullname).split()[0]
|
||||
|
|
|
@ -155,12 +155,12 @@ def slice_var_index(x, slices):
|
|||
x.stop_fuse()
|
||||
return (out_shape, out_index, 0, [], extras)
|
||||
|
||||
def slice_var(x, slices):
|
||||
def _slice_var_old(x, slices):
|
||||
reindex_args = slice_var_index(x, slices)
|
||||
x.stop_fuse()
|
||||
return x.reindex(*reindex_args).stop_fuse()
|
||||
|
||||
def setitem(x, slices, value):
|
||||
def _setitem_old(x, slices, value):
|
||||
reindex_args = slice_var_index(x, slices)
|
||||
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
|
||||
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
|
||||
|
@ -176,14 +176,11 @@ def setitem(x, slices, value):
|
|||
x.assign(out)
|
||||
return x
|
||||
|
||||
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
|
||||
jt.Var.__setitem__ = setitem
|
||||
|
||||
# PATCH
|
||||
def getitem(x, slices):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||
return getitem(x, slices.where())
|
||||
if isinstance(slices, Sequence):
|
||||
return getitem(x, tuple(slices.where()))
|
||||
if isinstance(slices, tuple):
|
||||
ss = []
|
||||
for s in slices:
|
||||
if isinstance(s, jt.Var) and s.dtype == "bool":
|
||||
|
@ -195,10 +192,8 @@ def getitem(x, slices):
|
|||
|
||||
def setitem(x, slices, value):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||
mask = jt.broadcast(slices, x)
|
||||
value = jt.broadcast(value, x)
|
||||
return x.assign(mask.ternary(value, x))
|
||||
if isinstance(slices, Sequence):
|
||||
slices = tuple(slices.where())
|
||||
elif isinstance(slices, tuple):
|
||||
ss = []
|
||||
for s in slices:
|
||||
if isinstance(s, jt.Var) and s.dtype == "bool":
|
||||
|
|
|
@ -218,8 +218,11 @@ class Dataset(object):
|
|||
img_open_hook.duration
|
||||
img_open_hook.duration = 0.0
|
||||
except:
|
||||
import traceback
|
||||
line = traceback.format_exc()
|
||||
print(line)
|
||||
os.kill(os.getppid(), signal.SIGINT)
|
||||
raise
|
||||
exit(0)
|
||||
|
||||
def display_worker_status(self):
|
||||
''' Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow:
|
||||
|
|
|
@ -85,10 +85,10 @@ class MNIST(Dataset):
|
|||
Download mnist data set function, this function will be called when download is True.
|
||||
'''
|
||||
resources = [
|
||||
("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
|
||||
("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
|
||||
("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
|
||||
("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
|
||||
("https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
|
||||
("https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
|
||||
("https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
|
||||
("https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
|
||||
]
|
||||
|
||||
for url, md5 in resources:
|
||||
|
|
|
@ -11,9 +11,21 @@
|
|||
import jittor as jt
|
||||
from functools import partial
|
||||
|
||||
|
||||
#TODO:full_matrices=1
|
||||
def svd(x):
|
||||
|
||||
r'''
|
||||
calculate the Singular Value Decomposition of x.It follows the below fomula:
|
||||
x = usv*
|
||||
only support full matrices == False ver now, which means:
|
||||
x's shape (...,M,K)
|
||||
u's shape (...,M,K)
|
||||
s's shape (...,K)
|
||||
v's shape (...,K,N)
|
||||
where K is min(M,N).
|
||||
:param x:
|
||||
:return:u,s,v.
|
||||
'''
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
u, s, v = data["outputs"]
|
||||
|
@ -81,8 +93,15 @@ def svd(x):
|
|||
)
|
||||
return u, s, v
|
||||
|
||||
def eigh(x):
|
||||
|
||||
def eigh(x):
|
||||
r"""
|
||||
calculate the eigenvalues and eigenvectors of x.
|
||||
:param x (...,M,M):
|
||||
:return:w, v.
|
||||
w (...,M) : the eigenvalues.
|
||||
v (...,M,M) : normalized eigenvectors.
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
w, v = data["outputs"]
|
||||
|
@ -122,8 +141,13 @@ def eigh(x):
|
|||
)
|
||||
return w, v
|
||||
|
||||
def inv(x):
|
||||
|
||||
def inv(x):
|
||||
r"""
|
||||
calculate the inverse of x.
|
||||
:param x (...,M,M):
|
||||
:return:x^-1 (...,M,M).
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
m_a = data["outputs"][0]
|
||||
|
@ -151,8 +175,13 @@ def inv(x):
|
|||
mx = lmx[0]
|
||||
return mx
|
||||
|
||||
def pinv(x):
|
||||
|
||||
def pinv(x):
|
||||
r"""
|
||||
calculate the pseudo-inverse of a x.
|
||||
:param x (...,M,N)
|
||||
:return: x's pinv (...N,M)
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
m_a = data["outputs"][0]
|
||||
|
@ -174,9 +203,9 @@ def pinv(x):
|
|||
+ _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx)
|
||||
)
|
||||
np.copyto(out, t)
|
||||
|
||||
sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]]
|
||||
lmx = jt.numpy_code(
|
||||
[x.shape],
|
||||
[sw],
|
||||
[x.dtype],
|
||||
[x],
|
||||
forward_code,
|
||||
|
@ -185,8 +214,13 @@ def pinv(x):
|
|||
mx = lmx[0]
|
||||
return mx
|
||||
|
||||
def det(x):
|
||||
|
||||
def det(x):
|
||||
r"""
|
||||
calculate the determinant of x.
|
||||
:param x (...,M,M):
|
||||
:return:|x| (...,1)
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
L = data["outputs"][0]
|
||||
|
@ -220,7 +254,15 @@ def det(x):
|
|||
det = l_det[0]
|
||||
return det
|
||||
|
||||
|
||||
def slogdet(x):
|
||||
r"""
|
||||
calculate the sign and log of the determinant of x.
|
||||
:param x (...,M,M):
|
||||
:return sign, x's logdet.
|
||||
sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf.
|
||||
logdet in shape (...,1).
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
sign, m_a = data["outputs"]
|
||||
|
@ -256,8 +298,15 @@ def slogdet(x):
|
|||
)
|
||||
return sign, mx
|
||||
|
||||
def cholesky(x):
|
||||
|
||||
def cholesky(x):
|
||||
r"""
|
||||
do Cholesky decomposition of x in the form of below formula:
|
||||
x = LL^T
|
||||
x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix.
|
||||
:param x (...,M,M):
|
||||
:return: L (...,M,M).
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
L = data["outputs"][0]
|
||||
|
@ -291,8 +340,14 @@ def cholesky(x):
|
|||
L = lL[0]
|
||||
return L
|
||||
|
||||
def solve(a,b):
|
||||
|
||||
def solve(a,b):
|
||||
r"""
|
||||
Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular.
|
||||
:param a:(...,M,M)
|
||||
:param b:(...,M)
|
||||
:return:solution of Ax = b formula.x in the shape of (...M)
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a, b = data["inputs"]
|
||||
L = data["outputs"][0]
|
||||
|
@ -323,4 +378,55 @@ def solve(a,b):
|
|||
[backward_code1, backward_code2],
|
||||
)
|
||||
ans = l_ans[0]
|
||||
return ans
|
||||
return ans
|
||||
|
||||
|
||||
def qr(x):
|
||||
r"""
|
||||
do the qr factorization of x in the below formula:
|
||||
x = QR where Q is orthogonal matrix and R is upper-triangle matrix.
|
||||
:param x (...,M,M):
|
||||
:return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
|
||||
"""
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
q, r = data["outputs"]
|
||||
Q, R = np.linalg.qr(a)
|
||||
np.copyto(q,Q)
|
||||
np.copyto(r,R)
|
||||
|
||||
def backward_code(np, data):
|
||||
def T(x):
|
||||
return np.swapaxes(x, -1, -2)
|
||||
_dot = partial(np.einsum, '...ij,...jk->...ik')
|
||||
_harmard = partial(np.einsum, '...ij,...ij->...ij')
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
q, r = data["f_outputs"]
|
||||
out_index = data["out_index"]
|
||||
#pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags
|
||||
if out_index == 0: # Q_TERM
|
||||
q_t = _dot(T(q),dout)
|
||||
rhs_solve = q_t - T(q_t)
|
||||
rhs_solve = T(np.tril(rhs_solve,-1))
|
||||
qsolve = np.linalg.solve(r,rhs_solve)
|
||||
qsolve = T(qsolve)
|
||||
tq = _dot(q,qsolve)
|
||||
np.copyto(out,tq)
|
||||
else: #R_TERM
|
||||
r_t = _dot(r ,T(dout))
|
||||
rhs_solve = r_t - T(r_t)
|
||||
rhs_solve = np.tril(rhs_solve,-1)
|
||||
rhs_solve = T(rhs_solve)
|
||||
r_solve = np.linalg.solve(r,rhs_solve)
|
||||
tr = _dot(q,(T(r_solve) + dout))
|
||||
np.copyto(out,tr)
|
||||
|
||||
q, r = jt.numpy_code(
|
||||
[x.shape,x.shape],
|
||||
[x.dtype,x.dtype],
|
||||
[x],
|
||||
forward_code,
|
||||
[backward_code],
|
||||
)
|
||||
return q, r
|
||||
|
|
|
@ -13,6 +13,28 @@ import numpy as np
|
|||
import math
|
||||
from collections.abc import Sequence,Iterable
|
||||
|
||||
def index_add_(x, dim, index, tensor):
|
||||
""" Take out each index subscript vector of the dim dimension and add the corresponding tensor variable.
|
||||
|
||||
Example:
|
||||
|
||||
x = jt.ones((5,3))
|
||||
tensor = jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
index = jt.array([0,4,2])
|
||||
x.index_add_(0, index, tensor)
|
||||
print(x)
|
||||
|
||||
>>> jt.Var([[ 2., 3., 4.],
|
||||
[ 1., 1., 1.],
|
||||
[ 8., 9., 10.],
|
||||
[ 1., 1., 1.],
|
||||
[ 5., 6., 7.]])
|
||||
"""
|
||||
assert len(index.shape) == 1
|
||||
assert tensor.shape[0] == index.shape[0]
|
||||
x[(slice(None,),)*dim+(index,)] += tensor
|
||||
jt.Var.index_add_ = index_add_
|
||||
|
||||
def __copy__(x):
|
||||
return x.copy().detach()
|
||||
jt.Var.__copy__ = __copy__
|
||||
|
@ -897,7 +919,7 @@ def auto_parallel(n, src, **kw):
|
|||
tid_def += f"\nauto tnum{i} = 1<<tn{i};"
|
||||
tid_def += f"\ntid = tid>>tn{i};"
|
||||
for i in range(n):
|
||||
tid_loop += f"\nfor (int i{i}=tid{i}; i{i}<{pnargs2[i]}; i{i}+=tn{i})"
|
||||
tid_loop += f"\nfor (int i{i}=tid{i}; i{i}<{pnargs2[i]}; i{i}+=tnum{i})"
|
||||
call_args.append(pnargs2[i])
|
||||
call_args.append(f"i{i}")
|
||||
call_args += oargs2
|
||||
|
@ -1073,3 +1095,89 @@ inline static void searchsorted(
|
|||
cpu_src=_searchsorted_src,
|
||||
cuda_header=_searchsorted_header,
|
||||
cuda_src=_searchsorted_src)
|
||||
|
||||
|
||||
def scatter(x:jt.Var, dim:int, index:jt.Var, src:jt.Var, reduce='void'):
|
||||
''' if x is a 3-D array, rewrite x like:
|
||||
|
||||
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
|
||||
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
|
||||
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
|
||||
|
||||
Parameters::
|
||||
|
||||
* x (jt.Var) – input array
|
||||
* dim (int) – the axis along which to index
|
||||
* index (jt.Var) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
|
||||
* src (jt.Var) – the source element(s) to scatter.
|
||||
* reduce (str, optional) – reduction operation to apply, can be either 'add' or 'multiply'.
|
||||
|
||||
Example::
|
||||
|
||||
src = jt.arange(1, 11).reshape((2, 5))
|
||||
index = jt.array([[0, 1, 2, 0]])
|
||||
x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src)
|
||||
assert (x.data ==
|
||||
[[1, 0, 0, 4, 0],
|
||||
[0, 2, 0, 0, 0],
|
||||
[0, 0, 3, 0, 0]]).all()
|
||||
index = jt.array([[0, 1, 2], [0, 1, 4]])
|
||||
x = jt.zeros((3, 5), dtype=src.dtype).scatter_(1, index, src)
|
||||
assert (x.data ==
|
||||
[[1, 2, 3, 0, 0],
|
||||
[6, 7, 0, 0, 8],
|
||||
[0, 0, 0, 0, 0]]).all()
|
||||
x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]),
|
||||
jt.array(1.23), reduce='multiply')
|
||||
assert np.allclose(x.data,
|
||||
[[2.0000, 2.0000, 2.4600, 2.0000],
|
||||
[2.0000, 2.0000, 2.0000, 2.4600]]), x
|
||||
x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]),
|
||||
jt.array(1.23), reduce='add')
|
||||
assert np.allclose(x.data,
|
||||
[[2.0000, 2.0000, 3.2300, 2.0000],
|
||||
[2.0000, 2.0000, 2.0000, 3.2300]])
|
||||
|
||||
'''
|
||||
shape = index.shape
|
||||
if src.shape != shape and src.numel() != 1:
|
||||
src = src[tuple( slice(None,s) for s in shape )]
|
||||
indexes = [ f'i{i}' for i in range(len(shape)) ]
|
||||
indexes[dim] = index
|
||||
return x.setitem(tuple(indexes), src, reduce)
|
||||
|
||||
def scatter_(x, dim, index, src, reduce='void'):
|
||||
return x.assign(x.scatter(dim, index, src, reduce))
|
||||
|
||||
jt.Var.scatter = scatter
|
||||
jt.Var.scatter_ = scatter_
|
||||
|
||||
def gather(x, dim, index):
|
||||
''' if x is a 3-D array, reindex x like:
|
||||
|
||||
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
|
||||
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
|
||||
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
|
||||
|
||||
|
||||
Parameters::
|
||||
|
||||
* input (jt.Var) – the source array
|
||||
* dim (int) – the axis along which to index
|
||||
* index (jt.Var) – the indices of elements to gather
|
||||
|
||||
Example::
|
||||
|
||||
t = jt.array([[1, 2], [3, 4]])
|
||||
data = t.gather(1, jt.array([[0, 0], [1, 0]]))
|
||||
assert (data.data == [[ 1, 1], [ 4, 3]]).all()
|
||||
data = t.gather(0, jt.array([[0, 0], [1, 0]]))
|
||||
assert (data.data == [[ 1, 2], [ 3, 2]]).all()
|
||||
|
||||
'''
|
||||
shape = index.shape
|
||||
indexes = [ f'i{i}' for i in range(len(shape)) ]
|
||||
indexes[dim] = index
|
||||
return x.getitem(tuple(indexes))
|
||||
|
||||
jt.Var.gather = gather
|
||||
|
|
|
@ -618,13 +618,17 @@ class Conv1d(Module):
|
|||
self.bias = bias
|
||||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||
self.conv = Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias)
|
||||
# using list to escape module dfs
|
||||
self._conv = [Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias)]
|
||||
self.weight = self._conv[0].weight.squeeze(-1)
|
||||
self.bias = self._conv[0].bias
|
||||
|
||||
def execute(self, x):
|
||||
N,C,D = x.shape
|
||||
assert C==self.in_channels
|
||||
self._conv[0].weight = self.weight.unsqueeze(-1)
|
||||
x = x.unsqueeze(-1)
|
||||
x = self.conv(x)
|
||||
x = self._conv[0](x)
|
||||
y = x.squeeze(-1)
|
||||
return y
|
||||
|
||||
|
@ -845,7 +849,7 @@ class ZeroPad2d(Module):
|
|||
self.pr = self.padding
|
||||
self.pt = self.padding
|
||||
self.pb = self.padding
|
||||
elif isinstance(self.padding, tuple):
|
||||
elif isinstance(self.padding, (tuple,list)):
|
||||
self.pl, self.pr, self.pt, self.pb = self.padding
|
||||
else:
|
||||
raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}")
|
||||
|
@ -953,11 +957,11 @@ def hardtanh(x,min_val=-1,max_val=1):
|
|||
class Softplus(Module):
|
||||
r'''
|
||||
SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
[in] beta (float): the beta value for the Softplus formulation. Default: 1.
|
||||
|
||||
|
||||
[in] threshold (float): values above this revert to a linear function. Default: 20.
|
||||
'''
|
||||
def __init__(self, beta=1, threshold=20):
|
||||
|
@ -976,59 +980,99 @@ class Resize(Module):
|
|||
def execute(self, x):
|
||||
return resize(x, self.size, self.mode, self.align_corners)
|
||||
|
||||
|
||||
def _bicubic(x, a, func):
|
||||
# normal ver
|
||||
if func == 1:
|
||||
return (a+2)*(jt.abs(x)**3)-(a+3)*(x**2)+1
|
||||
if func == 2:
|
||||
return a*(jt.abs(x)**3)-5*a*(x**2)+8*a*(jt.abs(x))-4*a
|
||||
return 0
|
||||
|
||||
|
||||
def _interpolate(img, x, y, ids, mode):
|
||||
if mode=="nearest":
|
||||
if mode == "nearest":
|
||||
return img.reindex([*ids, x.floor(), y.floor()])
|
||||
if mode=="bilinear":
|
||||
if mode == "bilinear":
|
||||
fx, fy = x.floor(), y.floor()
|
||||
cx, cy = fx+1, fy+1
|
||||
dx, dy = x-fx, y-fy
|
||||
cx, cy = fx + 1, fy + 1
|
||||
dx, dy = x - fx, y - fy
|
||||
a = img.reindex_var([*ids, fx, fy])
|
||||
b = img.reindex_var([*ids, cx, fy])
|
||||
c = img.reindex_var([*ids, fx, cy])
|
||||
d = img.reindex_var([*ids, cx, cy])
|
||||
dnx, dny = 1-dx, 1-dy
|
||||
ab = dx*b + dnx*a
|
||||
cd = dx*d + dnx*c
|
||||
o = ab*dny + cd*dy
|
||||
dnx, dny = 1 - dx, 1 - dy
|
||||
ab = dx * b + dnx * a
|
||||
cd = dx * d + dnx * c
|
||||
o = ab * dny + cd * dy
|
||||
return o
|
||||
raise(f"Not support interpolation mode: {mode}")
|
||||
if mode=="bicubic": # ugly ver.
|
||||
n,c,h,w = img.shape
|
||||
fx, fy = x.floor(), y.floor()
|
||||
dix, diy = x - fx, y - fy
|
||||
ax, ay = _bicubic(dix+1,-0.75,2), _bicubic(diy+1,-0.75,2)
|
||||
bx, by = _bicubic(dix,-0.75,1), _bicubic(diy,-0.75,1)
|
||||
cx, cy = _bicubic(1-dix,-0.75,1), _bicubic(1-diy,-0.75,1)
|
||||
dx, dy = _bicubic(2-dix,-0.75,2), _bicubic(2-diy,-0.75,2)
|
||||
afx, afy = jt.maximum(jt.minimum(fx-1,h-1),0), jt.maximum(jt.minimum(fy-1,w-1),0)
|
||||
bfx, bfy = jt.maximum(jt.minimum(fx,h-1),0), jt.maximum(jt.minimum(fy,w-1),0)
|
||||
cfx, cfy = jt.maximum(jt.minimum(fx+1,h-1),0), jt.maximum(jt.minimum(fy+1,w-1),0)
|
||||
dfx, dfy = jt.maximum(jt.minimum(fx+2,h-1),0), jt.maximum(jt.minimum(fy+2,w-1),0)
|
||||
a = ax*(img.reindex_var([*ids,afx,afy])*ay+img.reindex_var([*ids,afx,bfy])*by+img.reindex_var([*ids,afx,cfy])*cy+img.reindex_var([*ids,afx,dfy])*dy)
|
||||
b = bx*(img.reindex_var([*ids,bfx,afy])*ay+img.reindex_var([*ids,bfx,bfy])*by+img.reindex_var([*ids,bfx,cfy])*cy+img.reindex_var([*ids,bfx,dfy])*dy)
|
||||
c = cx*(img.reindex_var([*ids,cfx,afy])*ay+img.reindex_var([*ids,cfx,bfy])*by+img.reindex_var([*ids,cfx,cfy])*cy+img.reindex_var([*ids,cfx,dfy])*dy)
|
||||
d = dx*(img.reindex_var([*ids,dfx,afy])*ay+img.reindex_var([*ids,dfx,bfy])*by+img.reindex_var([*ids,dfx,cfy])*cy+img.reindex_var([*ids,dfx,dfy])*dy)
|
||||
o = a + b + c + d
|
||||
return o
|
||||
raise (f"Not support interpolation mode: {mode}")
|
||||
|
||||
|
||||
def resize(img, size, mode="nearest", align_corners=False):
|
||||
n,c,h,w = img.shape
|
||||
H,W = size
|
||||
nid, cid, hid, wid = jt.index((n,c,H,W))
|
||||
n, c, h, w = img.shape
|
||||
H, W = size
|
||||
nid, cid, hid, wid = jt.index((n, c, H, W))
|
||||
if align_corners:
|
||||
x = hid * ((h-1) / max(1, H-1))
|
||||
y = wid * ((w-1) / max(1, W-1))
|
||||
x = hid * ((h - 1) / max(1, H - 1))
|
||||
y = wid * ((w - 1) / max(1, W - 1))
|
||||
else:
|
||||
x = hid * (h / H) + (h/H*0.5 - 0.5)
|
||||
if H>h: x = x.clamp(0, h-1)
|
||||
y = wid * (w / W) + (w/W*0.5 - 0.5)
|
||||
if W>w: y = y.clamp(0, w-1)
|
||||
return _interpolate(img, x, y, (nid,cid), mode)
|
||||
x = hid * (h / H) + (h / H * 0.5 - 0.5)
|
||||
if H > h: x = x.clamp(0, h - 1)
|
||||
y = wid * (w / W) + (w / W * 0.5 - 0.5)
|
||||
if W > w: y = y.clamp(0, w - 1)
|
||||
return _interpolate(img, x, y, (nid, cid), mode)
|
||||
|
||||
|
||||
def upsample(img, size, mode="nearest", align_corners=False):
|
||||
n,c,h,w = img.shape
|
||||
H,W = size
|
||||
nid, cid, hid, wid = jt.index((n,c,H,W))
|
||||
n, c, h, w = img.shape
|
||||
H, W = size
|
||||
nid, cid, hid, wid = jt.index((n, c, H, W))
|
||||
if align_corners:
|
||||
x = hid * ((h-1) / max(1, H-1))
|
||||
y = wid * ((w-1) / max(1, W-1))
|
||||
else:
|
||||
x = hid * ((h - 1) / max(1, H - 1))
|
||||
y = wid * ((w - 1) / max(1, W - 1))
|
||||
elif mode == "bicubic":
|
||||
x = (hid + 0.5) * (h / H) - 0.5
|
||||
y = (wid + 0.5) * (w / W) - 0.5
|
||||
elif mode == 'nearest':
|
||||
x = hid * (h / H)
|
||||
y = wid * (w / W)
|
||||
return _interpolate(img, x, y, (nid,cid), mode)
|
||||
|
||||
def interpolate(X,size=None,scale_factor=None,mode='bilinear',align_corners=False):
|
||||
if scale_factor is not None:
|
||||
size = [X.shape[-2]*scale_factor,X.shape[-1]*scale_factor]
|
||||
if isinstance(size,int):
|
||||
size = (size,size)
|
||||
if scale_factor is not None and scale_factor>1:
|
||||
return upsample(X,size,mode,align_corners)
|
||||
else:
|
||||
return resize(X,size,mode,align_corners)
|
||||
x = hid * (h / H) + (h / H * 0.5 - 0.5)
|
||||
if H > h: x = x.clamp(0, h - 1)
|
||||
y = wid * (w / W) + (w / W * 0.5 - 0.5)
|
||||
if W > w: y = y.clamp(0, w - 1)
|
||||
return _interpolate(img, x, y, (nid, cid), mode)
|
||||
|
||||
|
||||
def interpolate(X, size=None, scale_factor=None, mode='bilinear', align_corners=False):
|
||||
if scale_factor is not None:
|
||||
size = [X.shape[-2] * scale_factor, X.shape[-1] * scale_factor]
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
if scale_factor is not None and scale_factor > 1:
|
||||
return upsample(X, size, mode, align_corners)
|
||||
else:
|
||||
return resize(X, size, mode, align_corners)
|
||||
|
||||
|
||||
def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
|
||||
r'''
|
||||
|
@ -1045,7 +1089,7 @@ def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
|
|||
[in] mode (string): the interpolate way, default: bilinear.
|
||||
|
||||
[in] padding_mode (string): the padding way, default: zeros.
|
||||
|
||||
|
||||
[out] output (var): the output var, whose shape is (N, C, Ho, Wo)
|
||||
|
||||
Example:
|
||||
|
@ -1069,10 +1113,10 @@ def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
|
|||
assert Ni == No
|
||||
assert len(input.shape) == 4 and len(grid.shape)
|
||||
|
||||
nid, cid, hid, wid = jt.index((Ni,Ci,Ho,Wo))
|
||||
x = ((grid[:,:,:,1].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Hi - 1)
|
||||
y = ((grid[:,:,:,0].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Wi - 1)
|
||||
return _interpolate(input, x, y, (nid,cid), mode)
|
||||
nid, cid, hid, wid = jt.index((Ni, Ci, Ho, Wo))
|
||||
x = ((grid[:, :, :, 1].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Hi - 1)
|
||||
y = ((grid[:, :, :, 0].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Wi - 1)
|
||||
return _interpolate(input, x, y, (nid, cid), mode)
|
||||
|
||||
|
||||
def linspace_from_neg_one(grid,num_steps,align_corners):
|
||||
|
@ -1327,4 +1371,49 @@ class Sequential(Module):
|
|||
def __len__(self):
|
||||
return len(self.layers)
|
||||
|
||||
|
||||
def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
|
||||
assert X.ndim == 4
|
||||
if not isinstance(kernel_size, tuple):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if not isinstance(dilation, tuple):
|
||||
dilation = (dilation, dilation)
|
||||
if not isinstance(padding, tuple):
|
||||
padding = (padding, padding)
|
||||
if not isinstance(stride, tuple):
|
||||
stride = (stride, stride)
|
||||
n, c, h, w = X.shape
|
||||
shape = X.shape
|
||||
area = kernel_size[0] * kernel_size[1]
|
||||
block_nums = []
|
||||
for i in range(2, 4):
|
||||
block_nums.append(
|
||||
(shape[i] + 2 * padding[i - 2] - dilation[i - 2] * (kernel_size[i - 2] - 1) - 1) // stride[i - 2] + 1)
|
||||
if padding[0] != 0 or padding[1] != 0:
|
||||
X = X.reindex([n, c, h + padding[0] * 2, w + padding[1] * 2],
|
||||
["i0", "i1", f"i2-{padding[0]}", f"i3-{padding[1]}"])
|
||||
output = X.reindex([n, c * area, block_nums[0] * block_nums[1]], ["i0", f"i1/{area}",
|
||||
f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}",
|
||||
f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"])
|
||||
return output
|
||||
|
||||
|
||||
def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1):
|
||||
assert X.ndim==3
|
||||
if not isinstance(kernel_size,tuple):
|
||||
kernel_size = (kernel_size,kernel_size)
|
||||
if not isinstance(dilation,tuple):
|
||||
dilation = (dilation,dilation)
|
||||
if not isinstance(padding,tuple):
|
||||
padding = (padding,padding)
|
||||
if not isinstance(stride,tuple):
|
||||
stride = (stride,stride)
|
||||
n,cl,num = X.shape
|
||||
area = kernel_size[0] * kernel_size[1]
|
||||
block_nums = []
|
||||
for i in range(2,4):
|
||||
block_nums.append((output_size[i-2]+2*padding[i-2]-dilation[i-2]*(kernel_size[i-2]-1)-1) // stride[i-2]+1)
|
||||
output = X.reindex_reduce("add",[n,cl // area,output_size[0]+2*padding[0],output_size[1]+2*padding[1]],["i0",f"i1/{area}",f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}",f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"])
|
||||
return output[:,:,padding[0]:padding[0]+output_size[0],padding[1]:padding[1]+output_size[1]]
|
||||
|
||||
ModuleList = Sequential
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Haoyang Peng <2247838039@qq.com>
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestBicubicInterpolate(unittest.TestCase):
|
||||
# this is for testing bicubic interpolate
|
||||
def test_bicubic(self):
|
||||
for _ in range(20):
|
||||
try:
|
||||
tn = np.random.randn(1,1,5,5).astype('float32')
|
||||
ja = jt.array(tn)
|
||||
ta = torch.autograd.Variable(torch.from_numpy(tn),requires_grad=True)
|
||||
# test upsample
|
||||
ju = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic')
|
||||
tu = F.interpolate(ta,scale_factor=2,mode='bicubic')
|
||||
assert np.allclose(ju.data,tu.detach().numpy(),rtol=1e-03,atol=1e-06)
|
||||
gju = jt.grad(ju,ja)
|
||||
gtu = torch.autograd.grad(tu,ta,torch.ones_like(tu),retain_graph=True)[0]
|
||||
assert np.allclose(gju.data,gtu.detach().numpy(),rtol=1e-03,atol=1e-06)
|
||||
# test align
|
||||
je = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic',align_corners=True)
|
||||
te = F.interpolate(ta,scale_factor=2,mode='bicubic',align_corners=True)
|
||||
assert np.allclose(je.data,te.detach().numpy(),rtol=1e-03,atol=1e-06)
|
||||
gje = jt.grad(je,ja)
|
||||
gte = torch.autograd.grad(te,ta,torch.ones_like(tu),retain_graph=True)[0]
|
||||
assert np.allclose(gje.data,gte.detach().numpy(),rtol=1e-03,atol=1e-06)
|
||||
except AssertionError:
|
||||
print(ju,tu)
|
||||
print(je,te)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -96,6 +96,16 @@ class TestConvTranspose(unittest.TestCase):
|
|||
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
|
||||
|
||||
def test_conv1d(self):
|
||||
conv1d = jt.nn.Conv1d(10,20,5)
|
||||
a = jt.rand((3,10,15))
|
||||
b = conv1d(a)
|
||||
b.sync()
|
||||
assert b.shape == [3,20,11]
|
||||
b = jt.nn.Conv1d(10,20,5, padding=2)(a)
|
||||
assert b.shape == [3,20,15]
|
||||
assert sorted(list(conv1d.state_dict().keys())) == ['bias', 'weight'], conv1d.state_dict().keys()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Haoyang Peng <2247838039@qq.com>
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestFoldOp(unittest.TestCase):
|
||||
def test_fold(self):
|
||||
# test unfold first and the test fold.
|
||||
for i in range(4,10):
|
||||
tn = np.random.randn(1,3,i,i).astype('float32')
|
||||
ja = jt.array(tn)
|
||||
ta = torch.autograd.Variable(torch.from_numpy(tn),requires_grad=True)
|
||||
juf = jt.nn.unfold(ja,kernel_size=2,stride=2,dilation=2,padding=2)
|
||||
tuf = F.unfold(ta,kernel_size=2,stride=2,dilation=2,padding=2)
|
||||
assert np.allclose(juf.data,tuf.detach().numpy())
|
||||
gjuf = jt.grad(juf,ja)
|
||||
gtuf = torch.autograd.grad(tuf,ta,torch.ones_like(tuf),retain_graph=True)[0]
|
||||
assert np.allclose(gjuf.data,gtuf.detach().numpy())
|
||||
# test fold
|
||||
jf = jt.nn.fold(juf,output_size=(i,i),kernel_size=2,stride=2,dilation=2,padding=2)
|
||||
tf = F.fold(tuf,output_size=(i,i),kernel_size=2,stride=2,dilation=2,padding=2)
|
||||
assert np.allclose(jf.data,tf.detach().numpy())
|
||||
gjf = jt.grad(jf,juf)
|
||||
gtf = torch.autograd.grad(tf,tuf,torch.ones_like(tf),retain_graph=True)[0]
|
||||
assert np.allclose(gjf.data,gtf.detach().numpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -8,49 +8,52 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
||||
try:
|
||||
import autograd.numpy as anp
|
||||
from autograd import jacobian
|
||||
|
||||
has_autograd = True
|
||||
except:
|
||||
has_autograd = False
|
||||
|
||||
|
||||
@unittest.skipIf(not has_autograd, "No autograd found.")
|
||||
class TestCodeOp(unittest.TestCase):
|
||||
class TestLinalgOp(unittest.TestCase):
|
||||
def test_svd(self):
|
||||
def check_svd(a):
|
||||
u,s,v = anp.linalg.svd(a, full_matrices=0)
|
||||
return u,s,v
|
||||
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
||||
return u, s, v
|
||||
|
||||
def check_u(a):
|
||||
u,s,v = anp.linalg.svd(a, full_matrices=0)
|
||||
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
||||
return u
|
||||
|
||||
def check_s(a):
|
||||
u,s,v = anp.linalg.svd(a, full_matrices=0)
|
||||
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
||||
return s
|
||||
|
||||
def check_v(a):
|
||||
u,s,v = anp.linalg.svd(a, full_matrices=0)
|
||||
u, s, v = anp.linalg.svd(a, full_matrices=0)
|
||||
return v
|
||||
|
||||
for i in range(50):
|
||||
#not for full-matrices!
|
||||
a = jt.random((2,2,5,4))
|
||||
# not for full-matrices!
|
||||
a = jt.random((2, 2, 5, 4))
|
||||
c_a = anp.array(a.data)
|
||||
u,s,v = jt.linalg.svd(a)
|
||||
tu,ts,tv = check_svd(c_a)
|
||||
assert np.allclose(tu,u.data)
|
||||
assert np.allclose(ts,s.data)
|
||||
assert np.allclose(tv,v.data)
|
||||
ju = jt.grad(u,a)
|
||||
js = jt.grad(s,a)
|
||||
jv = jt.grad(v,a)
|
||||
u, s, v = jt.linalg.svd(a)
|
||||
tu, ts, tv = check_svd(c_a)
|
||||
assert np.allclose(tu, u.data)
|
||||
assert np.allclose(ts, s.data)
|
||||
assert np.allclose(tv, v.data)
|
||||
ju = jt.grad(u, a)
|
||||
js = jt.grad(s, a)
|
||||
jv = jt.grad(v, a)
|
||||
grad_u = jacobian(check_u)
|
||||
gu = grad_u(c_a)
|
||||
gu = np.sum(gu, 4)
|
||||
|
@ -69,56 +72,56 @@ class TestCodeOp(unittest.TestCase):
|
|||
gv = np.sum(gv, 2)
|
||||
gv = np.sum(gv, 2)
|
||||
try:
|
||||
assert np.allclose(ju.data,gu,atol=1e-5)
|
||||
assert np.allclose(ju.data, gu, atol=1e-5)
|
||||
except AssertionError:
|
||||
print(ju.data)
|
||||
print(gu)
|
||||
try:
|
||||
assert np.allclose(js.data,gs,atol=1e-5)
|
||||
assert np.allclose(js.data, gs, atol=1e-5)
|
||||
except AssertionError:
|
||||
print(js.data)
|
||||
print(gs)
|
||||
try:
|
||||
assert np.allclose(jv.data,gv,atol=1e-5)
|
||||
assert np.allclose(jv.data, gv, atol=1e-5)
|
||||
except AssertionError:
|
||||
print(jv.data)
|
||||
print(gv)
|
||||
|
||||
def test_eigh(self):
|
||||
def check_eigh(a,UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a,UPLO)
|
||||
def check_eigh(a, UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a, UPLO)
|
||||
return w, v
|
||||
|
||||
def check_w(a,UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a,UPLO)
|
||||
def check_w(a, UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a, UPLO)
|
||||
return w
|
||||
|
||||
def check_v(a,UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a,UPLO)
|
||||
def check_v(a, UPLO='L'):
|
||||
w, v = anp.linalg.eigh(a, UPLO)
|
||||
return v
|
||||
|
||||
for i in range(50):
|
||||
a = jt.random((2,2,3,3))
|
||||
a = jt.random((2, 2, 3, 3))
|
||||
c_a = a.data
|
||||
w, v = jt.linalg.eigh(a)
|
||||
tw, tv = check_eigh(c_a)
|
||||
assert np.allclose(w.data,tw)
|
||||
assert np.allclose(v.data,tv)
|
||||
assert np.allclose(w.data, tw)
|
||||
assert np.allclose(v.data, tv)
|
||||
jw = jt.grad(w, a)
|
||||
jv = jt.grad(v, a)
|
||||
check_gw = jacobian(check_w)
|
||||
check_gv = jacobian(check_v)
|
||||
gw = check_gw(c_a)
|
||||
gw = np.sum(gw,4)
|
||||
gw = np.sum(gw,2)
|
||||
gw = np.sum(gw,2)
|
||||
assert np.allclose(gw,jw.data,rtol = 1,atol = 5e-8)
|
||||
gw = np.sum(gw, 4)
|
||||
gw = np.sum(gw, 2)
|
||||
gw = np.sum(gw, 2)
|
||||
assert np.allclose(gw, jw.data, rtol=1, atol=5e-8)
|
||||
gv = check_gv(c_a)
|
||||
gv = np.sum(gv,4)
|
||||
gv = np.sum(gv,4)
|
||||
gv = np.sum(gv,2)
|
||||
gv = np.sum(gv,2)
|
||||
assert np.allclose(gv,jv.data,rtol = 1,atol = 5e-8)
|
||||
gv = np.sum(gv, 4)
|
||||
gv = np.sum(gv, 4)
|
||||
gv = np.sum(gv, 2)
|
||||
gv = np.sum(gv, 2)
|
||||
assert np.allclose(gv, jv.data, rtol=1, atol=5e-8)
|
||||
|
||||
def test_pinv(self):
|
||||
def check_pinv(a):
|
||||
|
@ -126,34 +129,35 @@ class TestCodeOp(unittest.TestCase):
|
|||
return w
|
||||
|
||||
for i in range(50):
|
||||
x = jt.random((2,2,4,4))
|
||||
x = jt.random((2, 2, 4, 3))
|
||||
c_a = x.data
|
||||
mx = jt.linalg.pinv(x)
|
||||
tx = check_pinv(c_a)
|
||||
np.allclose(mx.data,tx)
|
||||
jx = jt.grad(mx,x)
|
||||
np.allclose(mx.data, tx)
|
||||
jx = jt.grad(mx, x)
|
||||
check_grad = jacobian(check_pinv)
|
||||
gx = check_grad(c_a)
|
||||
np.allclose(gx,jx.data)
|
||||
np.allclose(gx, jx.data)
|
||||
|
||||
def test_inv(self):
|
||||
def check_inv(a):
|
||||
w = anp.linalg.inv(a)
|
||||
return w
|
||||
|
||||
for i in range(50):
|
||||
tn = np.random.randn(4,4).astype('float32')*5
|
||||
while np.allclose(np.linalg.det(tn),0):
|
||||
tn = np.random.randn((4,4)).astype('float32')*5
|
||||
tn = np.random.randn(4, 4).astype('float32') * 5
|
||||
while np.allclose(np.linalg.det(tn), 0):
|
||||
tn = np.random.randn((4, 4)).astype('float32') * 5
|
||||
x = jt.array(tn)
|
||||
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
|
||||
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
||||
c_a = x.data
|
||||
mx = jt.linalg.inv(x)
|
||||
tx = check_inv(c_a)
|
||||
np.allclose(mx.data,tx)
|
||||
jx = jt.grad(mx,x)
|
||||
np.allclose(mx.data, tx)
|
||||
jx = jt.grad(mx, x)
|
||||
check_grad = jacobian(check_inv)
|
||||
gx = check_grad(c_a)
|
||||
np.allclose(gx,jx.data)
|
||||
np.allclose(gx, jx.data)
|
||||
|
||||
def test_slogdet(self):
|
||||
def check_ans(a):
|
||||
|
@ -165,11 +169,11 @@ class TestCodeOp(unittest.TestCase):
|
|||
return w
|
||||
|
||||
for i in range(50):
|
||||
tn = np.random.randn(4,4).astype('float32')*10
|
||||
while np.allclose(np.linalg.det(tn),0):
|
||||
tn = np.random.randn((4,4)).astype('float32')*10
|
||||
tn = np.random.randn(4, 4).astype('float32') * 10
|
||||
while np.allclose(np.linalg.det(tn), 0):
|
||||
tn = np.random.randn((4, 4)).astype('float32') * 10
|
||||
x = jt.array(tn)
|
||||
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
|
||||
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
||||
s = list(x.shape)
|
||||
det_s = s[:-2]
|
||||
if len(det_s) == 0:
|
||||
|
@ -178,12 +182,12 @@ class TestCodeOp(unittest.TestCase):
|
|||
ts, ta = check_ans(x.data)
|
||||
assert np.allclose(sign.data, ts)
|
||||
assert np.allclose(mx.data, ta)
|
||||
jx = jt.grad(mx,x)
|
||||
jx = jt.grad(mx, x)
|
||||
check_sgrad = jacobian(check_slogdet)
|
||||
gx = check_sgrad(x.data)
|
||||
gx = np.sum(gx,2)
|
||||
gx = np.sum(gx,2)
|
||||
assert np.allclose(gx,jx.data)
|
||||
gx = np.sum(gx, 2)
|
||||
gx = np.sum(gx, 2)
|
||||
assert np.allclose(gx, jx.data)
|
||||
|
||||
def test_cholesky(self):
|
||||
def check_cholesky(a):
|
||||
|
@ -192,39 +196,39 @@ class TestCodeOp(unittest.TestCase):
|
|||
|
||||
for i in range(50):
|
||||
x = jt.array(np.diag((np.random.rand(3) + 1) * 2))
|
||||
x = x.reindex([2,2,x.shape[0],x.shape[1]],["i2","i3"])
|
||||
x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
||||
tx = x.data
|
||||
L = jt.linalg.cholesky(x)
|
||||
tL = check_cholesky(tx)
|
||||
assert np.allclose(tL,L.data)
|
||||
jx = jt.grad(L,x)
|
||||
assert np.allclose(tL, L.data)
|
||||
jx = jt.grad(L, x)
|
||||
check_grad = jacobian(check_cholesky)
|
||||
gx = check_grad(tx)
|
||||
gx = np.sum(gx, 0)
|
||||
gx = np.sum(gx, 0)
|
||||
gx = np.sum(gx, 0)
|
||||
gx = np.sum(gx, 0)
|
||||
assert np.allclose(jx.data,gx)
|
||||
assert np.allclose(jx.data, gx)
|
||||
|
||||
def test_solve(self):
|
||||
def check_solve(a,b):
|
||||
ans = anp.linalg.solve(a,b)
|
||||
def check_solve(a, b):
|
||||
ans = anp.linalg.solve(a, b)
|
||||
return ans
|
||||
|
||||
for i in range(50):
|
||||
a = jt.random((2,2,3,3))
|
||||
b = jt.random((2,2,3))
|
||||
ans = jt.linalg.solve(a,b)
|
||||
ta = check_solve(a.data,b.data)
|
||||
a = jt.random((2, 2, 3, 3))
|
||||
b = jt.random((2, 2, 3))
|
||||
ans = jt.linalg.solve(a, b)
|
||||
ta = check_solve(a.data, b.data)
|
||||
assert np.allclose(ans.data, ta)
|
||||
jx = jt.grad(ans, a)
|
||||
check_sgrad = jacobian(check_solve)
|
||||
gx = check_sgrad(a.data,b.data)
|
||||
gx = np.sum(gx,0)
|
||||
gx = np.sum(gx,0)
|
||||
gx = np.sum(gx,0)
|
||||
gx = check_sgrad(a.data, b.data)
|
||||
gx = np.sum(gx, 0)
|
||||
gx = np.sum(gx, 0)
|
||||
gx = np.sum(gx, 0)
|
||||
try:
|
||||
assert np.allclose(gx, jx.data,rtol=1)
|
||||
assert np.allclose(gx, jx.data, rtol=1)
|
||||
except AssertionError:
|
||||
print(gx)
|
||||
print(jx.data)
|
||||
|
@ -254,7 +258,42 @@ class TestCodeOp(unittest.TestCase):
|
|||
gx = np.sum(gx, 2)
|
||||
assert np.allclose(gx, jx.data)
|
||||
|
||||
def test_qr(self):
|
||||
for i in range(50):
|
||||
tn = np.random.randn(3, 3).astype('float32')
|
||||
while np.allclose(np.linalg.det(tn), 0):
|
||||
tn = np.random.randn((3, 3)).astype('float32')
|
||||
x = jt.array(tn)
|
||||
# x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"])
|
||||
t_x = torch.from_numpy(tn)
|
||||
t_x = Variable(t_x, requires_grad=True)
|
||||
jq, jr = jt.linalg.qr(x)
|
||||
tq, tr = torch.qr(t_x)
|
||||
try:
|
||||
assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6)
|
||||
assert np.allclose(jr.data, tr.detach().numpy(), rtol=1e-4, atol=1e-6)
|
||||
except AssertionError:
|
||||
print("ours' qr results:")
|
||||
print(jq)
|
||||
print(jr)
|
||||
print("pytorch's qr results:")
|
||||
print(tq)
|
||||
print(tr)
|
||||
gq = jt.grad(jq, x).data
|
||||
gr = jt.grad(jr, x).data
|
||||
tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True)
|
||||
tgr = torch.autograd.grad(tr, t_x, torch.ones_like(tr), retain_graph=True)
|
||||
try:
|
||||
assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6)
|
||||
assert np.allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6)
|
||||
except AssertionError:
|
||||
print("ours' qr grad results:")
|
||||
print(gq)
|
||||
print(gr)
|
||||
print("pytorch's qr grad result")
|
||||
print(tgq[0])
|
||||
print(tgr[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
|
|
@ -31,6 +31,22 @@ def check_equal(res1, res2, eps=1e-5):
|
|||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestPad(unittest.TestCase):
|
||||
def test_index_add_(self):
|
||||
x = np.ones((5,3))
|
||||
a1 = torch.Tensor(x)
|
||||
a1.index_add_(0, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float))
|
||||
a2 = jt.array(x)
|
||||
a2.index_add_(0, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
||||
check_equal(a1, a2)
|
||||
|
||||
x = np.ones((3,5))
|
||||
a1 = torch.Tensor(x)
|
||||
a1.index_add_(1, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float))
|
||||
a2 = jt.array(x)
|
||||
a2.index_add_(1, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
||||
check_equal(a1, a2)
|
||||
print('pass index_add_ test ...')
|
||||
|
||||
def test_repeat(self):
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(torch.Tensor(arr).repeat(1,2,3,4), jt.array(arr).repeat(1,2,3,4))
|
||||
|
|
|
@ -149,26 +149,58 @@ class TestSetitem(unittest.TestCase):
|
|||
assert (a[0].numpy() == [-1,2]).all(), a[0].numpy()
|
||||
assert (a[1].numpy() == [3,-2]).all(), a[1].numpy()
|
||||
|
||||
# def test_scatter(self):
|
||||
# src = jt.arange(1, 11).reshape((2, 5))
|
||||
# index = jt.array([[0, 1, 2, 0]])
|
||||
# print(index.shape, src.shape)
|
||||
# x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src)
|
||||
# print(x)
|
||||
def test_scatter(self):
|
||||
src = jt.arange(1, 11).reshape((2, 5))
|
||||
index = jt.array([[0, 1, 2, 0]])
|
||||
x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src)
|
||||
assert (x.data ==
|
||||
[[1, 0, 0, 4, 0],
|
||||
[0, 2, 0, 0, 0],
|
||||
[0, 0, 3, 0, 0]]).all()
|
||||
index = jt.array([[0, 1, 2], [0, 1, 4]])
|
||||
x = jt.zeros((3, 5), dtype=src.dtype).scatter_(1, index, src)
|
||||
assert (x.data ==
|
||||
[[1, 2, 3, 0, 0],
|
||||
[6, 7, 0, 0, 8],
|
||||
[0, 0, 0, 0, 0]]).all()
|
||||
x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]),
|
||||
jt.array(1.23), reduce='multiply')
|
||||
assert np.allclose(x.data,
|
||||
[[2.0000, 2.0000, 2.4600, 2.0000],
|
||||
[2.0000, 2.0000, 2.0000, 2.4600]]), x
|
||||
x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]),
|
||||
jt.array(1.23), reduce='add')
|
||||
assert np.allclose(x.data,
|
||||
[[2.0000, 2.0000, 3.2300, 2.0000],
|
||||
[2.0000, 2.0000, 2.0000, 3.2300]])
|
||||
|
||||
def test_gather(self):
|
||||
t = jt.array([[1, 2], [3, 4]])
|
||||
data = t.gather(1, jt.array([[0, 0], [1, 0]])).data
|
||||
assert (data == [[ 1, 1], [ 4, 3]]).all()
|
||||
data = t.gather(0, jt.array([[0, 0], [1, 0]])).data
|
||||
assert (data == [[ 1, 2], [ 3, 2]]).all()
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_scatter_cuda(self):
|
||||
self.test_scatter()
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_gather_cuda(self):
|
||||
self.test_gather()
|
||||
|
||||
def test_setitem_bool(self):
|
||||
a = jt.array([1,2,3,4])
|
||||
b = jt.array([True,False,True,False])
|
||||
a[b] = jt.array([-1,-2])
|
||||
assert (a.data == [-1,2,-2,4]).all()
|
||||
|
||||
def test_slice_none(self):
|
||||
a = jt.array([1,2])
|
||||
assert a[None,:,None,None,...,None].shape == (1,2,1,1,1)
|
||||
|
||||
# def scatter(x, dim, index, src, reduce='void'):
|
||||
# shape = index.shape
|
||||
# indexes = [ jt.index(shape, i) for i in range(dim) ]
|
||||
# indexes.append(index)
|
||||
# print(indexes)
|
||||
# return x.setitem(tuple(indexes), src, reduce)
|
||||
|
||||
# def scatter_(x, dim, index, src, reduce='void'):
|
||||
# return x.assign(x.scatter(dim, index, src, reduce))
|
||||
|
||||
# jt.Var.scatter = scatter
|
||||
# jt.Var.scatter_ = scatter_
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -62,6 +62,13 @@ class TestUnaryOp(unittest.TestCase):
|
|||
jda = jt.grad(jb, ja)
|
||||
assert (np.allclose(jda.data, da)), (jda.data,da,op)
|
||||
|
||||
def test_sigmoid(self):
|
||||
a = np.arange(-150,150, 10).astype("float32")
|
||||
# a = np.array([-150.0, -140.0, -130.0]).astype("float32")
|
||||
b = jt.array(a, dtype='float32')
|
||||
b1 = b.sigmoid().numpy()
|
||||
assert np.isnan(b1).any() == False
|
||||
|
||||
class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)):
|
||||
pass
|
||||
|
||||
|
|
|
@ -81,8 +81,8 @@ void ArrayOp::jit_prepare(JK& jk) {
|
|||
|
||||
// fill or find cbuffer for const var pass
|
||||
if (output->dtype().dsize() == 4) {
|
||||
auto x = abs(ptr<int32>()[0]);
|
||||
auto y = abs(ptr<float32>()[0]);
|
||||
auto x = std::abs(ptr<int32>()[0]);
|
||||
auto y = std::abs(ptr<float32>()[0]);
|
||||
auto z = ptr<uint32>()[0];
|
||||
if ((x<=2) || (y==1.0f || y==2.0f))
|
||||
jk << _CS("[o:") << z << ']';
|
||||
|
|
|
@ -117,6 +117,10 @@ void GetitemOp::infer_slices(
|
|||
auto& v = s.slice.start;
|
||||
if (v<0) v += in_shape_i;
|
||||
CHECK(v>=0 && v<in_shape_i) << "slice overflow, " << v << "not in [0,">>in_shape_i>>")";
|
||||
} else
|
||||
if (s.is_str()) {
|
||||
i_to_vs[i] = vid++;
|
||||
i_to_o[i] = -1;
|
||||
} else {
|
||||
// slice
|
||||
auto& slice = s.slice;
|
||||
|
@ -146,6 +150,13 @@ void GetitemOp::infer_slices(
|
|||
}
|
||||
}
|
||||
}
|
||||
while (vid < vs.n) {
|
||||
auto& s = vs.slices[vid++];
|
||||
if (s.is_none()) {
|
||||
out_shape.push_back(1);
|
||||
} else
|
||||
CHECK(s.is_ellipsis()) << "Too many slices" << vs << "shape:" << in->shape;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims) {
|
||||
|
@ -401,6 +412,10 @@ void GetitemOp::jit_prepare(JK& jk) {
|
|||
if (iv>=0 && io==-1) {
|
||||
if (v.is_int()) {
|
||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
|
||||
} else
|
||||
if (v.is_str()) {
|
||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5");
|
||||
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str();
|
||||
} else {
|
||||
ASSERT(v.is_var());
|
||||
auto var = v.var;
|
||||
|
@ -498,9 +513,10 @@ void GetitemOp::jit_run() {
|
|||
@if(IV@d==-2, 0,
|
||||
@if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d),
|
||||
@if(VS@d==-1, vi@d,
|
||||
@if(VS@d==-5, VSS@d,
|
||||
@if(VS@d>=0,
|
||||
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
|
||||
, ??? )))));
|
||||
, ??? ))))));
|
||||
)
|
||||
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
|
||||
op[oid] = ip[iid];
|
||||
|
|
|
@ -206,6 +206,10 @@ void SetitemOp::jit_prepare(JK& jk) {
|
|||
if (iv>=0 && io==-1) {
|
||||
if (v.is_int()) {
|
||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
|
||||
} else
|
||||
if (v.is_str()) {
|
||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5");
|
||||
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str();
|
||||
} else {
|
||||
ASSERT(v.is_var());
|
||||
auto var = v.var;
|
||||
|
@ -323,9 +327,10 @@ void SetitemOp::jit_run() {
|
|||
@if(IV@d==-2, 0,
|
||||
@if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d),
|
||||
@if(VS@d==-1, vi@d,
|
||||
@if(VS@d==-5, VSS@d,
|
||||
@if(VS@d>=0,
|
||||
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
|
||||
, ??? )))));
|
||||
, ??? ))))));
|
||||
)
|
||||
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ namespace jittor {
|
|||
#define tanh(T,x) ((T) ::tanhf((x)))
|
||||
#define atanh(T,x) ((T) ::atanhf((x)))
|
||||
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf(-(x)))))
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf((::min(T(-(x)), T(@if(@strcmp(@T,float32)==0,30,300))))))))
|
||||
|
||||
#define erf(T,x) ((T) ::erff((x)))
|
||||
|
||||
|
@ -65,7 +65,7 @@ namespace jittor {
|
|||
#define tanh(T,x) ((T) std::tanh((x)))
|
||||
#define atanh(T,x) ((T) std::atanh((x)))
|
||||
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(-(x)))))
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(std::min(T(-(x)), T(@if(@strcmp(@T,float32)==0,30,300)))))))
|
||||
|
||||
#define erf(T,x) ((T) std::erf((x)))
|
||||
|
||||
|
|
|
@ -719,6 +719,7 @@ DEF_IS(VarSlices, bool) is_type(PyObject* obj) {
|
|||
PySlice_Check(obj) ||
|
||||
(Py_TYPE(obj) == &PyEllipsis_Type) ||
|
||||
obj == Py_None ||
|
||||
PyUnicode_CheckExact(obj) ||
|
||||
is_type<VarHolder*>(obj);
|
||||
}
|
||||
|
||||
|
@ -733,6 +734,9 @@ void load_var_slice(PyObject* obj, T* var_slice, vector<unique_ptr<VarHolder>>&
|
|||
if (Py_TYPE(obj) == &PyEllipsis_Type) {
|
||||
var_slice->set_ellipsis();
|
||||
} else
|
||||
if (PyUnicode_CheckExact(obj)) {
|
||||
var_slice->set_str(from_py_object<string>(obj));
|
||||
} else
|
||||
if (obj == Py_None) {
|
||||
var_slice->set_none();
|
||||
} else
|
||||
|
@ -745,7 +749,9 @@ void load_var_slice(PyObject* obj, T* var_slice, vector<unique_ptr<VarHolder>>&
|
|||
} else {
|
||||
holders.emplace_back();
|
||||
auto* vh = from_py_object<VarHolder*>(obj, holders.back());
|
||||
auto vv = (Var**)vh;
|
||||
auto vv = (decltype(var_slice->var)*)vh;
|
||||
CHECK(vv[0]->dtype() != ns_bool) << "Please convert bool slice into jt.array, example:\n"
|
||||
"a[[True,False,False]] ---> a[jt.array([True,False,False])";
|
||||
var_slice->set_var(vv[0]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ std::ostream& operator<<(std::ostream& os, const VarSlice& s) {
|
|||
if (s.is_ellipsis()) return os << "...";
|
||||
if (s.is_slice()) return os << s.slice;
|
||||
if (s.is_int()) return os << s.i;
|
||||
if (s.is_str()) return os << (const char*)&s;
|
||||
return os << "-";
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "misc/nano_vector.h"
|
||||
#include "var.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -20,11 +21,21 @@ union VarSlice {
|
|||
inline bool is_ellipsis() const { return slice.mask == -2; }
|
||||
inline bool is_none() const { return slice.mask == -3; }
|
||||
inline bool is_int() const { return slice.mask == -4; }
|
||||
inline bool is_str() const { return slice.mask == -5; }
|
||||
inline bool is_slice() const { return slice.mask >= 0; }
|
||||
inline void set_var(Var* v) { slice.mask = -1; var = v; }
|
||||
inline void set_ellipsis() { slice.mask = -2; }
|
||||
inline void set_none() { slice.mask = -3; }
|
||||
inline void set_int(int64 v) { slice.mask = -4; i = v; }
|
||||
inline void set_str(const string& s) {
|
||||
slice.mask = -5;
|
||||
CHECK(s.size() < 16) << "String slice too long" << s;
|
||||
auto v = (int64*)s.c_str();
|
||||
slice.start = v[0];
|
||||
slice.stop = v[1];
|
||||
slice.step = s.size();
|
||||
}
|
||||
inline char* get_str() {return (char*)this;}
|
||||
};
|
||||
|
||||
struct VarSlices {
|
||||
|
|
Loading…
Reference in New Issue