searchsorted&cumprod etc. ops

This commit is contained in:
Gword 2020-12-17 20:31:26 +08:00
parent 3c95c6d100
commit 8a4ad9bca5
7 changed files with 230 additions and 4 deletions

View File

@ -268,12 +268,12 @@ def std(x):
return out
Var.std = std
def norm(x, k, dim):
def norm(x, k, dim, keepdim=False):
assert k==1 or k==2
if k==1:
return x.abs().sum(dim)
return x.abs().sum(dim, keepdim)
if k==2:
return (x.sqr()).sum(dim).maximum(1e-6).sqrt()
return (x.sqr()).sum(dim, keepdim).maximum(1e-6).sqrt()
Var.norm = norm
origin_reshape = reshape

View File

@ -676,3 +676,114 @@ def triu_(x,diagonal=0):
return x.reindex(x.shape,indexs,overflow_conditions=overflow_conditions,overflow_value=0)
jt.Var.triu_ = triu_
def searchsorted(s, v, right=False):
class SearchsortedFunc(jt.Module):
def __init__(self, right=False):
self.side = "right" if right else "left"
def forward_code(self, np, data):
a, b = data["inputs"]
c = data["outputs"][0]
if len(a.shape)==1:
out = np.searchsorted(a, b, side=self.side)
else:
# out = np.apply_along_axis(np.searchsorted, 1, a, b)
# out = out.diagonal(0,0,1).T
# TODO: support better 2-dims searchsorted
outs = []
for i in range(a.shape[0]):
outs.append(np.expand_dims(np.searchsorted(a[i], b[i], side=self.side),0))
out = np.concatenate(outs, 0)
# out = np.zeros(b.shape)
np.copyto(c, out)
def execute(self, s, v):
return jt.numpy_code(
v.shape,
v.dtype,
[s, v],
self.forward_code,
)
assert len(s.shape)==len(v.shape) and v.shape[:-1]==s.shape[:-1]
assert len(s.shape)==1 or len(s.shape)==2, "TODO: support n-dims searchsorted"
func = SearchsortedFunc(right)
return func(s, v)
def cumprod(a, dim):
class CumprodFunc(jt.Function):
def forward_code(self, np, data):
a = data["inputs"][0]
b = data["outputs"][0]
out = np.cumprod(a, self.dim)
np.copyto(b, out)
def backward_code(self, np, data):
a, b, dout = data["inputs"]
out = data["outputs"][0]
sdim = a.shape[self.dim]
dim = (len(a.shape)+1)*[1]
dim[self.dim+1] = sdim
res = np.tile(np.expand_dims(b, self.dim+1), dim)
dout = np.tile(np.expand_dims(dout, self.dim+1), dim)
dim[self.dim]=sdim
dim[self.dim+1]=1
a = np.tile(np.expand_dims(a, self.dim), dim)
res = res/a
mask = np.tril(np.ones((sdim, sdim)))
for i in range(self.dim):
mask = np.expand_dims(mask, 0)
for i in range(len(a.shape)-self.dim-2):
mask = np.expand_dims(mask, -1)
res = np.sum(mask*res*dout, self.dim)
np.copyto(out, res)
def execute(self, a, dim):
self.save_vars = a
self.dim = dim
self.res = jt.numpy_code(
a.shape,
a.dtype,
[a],
self.forward_code,
)
return self.res
def grad(self, grad_a):
a = self.save_vars
b = self.res
return jt.numpy_code(
a.shape,
a.dtype,
[a, b, grad_a],
self.backward_code,
)
func = CumprodFunc()
if dim<0:
dim+=len(a.shape)
return func(a, dim)
def linspace(start, end, steps):
res = jt.index((steps,))[0]
res = res*(end-start)/float(steps-1)+start
return res
def randperm(n):
# TODO: use jt.random
idx = np.arange(n)
return jt.array(np.random.permutation(idx))
def set_global_seed(seed):
jt.set_seed(seed)
np.random.seed(seed)
try:
import cupy
cupy.random.seed(seed)
except:
pass

View File

@ -0,0 +1,51 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import ctypes
import sys
import torch
from torch.autograd import Variable
class TestCumprod(unittest.TestCase):
def test_cumprod_cpu(self):
jt.flags.use_cuda = 0
for i in range(1,6):
for j in range(i):
x = np.random.rand(*((10,)*i))
x_jt = jt.array(x)
y_jt = jt.cumprod(x_jt, j).sqr()
g_jt = jt.grad(y_jt.sum(), x_jt)
x_tc = Variable(torch.from_numpy(x), requires_grad=True)
y_tc = torch.cumprod(x_tc, j)**2
y_tc.sum().backward()
g_tc = x_tc.grad
assert np.allclose(y_jt.numpy(), y_tc.data)
assert np.allclose(g_jt.numpy(), g_tc.data)
def test_cumprod_gpu(self):
jt.flags.use_cuda = 1
for i in range(1,6):
for j in range(i):
x = np.random.rand(*((10,)*i))
x_jt = jt.array(x)
y_jt = jt.cumprod(x_jt, j).sqr()
g_jt = jt.grad(y_jt.sum(), x_jt)
x_tc = Variable(torch.from_numpy(x), requires_grad=True)
y_tc = torch.cumprod(x_tc, j)**2
y_tc.sum().backward()
g_tc = x_tc.grad
assert np.allclose(y_jt.numpy(), y_tc.data)
assert np.allclose(g_jt.numpy(), g_tc.data)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,55 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import ctypes
import sys
import torch
from torch.autograd import Variable
class TestSearchsorted(unittest.TestCase):
def test_searchsorted_cpu(self):
jt.flags.use_cuda = 0
for i in range(1,3):
s = np.sort(np.random.rand(*((10,)*i)),-1)
v = np.random.rand(*((10,)*i))
s_jt = jt.array(s)
v_jt = jt.array(v)
s_tc = torch.from_numpy(s)
v_tc = torch.from_numpy(v)
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
y_tc = torch.searchsorted(s_tc, v_tc, right=True)
assert np.allclose(y_jt.numpy(), y_tc.data)
y_jt = jt.searchsorted(s_jt, v_jt, right=False)
y_tc = torch.searchsorted(s_tc, v_tc, right=False)
assert np.allclose(y_jt.numpy(), y_tc.data)
def test_searchsorted_gpu(self):
jt.flags.use_cuda = 1
for i in range(1,3):
s = np.sort(np.random.rand(*((10,)*i)),-1)
v = np.random.rand(*((10,)*i))
s_jt = jt.array(s)
v_jt = jt.array(v)
s_tc = torch.from_numpy(s)
v_tc = torch.from_numpy(v)
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
y_tc = torch.searchsorted(s_tc, v_tc, right=True)
assert np.allclose(y_jt.numpy(), y_tc.data)
y_jt = jt.searchsorted(s_jt, v_jt, right=False)
y_tc = torch.searchsorted(s_tc, v_tc, right=False)
assert np.allclose(y_jt.numpy(), y_tc.data)
if __name__ == "__main__":
unittest.main()

View File

@ -18,6 +18,7 @@ unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
PyObject* (*PyArray_NewCopy)(PyObject *, int);
int (*PyArray_CopyInto)(PyObject *, PyObject *);
void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode);
tmp_data_t tmp_data;
@ -36,6 +37,7 @@ void numpy_init() {
fill(PyArray_SetBaseObject, 282);
fill(PyArray_NewCopy, 85);
fill(PyArray_CopyInto, 82);
fill(PyArray_CastScalarToCtype, 63);
ASSERT(PyArray_GetNDArrayCFeatureVersion()>=7);
}

View File

@ -100,6 +100,7 @@ extern unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
extern int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
extern PyObject* (*PyArray_NewCopy)(PyObject *, int);
extern int (*PyArray_CopyInto)(PyObject *, PyObject *);
extern void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode);
#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0)

View File

@ -734,7 +734,13 @@ void load_var_slice(PyObject* obj, T* var_slice, vector<unique_ptr<VarHolder>>&
} else
if (obj == Py_None) {
var_slice->set_none();
}else {
} else
if (PyObject_TypeCheck(obj, PyNumberArrType_Type)) {
PyArrayDescr_Proxy array_descr = {.type_num = 5}; // 5: int32
int value;
PyArray_CastScalarToCtype(obj, &value, &array_descr);
var_slice->set_int(value);
} else {
holders.emplace_back();
auto* vh = from_py_object<VarHolder*>(obj, holders.back());
auto vv = (Var**)vh;