mirror of https://github.com/Jittor/Jittor
Merge branch 'searchsort' into ygy4
This commit is contained in:
commit
463e530643
|
@ -298,12 +298,12 @@ def std(x):
|
|||
return out
|
||||
Var.std = std
|
||||
|
||||
def norm(x, k, dim):
|
||||
def norm(x, k, dim, keepdim=False):
|
||||
assert k==1 or k==2
|
||||
if k==1:
|
||||
return x.abs().sum(dim)
|
||||
return x.abs().sum(dim, keepdim)
|
||||
if k==2:
|
||||
return (x.sqr()).sum(dim).maximum(1e-6).sqrt()
|
||||
return (x.sqr()).sum(dim, keepdim).maximum(1e-6).sqrt()
|
||||
Var.norm = norm
|
||||
|
||||
origin_reshape = reshape
|
||||
|
|
|
@ -721,3 +721,114 @@ def triu_(x,diagonal=0):
|
|||
return x.reindex(x.shape,indexs,overflow_conditions=overflow_conditions,overflow_value=0)
|
||||
|
||||
jt.Var.triu_ = triu_
|
||||
|
||||
def searchsorted(s, v, right=False):
|
||||
class SearchsortedFunc(jt.Module):
|
||||
def __init__(self, right=False):
|
||||
self.side = "right" if right else "left"
|
||||
|
||||
def forward_code(self, np, data):
|
||||
a, b = data["inputs"]
|
||||
c = data["outputs"][0]
|
||||
if len(a.shape)==1:
|
||||
out = np.searchsorted(a, b, side=self.side)
|
||||
else:
|
||||
# out = np.apply_along_axis(np.searchsorted, 1, a, b)
|
||||
# out = out.diagonal(0,0,1).T
|
||||
|
||||
# TODO: support better 2-dims searchsorted
|
||||
outs = []
|
||||
for i in range(a.shape[0]):
|
||||
outs.append(np.expand_dims(np.searchsorted(a[i], b[i], side=self.side),0))
|
||||
out = np.concatenate(outs, 0)
|
||||
# out = np.zeros(b.shape)
|
||||
np.copyto(c, out)
|
||||
|
||||
def execute(self, s, v):
|
||||
return jt.numpy_code(
|
||||
v.shape,
|
||||
v.dtype,
|
||||
[s, v],
|
||||
self.forward_code,
|
||||
)
|
||||
assert len(s.shape)==len(v.shape) and v.shape[:-1]==s.shape[:-1]
|
||||
assert len(s.shape)==1 or len(s.shape)==2, "TODO: support n-dims searchsorted"
|
||||
func = SearchsortedFunc(right)
|
||||
return func(s, v)
|
||||
|
||||
def cumprod(a, dim):
|
||||
class CumprodFunc(jt.Function):
|
||||
def forward_code(self, np, data):
|
||||
a = data["inputs"][0]
|
||||
b = data["outputs"][0]
|
||||
out = np.cumprod(a, self.dim)
|
||||
np.copyto(b, out)
|
||||
|
||||
def backward_code(self, np, data):
|
||||
a, b, dout = data["inputs"]
|
||||
out = data["outputs"][0]
|
||||
|
||||
sdim = a.shape[self.dim]
|
||||
dim = (len(a.shape)+1)*[1]
|
||||
dim[self.dim+1] = sdim
|
||||
res = np.tile(np.expand_dims(b, self.dim+1), dim)
|
||||
dout = np.tile(np.expand_dims(dout, self.dim+1), dim)
|
||||
|
||||
dim[self.dim]=sdim
|
||||
dim[self.dim+1]=1
|
||||
a = np.tile(np.expand_dims(a, self.dim), dim)
|
||||
res = res/a
|
||||
|
||||
mask = np.tril(np.ones((sdim, sdim)))
|
||||
for i in range(self.dim):
|
||||
mask = np.expand_dims(mask, 0)
|
||||
for i in range(len(a.shape)-self.dim-2):
|
||||
mask = np.expand_dims(mask, -1)
|
||||
res = np.sum(mask*res*dout, self.dim)
|
||||
|
||||
np.copyto(out, res)
|
||||
|
||||
def execute(self, a, dim):
|
||||
self.save_vars = a
|
||||
self.dim = dim
|
||||
self.res = jt.numpy_code(
|
||||
a.shape,
|
||||
a.dtype,
|
||||
[a],
|
||||
self.forward_code,
|
||||
)
|
||||
return self.res
|
||||
|
||||
def grad(self, grad_a):
|
||||
a = self.save_vars
|
||||
b = self.res
|
||||
return jt.numpy_code(
|
||||
a.shape,
|
||||
a.dtype,
|
||||
[a, b, grad_a],
|
||||
self.backward_code,
|
||||
)
|
||||
|
||||
func = CumprodFunc()
|
||||
if dim<0:
|
||||
dim+=len(a.shape)
|
||||
return func(a, dim)
|
||||
|
||||
def linspace(start, end, steps):
|
||||
res = jt.index((steps,))[0]
|
||||
res = res*(end-start)/float(steps-1)+start
|
||||
return res
|
||||
|
||||
def randperm(n):
|
||||
# TODO: use jt.random
|
||||
idx = np.arange(n)
|
||||
return jt.array(np.random.permutation(idx))
|
||||
|
||||
def set_global_seed(seed):
|
||||
jt.set_seed(seed)
|
||||
np.random.seed(seed)
|
||||
try:
|
||||
import cupy
|
||||
cupy.random.seed(seed)
|
||||
except:
|
||||
pass
|
|
@ -0,0 +1,51 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import ctypes
|
||||
import sys
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
class TestCumprod(unittest.TestCase):
|
||||
def test_cumprod_cpu(self):
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
for i in range(1,6):
|
||||
for j in range(i):
|
||||
x = np.random.rand(*((10,)*i))
|
||||
x_jt = jt.array(x)
|
||||
y_jt = jt.cumprod(x_jt, j).sqr()
|
||||
g_jt = jt.grad(y_jt.sum(), x_jt)
|
||||
x_tc = Variable(torch.from_numpy(x), requires_grad=True)
|
||||
y_tc = torch.cumprod(x_tc, j)**2
|
||||
y_tc.sum().backward()
|
||||
g_tc = x_tc.grad
|
||||
assert np.allclose(y_jt.numpy(), y_tc.data)
|
||||
assert np.allclose(g_jt.numpy(), g_tc.data)
|
||||
|
||||
def test_cumprod_gpu(self):
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
for i in range(1,6):
|
||||
for j in range(i):
|
||||
x = np.random.rand(*((10,)*i))
|
||||
x_jt = jt.array(x)
|
||||
y_jt = jt.cumprod(x_jt, j).sqr()
|
||||
g_jt = jt.grad(y_jt.sum(), x_jt)
|
||||
x_tc = Variable(torch.from_numpy(x), requires_grad=True)
|
||||
y_tc = torch.cumprod(x_tc, j)**2
|
||||
y_tc.sum().backward()
|
||||
g_tc = x_tc.grad
|
||||
assert np.allclose(y_jt.numpy(), y_tc.data)
|
||||
assert np.allclose(g_jt.numpy(), g_tc.data)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,55 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import ctypes
|
||||
import sys
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
class TestSearchsorted(unittest.TestCase):
|
||||
def test_searchsorted_cpu(self):
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
for i in range(1,3):
|
||||
s = np.sort(np.random.rand(*((10,)*i)),-1)
|
||||
v = np.random.rand(*((10,)*i))
|
||||
s_jt = jt.array(s)
|
||||
v_jt = jt.array(v)
|
||||
s_tc = torch.from_numpy(s)
|
||||
v_tc = torch.from_numpy(v)
|
||||
|
||||
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
|
||||
y_tc = torch.searchsorted(s_tc, v_tc, right=True)
|
||||
assert np.allclose(y_jt.numpy(), y_tc.data)
|
||||
y_jt = jt.searchsorted(s_jt, v_jt, right=False)
|
||||
y_tc = torch.searchsorted(s_tc, v_tc, right=False)
|
||||
assert np.allclose(y_jt.numpy(), y_tc.data)
|
||||
|
||||
def test_searchsorted_gpu(self):
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
for i in range(1,3):
|
||||
s = np.sort(np.random.rand(*((10,)*i)),-1)
|
||||
v = np.random.rand(*((10,)*i))
|
||||
s_jt = jt.array(s)
|
||||
v_jt = jt.array(v)
|
||||
s_tc = torch.from_numpy(s)
|
||||
v_tc = torch.from_numpy(v)
|
||||
|
||||
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
|
||||
y_tc = torch.searchsorted(s_tc, v_tc, right=True)
|
||||
assert np.allclose(y_jt.numpy(), y_tc.data)
|
||||
y_jt = jt.searchsorted(s_jt, v_jt, right=False)
|
||||
y_tc = torch.searchsorted(s_tc, v_tc, right=False)
|
||||
assert np.allclose(y_jt.numpy(), y_tc.data)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue