mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of github.com:Jittor/jittor
This commit is contained in:
commit
41fef77707
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.3.10'
|
||||
__version__ = '1.3.3.14'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -560,7 +560,13 @@ def setup_mpi():
|
|||
if k == "mpi_test": continue
|
||||
setattr(core.Var, k, wrapper(mpi_ops.__dict__[k]))
|
||||
|
||||
if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
||||
in_mpi = inside_mpi()
|
||||
FIX_TORCH_ERROR = 0
|
||||
if os.name != 'nt' and not in_mpi:
|
||||
FIX_TORCH_ERROR = 1
|
||||
if "FIX_TORCH_ERROR" in os.environ:
|
||||
FIX_TORCH_ERROR = os.environ["FIX_TORCH_ERROR"] != "0"
|
||||
if FIX_TORCH_ERROR:
|
||||
try:
|
||||
import torch
|
||||
from jittor_utils import dirty_fix_pytorch_runtime_error
|
||||
|
@ -570,7 +576,6 @@ if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
|||
|
||||
cudnn = cublas = curand = cufft = None
|
||||
setup_mpi()
|
||||
in_mpi = inside_mpi()
|
||||
rank = mpi.world_rank() if in_mpi else 0
|
||||
world_size = mpi.world_size() if in_mpi else 1
|
||||
setup_nccl()
|
||||
|
|
|
@ -1033,8 +1033,10 @@ cc_flags += " -lstdc++ -ldl -shared "
|
|||
if platform.system() == 'Darwin':
|
||||
# TODO: if not using apple clang, there is no need to add -lomp
|
||||
cc_flags += "-undefined dynamic_lookup -lomp "
|
||||
if os.environ.get('CONDA_PREFIX', None):
|
||||
cc_flags += f" -L{os.path.join(os.environ['CONDA_PREFIX'], 'lib')} "
|
||||
if platform.machine() == "arm64":
|
||||
cc_flags += " -L/opt/homebrew/lib "
|
||||
cc_flags += " -L/opt/homebrew/lib "
|
||||
|
||||
opt_flags = ""
|
||||
|
||||
|
|
|
@ -115,6 +115,7 @@ void CudnnRnnOp::grads(Var** dout, VarPtr* dins) {
|
|||
VarPtr dy = dout[0];
|
||||
VarPtr dhy = dout[1];
|
||||
VarPtr dcy = cx ? dout[2] : nullptr;
|
||||
if (!dy.ptr) dy = make_number(0.0, y);
|
||||
if (!dhy.ptr) dhy = make_number(0.0, hy);
|
||||
if (!dcy.ptr && cx) dcy = make_number(0.0, cy);
|
||||
|
||||
|
|
|
@ -13,6 +13,78 @@ import numpy as np
|
|||
import math
|
||||
from collections.abc import Sequence,Iterable
|
||||
|
||||
def knn(unknown, known, k):
|
||||
''' find k neighbors for unknown array from known array
|
||||
|
||||
Args:
|
||||
|
||||
unknown (var): shape [b, n, c]
|
||||
known (var): shape [b, m, c]
|
||||
k (int)
|
||||
|
||||
'''
|
||||
b, n, c = unknown.shape
|
||||
_, m, _ = known.shape
|
||||
dists2 = jt.empty((b, n, k), dtype="float")
|
||||
idx = jt.empty((b, n, k), dtype="int")
|
||||
src = '''
|
||||
__inline_static__
|
||||
@python.jittor.auto_parallel(2, block_num=256)
|
||||
void knn_kernel(int b, int batch_index, int n, int index, int m,
|
||||
const float *__restrict__ unknown,
|
||||
const float *__restrict__ known,
|
||||
float *__restrict__ dist2,
|
||||
int *__restrict__ idx) {
|
||||
|
||||
#define K %s
|
||||
unknown += batch_index * n * 3;
|
||||
known += batch_index * m * 3;
|
||||
dist2 += batch_index * n * K;
|
||||
idx += batch_index * n * K;
|
||||
int j = index;
|
||||
{
|
||||
float ux = unknown[j * 3 + 0];
|
||||
float uy = unknown[j * 3 + 1];
|
||||
float uz = unknown[j * 3 + 2];
|
||||
|
||||
float tmp_dist[K];
|
||||
int tmp_idx[K];
|
||||
#pragma unroll
|
||||
for (int i=0; i<K; i++) tmp_dist[i] = 1e30;
|
||||
for (int k = 0; k < m; ++k) {
|
||||
float x = known[k * 3 + 0];
|
||||
float y = known[k * 3 + 1];
|
||||
float z = known[k * 3 + 2];
|
||||
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
|
||||
|
||||
int first = -1;
|
||||
#pragma unroll
|
||||
for (int i=0; i<K; i++)
|
||||
if (first == -1 && d<tmp_dist[i])
|
||||
first = i;
|
||||
if (first == -1) continue;
|
||||
#pragma unroll
|
||||
for (int i=0; i<K; i++)
|
||||
if (K-1-i > first) {
|
||||
tmp_dist[K-1-i] = tmp_dist[K-2-i];
|
||||
tmp_idx[K-1-i] = tmp_idx[K-2-i];
|
||||
}
|
||||
tmp_dist[first] = d;
|
||||
tmp_idx[first] = k;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i=0; i<K; i++) {
|
||||
dist2[j * K + i] = tmp_dist[i];
|
||||
idx[j * K + i] = tmp_idx[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
knn_kernel(in0->shape[0], 0, in0->shape[1], 0, in1->shape[1], in0_p, in1_p, out0_p, out1_p);
|
||||
''' % k
|
||||
return jt.code([unknown, known], [dists2, idx],
|
||||
cpu_src=src,
|
||||
cuda_src=src)
|
||||
|
||||
def index_add_(x, dim, index, tensor):
|
||||
""" Take out each index subscript vector of the dim dimension and add the corresponding tensor variable.
|
||||
|
||||
|
@ -950,7 +1022,7 @@ def python_pass_wrapper(mod_func, args, kw):
|
|||
args = ",".join(args)
|
||||
return eval(f"func({args})")
|
||||
|
||||
def auto_parallel(n, src, **kw):
|
||||
def auto_parallel(n, src, block_num=1024, **kw):
|
||||
"""
|
||||
auto parallel(CPU and GPU) n-d for loop function like below:
|
||||
|
||||
|
@ -1022,11 +1094,11 @@ __global__ static void {func_name}_entry({entry_func_args_def}) {{
|
|||
|
||||
inline static void {func_name}({",".join(pargs+oargs)}) {{
|
||||
#ifdef JIT_cuda
|
||||
int thread_num = 256*1024;
|
||||
int thread_num = 256*{block_num};
|
||||
{xn.join([f"int tn{i} = NanoVector::get_nbits(std::min(thread_num, {pnargs2[i]})) - 2;thread_num >>= tn{i};" for i in reversed(range(n))])}
|
||||
thread_num = 1<<({"+".join([f"tn{i}" for i in range(n)])});
|
||||
int p1 = std::max(thread_num/1024, 1);
|
||||
int p2 = std::min(thread_num, 1024);
|
||||
int p1 = std::max(thread_num/{block_num}, 1);
|
||||
int p2 = std::min(thread_num, {block_num});
|
||||
{func_name}_entry<<<p1,p2>>>({entry_func_args});
|
||||
#else
|
||||
{xn.join([f"for (int i{i}=0; i{i}<{pnargs2[i]}; i{i}++)" for i in range(n)])}
|
||||
|
|
|
@ -42,7 +42,12 @@ unordered_map<string,string> common_op_type_cuda_map = {
|
|||
{"erf", "(($1) ::erff(($2)))"},
|
||||
{"erfinv", "(($1) ::erfinvf(($1)($2)))"},
|
||||
{"cast", "(($1)($2))"},
|
||||
#ifdef _WIN32
|
||||
// windows don't have pow(float,int), cause undefined reference, fix it
|
||||
{"pow", "::pow(($1)($2),($1)($4))"},
|
||||
#else
|
||||
{"pow", "::pow(($2),($4))"},
|
||||
#endif
|
||||
{"maximum", "::max($1($2), $1($4))"},
|
||||
{"minimum", "::min($1($2), $1($4))"},
|
||||
{"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"},
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
def topk(input, k, dim=None, largest=True, sorted=True):
|
||||
if dim is None:
|
||||
dim = -1
|
||||
if dim < 0:
|
||||
dim += input.ndim
|
||||
|
||||
transpose_dims = [i for i in range(input.ndim)]
|
||||
transpose_dims[0] = dim
|
||||
transpose_dims[dim] = 0
|
||||
input = input.transpose(transpose_dims)
|
||||
index, values = jt.argsort(input, dim=0, descending=largest)
|
||||
indices = index[:k]
|
||||
values = values[:k]
|
||||
indices = indices.transpose(transpose_dims)
|
||||
values = values.transpose(transpose_dims)
|
||||
return [values, indices]
|
||||
|
||||
def knn(x, k):
|
||||
inner = -2 * jt.nn.bmm(x.transpose(0, 2, 1), x)
|
||||
xx = jt.sum(x ** 2, dim=1, keepdims=True)
|
||||
distance = -xx - inner - xx.transpose(0, 2, 1)
|
||||
return topk(distance, k=k, dim=-1)
|
||||
|
||||
class TestKnnOp(unittest.TestCase):
|
||||
def test_knn(self):
|
||||
jt_a = jt.randn(32,512,3)
|
||||
a1, b1 = jt.misc.knn(jt_a, jt_a, 16)
|
||||
a2, b2 = knn(jt_a.transpose(0,2,1), 16)
|
||||
a2 *= -1
|
||||
np.testing.assert_allclose(a1.data, a2.data, atol=1e-4)
|
||||
|
||||
if jt.has_cuda:
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
jt_a = jt.randn(32,512,3)
|
||||
a1, b1 = jt.misc.knn(jt_a, jt_a, 16)
|
||||
a2, b2 = knn(jt_a.transpose(0,2,1), 16)
|
||||
a2 *= -1
|
||||
np.testing.assert_allclose(a1.data, a2.data, atol=1e-4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue