Merge branch 'master' of https://github.com/Jittor/jittor into ygy4

This commit is contained in:
Gword 2021-01-06 13:24:53 +08:00
commit 56cc2f50f4
6 changed files with 188 additions and 35 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.11'
__version__ = '1.2.2.14'
from . import lock
with lock.lock_scope():
ori_int = int
@ -40,7 +40,9 @@ import traceback
def safepickle(obj, path):
s = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
# Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations.
# ref: <https://docs.python.org/3/library/pickle.html>
s = pickle.dumps(obj, 4)
checksum = hashlib.sha1(s).digest()
s += bytes(checksum)
s += b"HCAJSLHD"

View File

@ -722,39 +722,39 @@ def triu_(x,diagonal=0):
jt.Var.triu_ = triu_
def searchsorted(s, v, right=False):
class SearchsortedFunc(jt.Module):
def __init__(self, right=False):
self.side = "right" if right else "left"
# def searchsorted(s, v, right=False):
# class SearchsortedFunc(jt.Module):
# def __init__(self, right=False):
# self.side = "right" if right else "left"
def forward_code(self, np, data):
a, b = data["inputs"]
c = data["outputs"][0]
if len(a.shape)==1:
out = np.searchsorted(a, b, side=self.side)
else:
# out = np.apply_along_axis(np.searchsorted, 1, a, b)
# out = out.diagonal(0,0,1).T
# def forward_code(self, np, data):
# a, b = data["inputs"]
# c = data["outputs"][0]
# if len(a.shape)==1:
# out = np.searchsorted(a, b, side=self.side)
# else:
# # out = np.apply_along_axis(np.searchsorted, 1, a, b)
# # out = out.diagonal(0,0,1).T
# TODO: support better 2-dims searchsorted
outs = []
for i in range(a.shape[0]):
outs.append(np.expand_dims(np.searchsorted(a[i], b[i], side=self.side),0))
out = np.concatenate(outs, 0)
# out = np.zeros(b.shape)
np.copyto(c, out)
# # TODO: support better 2-dims searchsorted
# outs = []
# for i in range(a.shape[0]):
# outs.append(np.expand_dims(np.searchsorted(a[i], b[i], side=self.side),0))
# out = np.concatenate(outs, 0)
# # out = np.zeros(b.shape)
# np.copyto(c, out)
def execute(self, s, v):
return jt.numpy_code(
v.shape,
v.dtype,
[s, v],
self.forward_code,
)
assert len(s.shape)==len(v.shape) and v.shape[:-1]==s.shape[:-1]
assert len(s.shape)==1 or len(s.shape)==2, "TODO: support n-dims searchsorted"
func = SearchsortedFunc(right)
return func(s, v)
# def execute(self, s, v):
# return jt.numpy_code(
# v.shape,
# v.dtype,
# [s, v],
# self.forward_code,
# )
# assert len(s.shape)==len(v.shape) and v.shape[:-1]==s.shape[:-1]
# assert len(s.shape)==1 or len(s.shape)==2, "TODO: support n-dims searchsorted"
# func = SearchsortedFunc(right)
# return func(s, v)
def cumprod(a, dim):
class CumprodFunc(jt.Function):
@ -831,4 +831,100 @@ def set_global_seed(seed):
import cupy
cupy.random.seed(seed)
except:
pass
pass
def searchsorted(sorted, values, right=False):
"""
Find the indices from the innermost dimension of `sorted` for each `values`.
Example::
sorted = jt.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]])
values = jt.array([[3, 6, 9], [3, 6, 9]])
ret = jt.searchsorted(sorted, values)
assert (ret == [[1, 3, 4], [1, 2, 4]]).all(), ret
ret = jt.searchsorted(sorted, values, right=True)
assert (ret == [[2, 3, 5], [1, 3, 4]]).all(), ret
sorted_1d = jt.array([1, 3, 5, 7, 9])
ret = jt.searchsorted(sorted_1d, values)
assert (ret == [[1, 3, 4], [1, 3, 4]]).all(), ret
"""
_searchsorted_header = f"""
namespace jittor {{
#ifdef JIT_cuda
__device__
#endif
inline static void searchsorted_kernel(int batch_id, int value_id,
int value_num, int sorted_num, int batch_stride,
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
int32* __restrict__ index_p) {{
int32 l = batch_id * batch_stride;
int32 r = l + sorted_num;
auto v = value_p[batch_id * value_num + value_id];
while (l<r) {{
int32 m = (l+r)/2;
if (sort_p[m] {"<=" if right else "<"} v)
l = m+1;
else
r = m;
}}
index_p[batch_id * value_num + value_id] = l - batch_id * batch_stride;
}}
#ifdef JIT_cuda
__global__ void searchsorted(int tn0, int tn1, int batch_num,
int value_num, int sorted_num, int batch_stride,
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
int32* __restrict__ index_p
) {{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
auto i1 = tid & ((1<<tn1)-1);
auto i0 = tid >> tn1;
for (int i=i0; i<batch_num; i+=1<<tn0)
for (int j=i1; j<value_num; j+=1<<tn1)
searchsorted_kernel(i, j, value_num, sorted_num, batch_stride, sort_p, value_p, index_p);
}}
inline static int get_thread_range_log(int& thread_num, int64 range) {{
int nbits = NanoVector::get_nbits(std::min((int64)thread_num, range)) - 2;
thread_num >>= nbits;
return nbits;
}}
#endif
}}
"""
_searchsorted_src = """
int value_num = in1->shape[in1->shape.size()-1];
int sorted_num = in0->shape[in0->shape.size()-1];
int32 batch_num = in0->num / sorted_num;
int32 batch_num2 = in1->num / value_num;
int32 batch_stride = batch_num == 1 ? 0 : sorted_num;
CHECK(batch_num == batch_num2 || batch_num == 1);
#ifdef JIT_cuda
int thread_num = 256*1024;
auto tn1 = get_thread_range_log(thread_num, value_num);
auto tn0 = get_thread_range_log(thread_num, batch_num2);
thread_num = 1<<(tn0+tn1);
int p1 = std::max(thread_num/1024, 1);
int p2 = std::min(thread_num, 1024);
searchsorted<<<p1,p2>>>(tn0, tn1, batch_num2, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p);
#else
for (int32 i=0; i<batch_num2; i++)
for (int32 j=0; j<value_num; j++)
searchsorted_kernel(i, j, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p);
#endif
"""
return jt.code(values.shape, "int32", [sorted, values],
cpu_header=_searchsorted_header,
cpu_src=_searchsorted_src,
cuda_header=_searchsorted_header,
cuda_src=_searchsorted_src)

View File

@ -193,7 +193,9 @@ def Resnet101(**kwargs):
y = model(x) # [10, 1000]
"""
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained: model.load("jittorhub://resnet101.pkl")
return model
resnet101 = Resnet101
def Resnet152(pretrained=False, **kwargs):

View File

@ -963,7 +963,7 @@ class Softplus(Module):
self.threshold = threshold
def execute(self, x):
return 1 / self.beta * jt.log(1 + (self.beta * x).exp())
return softplus(x, self.beta, self.threshold)
class Resize(Module):
def __init__(self, size, mode="nearest", align_corners=False):

View File

@ -0,0 +1,52 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import jittor.nn as jnn
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
import torchvision
except:
torch = None
tnn = None
torchvision = None
skip_this_test = True
# TODO: more test
# @unittest.skipIf(skip_this_test, "No Torch found")
class TestSearchSorted(unittest.TestCase):
def test_origin(self):
sorted = jt.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]])
values = jt.array([[3, 6, 9], [3, 6, 9]])
ret = jt.searchsorted(sorted, values)
assert (ret == [[1, 3, 4], [1, 2, 4]]).all(), ret
ret = jt.searchsorted(sorted, values, right=True)
assert (ret == [[2, 3, 5], [1, 3, 4]]).all(), ret
sorted_1d = jt.array([1, 3, 5, 7, 9])
ret = jt.searchsorted(sorted_1d, values)
assert (ret == [[1, 3, 4], [1, 3, 4]]).all(), ret
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_cuda(self):
self.test_origin()
if __name__ == "__main__":
unittest.main()

View File

@ -127,6 +127,7 @@ ArrayArgs VarHolder::fetch_sync() {
ItemData VarHolder::item() {
sync();
CHECK(var->num==1) << "Item var size should be 1, but got" << var->num;
ItemData data;
data.dtype = var->dtype();
auto dsize = data.dtype.dsize();