mirror of https://github.com/Jittor/Jittor
add search_sorted
This commit is contained in:
parent
7d2eefc581
commit
96545765ec
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.13'
|
||||
__version__ = '1.2.2.14'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -721,3 +721,100 @@ def triu_(x,diagonal=0):
|
|||
return x.reindex(x.shape,indexs,overflow_conditions=overflow_conditions,overflow_value=0)
|
||||
|
||||
jt.Var.triu_ = triu_
|
||||
|
||||
|
||||
|
||||
|
||||
def searchsorted(sorted, values, right=False):
|
||||
"""
|
||||
Find the indices from the innermost dimension of `sorted` for each `values`.
|
||||
|
||||
Example::
|
||||
|
||||
sorted = jt.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]])
|
||||
values = jt.array([[3, 6, 9], [3, 6, 9]])
|
||||
ret = jt.searchsorted(sorted, values)
|
||||
assert (ret == [[1, 3, 4], [1, 2, 4]]).all(), ret
|
||||
|
||||
ret = jt.searchsorted(sorted, values, right=True)
|
||||
assert (ret == [[2, 3, 5], [1, 3, 4]]).all(), ret
|
||||
|
||||
sorted_1d = jt.array([1, 3, 5, 7, 9])
|
||||
ret = jt.searchsorted(sorted_1d, values)
|
||||
assert (ret == [[1, 3, 4], [1, 3, 4]]).all(), ret
|
||||
|
||||
|
||||
"""
|
||||
_searchsorted_header = f"""
|
||||
namespace jittor {{
|
||||
|
||||
#ifdef JIT_cuda
|
||||
__device__
|
||||
#endif
|
||||
inline static void searchsorted_kernel(int batch_id, int value_id,
|
||||
int value_num, int sorted_num, int batch_stride,
|
||||
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
|
||||
int32* __restrict__ index_p) {{
|
||||
int32 l = batch_id * batch_stride;
|
||||
int32 r = l + sorted_num;
|
||||
auto v = value_p[batch_id * value_num + value_id];
|
||||
while (l<r) {{
|
||||
int32 m = (l+r)/2;
|
||||
if (sort_p[m] {"<=" if right else "<"} v)
|
||||
l = m+1;
|
||||
else
|
||||
r = m;
|
||||
}}
|
||||
index_p[batch_id * value_num + value_id] = l - batch_id * batch_stride;
|
||||
}}
|
||||
|
||||
#ifdef JIT_cuda
|
||||
__global__ void searchsorted(int tn0, int tn1, int batch_num,
|
||||
int value_num, int sorted_num, int batch_stride,
|
||||
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
|
||||
int32* __restrict__ index_p
|
||||
) {{
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto i1 = tid & ((1<<tn1)-1);
|
||||
auto i0 = tid >> tn1;
|
||||
for (int i=i0; i<batch_num; i+=1<<tn0)
|
||||
for (int j=i1; j<value_num; j+=1<<tn1)
|
||||
searchsorted_kernel(i, j, value_num, sorted_num, batch_stride, sort_p, value_p, index_p);
|
||||
}}
|
||||
|
||||
inline static int get_thread_range_log(int& thread_num, int64 range) {{
|
||||
int nbits = NanoVector::get_nbits(std::min((int64)thread_num, range)) - 2;
|
||||
thread_num >>= nbits;
|
||||
return nbits;
|
||||
}}
|
||||
#endif
|
||||
|
||||
}}
|
||||
"""
|
||||
_searchsorted_src = """
|
||||
int value_num = in1->shape[in1->shape.size()-1];
|
||||
int sorted_num = in0->shape[in0->shape.size()-1];
|
||||
int32 batch_num = in0->num / sorted_num;
|
||||
int32 batch_num2 = in1->num / value_num;
|
||||
int32 batch_stride = batch_num == 1 ? 0 : sorted_num;
|
||||
CHECK(batch_num == batch_num2 || batch_num == 1);
|
||||
|
||||
#ifdef JIT_cuda
|
||||
int thread_num = 256*1024;
|
||||
auto tn1 = get_thread_range_log(thread_num, value_num);
|
||||
auto tn0 = get_thread_range_log(thread_num, batch_num2);
|
||||
thread_num = 1<<(tn0+tn1);
|
||||
int p1 = std::max(thread_num/1024, 1);
|
||||
int p2 = std::min(thread_num, 1024);
|
||||
searchsorted<<<p1,p2>>>(tn0, tn1, batch_num2, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p);
|
||||
#else
|
||||
for (int32 i=0; i<batch_num2; i++)
|
||||
for (int32 j=0; j<value_num; j++)
|
||||
searchsorted_kernel(i, j, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p);
|
||||
#endif
|
||||
"""
|
||||
return jt.code(values.shape, "int32", [sorted, values],
|
||||
cpu_header=_searchsorted_header,
|
||||
cpu_src=_searchsorted_src,
|
||||
cuda_header=_searchsorted_header,
|
||||
cuda_src=_searchsorted_src)
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
import torchvision
|
||||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
torchvision = None
|
||||
skip_this_test = True
|
||||
|
||||
# TODO: more test
|
||||
# @unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestSearchSorted(unittest.TestCase):
|
||||
def test_origin(self):
|
||||
sorted = jt.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]])
|
||||
values = jt.array([[3, 6, 9], [3, 6, 9]])
|
||||
ret = jt.searchsorted(sorted, values)
|
||||
assert (ret == [[1, 3, 4], [1, 2, 4]]).all(), ret
|
||||
|
||||
ret = jt.searchsorted(sorted, values, right=True)
|
||||
assert (ret == [[2, 3, 5], [1, 3, 4]]).all(), ret
|
||||
|
||||
sorted_1d = jt.array([1, 3, 5, 7, 9])
|
||||
ret = jt.searchsorted(sorted_1d, values)
|
||||
assert (ret == [[1, 3, 4], [1, 3, 4]]).all(), ret
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda(self):
|
||||
self.test_origin()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -127,6 +127,7 @@ ArrayArgs VarHolder::fetch_sync() {
|
|||
|
||||
ItemData VarHolder::item() {
|
||||
sync();
|
||||
CHECK(var->num==1) << "Item var size should be 1, but got" << var->num;
|
||||
ItemData data;
|
||||
data.dtype = var->dtype();
|
||||
auto dsize = data.dtype.dsize();
|
||||
|
|
Loading…
Reference in New Issue