diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 1cea0614..8c6b66f0 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -8,7 +8,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.2.14' +__version__ = '1.2.2.15' from . import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 08458a87..2e9716ea 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -722,8 +722,100 @@ def triu_(x,diagonal=0): jt.Var.triu_ = triu_ +def python_pass_warper(mod_func, args, kw): + import importlib + mod, func = mod_func.rsplit(".", 1) + mod = importlib.import_module(mod) + func = getattr(mod, func) + args = args + ("**kw",) + args = ",".join(args) + return eval(f"func({args})") + +def auto_parallel(n, src, **kw): + """ + auto parallel(CPU and GPU) n-d for loop function like below: + + Before: + + void inner_func(int n0, int i0, int n1, int i1) { + ... + } + + for (int i0=0; i0= n*2, (args, n) + oargs = args[n*2:] + pargs = args[:n*2] + piargs = pargs[1::2] + pnargs = pargs[0::2] + pnargs2 = [ a.split()[-1] for a in pnargs ] + oargs2 = [ a.split()[-1] for a in oargs ] + entry_func_args_def = ",".join(["int tn"+str(i) for i in range(n)] + + pnargs + oargs) + entry_func_args = ",".join(["tn"+str(i) for i in range(n)] + + pnargs2 + oargs2) + tid_def = "" + tid_loop = "" + call_args = [] + for i in reversed(range(n)): + tid_def += f"\nauto tid{i} = tid & ((1<>tn{i};" + for i in range(n): + tid_loop += f"\nfor (int i{i}=tid{i}; i{i}<{pnargs2[i]}; i{i}+=tn{i})" + call_args.append(pnargs2[i]) + call_args.append(f"i{i}") + call_args += oargs2 + call_args = ",".join(call_args) + xn = '\n' + new_src = f""" +#ifdef JIT_cuda +__device__ +#endif +{src.replace(func_name, func_name+"_inner", 1)} + +#ifdef JIT_cuda +__global__ static void {func_name}_entry({entry_func_args_def}) {{ + int tid = threadIdx.x + blockIdx.x * blockDim.x; + {tid_def} + {tid_loop} + {func_name}_inner({call_args}); +}} +#endif + +inline static void {func_name}({",".join(pargs+oargs)}) {{ +#ifdef JIT_cuda + int thread_num = 256*1024; + {xn.join([f"int tn{i} = NanoVector::get_nbits(std::min(thread_num, {pnargs2[i]})) - 2;thread_num >>= tn{i};" for i in reversed(range(n))])} + thread_num = 1<<({"+".join([f"tn{i}" for i in range(n)])}); + int p1 = std::max(thread_num/1024, 1); + int p2 = std::min(thread_num, 1024); + {func_name}_entry<<>>({entry_func_args}); +#else + {xn.join([f"for (int i{i}=0; i{i}<{pnargs2[i]}; i{i}++)" for i in range(n)])} + {func_name}_inner({call_args}); +#endif +}} +""" + return new_src def searchsorted(sorted, values, right=False): """ @@ -748,11 +840,10 @@ Example:: _searchsorted_header = f""" namespace jittor {{ -#ifdef JIT_cuda -__device__ -#endif -inline static void searchsorted_kernel(int batch_id, int value_id, - int value_num, int sorted_num, int batch_stride, +@python.jittor.auto_parallel(2) +inline static void searchsorted( + int batch_num, int batch_id, int value_num, int value_id, + int sorted_num, int batch_stride, {sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p, int32* __restrict__ index_p) {{ int32 l = batch_id * batch_stride; @@ -768,27 +859,6 @@ inline static void searchsorted_kernel(int batch_id, int value_id, index_p[batch_id * value_num + value_id] = l - batch_id * batch_stride; }} -#ifdef JIT_cuda -__global__ void searchsorted(int tn0, int tn1, int batch_num, - int value_num, int sorted_num, int batch_stride, - {sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p, - int32* __restrict__ index_p -) {{ - int tid = threadIdx.x + blockIdx.x * blockDim.x; - auto i1 = tid & ((1<> tn1; - for (int i=i0; i>= nbits; - return nbits; -}} -#endif - }} """ _searchsorted_src = """ @@ -799,19 +869,7 @@ inline static int get_thread_range_log(int& thread_num, int64 range) {{ int32 batch_stride = batch_num == 1 ? 0 : sorted_num; CHECK(batch_num == batch_num2 || batch_num == 1); - #ifdef JIT_cuda - int thread_num = 256*1024; - auto tn1 = get_thread_range_log(thread_num, value_num); - auto tn0 = get_thread_range_log(thread_num, batch_num2); - thread_num = 1<<(tn0+tn1); - int p1 = std::max(thread_num/1024, 1); - int p2 = std::min(thread_num, 1024); - searchsorted<<>>(tn0, tn1, batch_num2, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p); - #else - for (int32 i=0; i defs, string src, unordered_map comma; vector args; size_t l = k+1; if (expr == "for" || expr == "if" || expr == "expand_macro" || - expr == "is_def" || + expr == "is_def" || expr == "python" || (k defs, string src, unordered_map. +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "pyjt/py_obj_holder.h" +#include "pyjt/py_converter.h" +#include "pyjt/py_caller.h" + +namespace jittor { + +string py_caller(const string& mod_func, const vector& args, const map& kw) { + PyObjHolder mod(PyImport_ImportModule("jittor")); + PyObjHolder func(PyObject_GetAttrString(mod.obj, "python_pass_warper")); + PyObjHolder py_name(to_py_object(mod_func)); + PyObjHolder py_args(to_py_tuple(args)); + PyObjHolder py_kw(to_py_object(kw)); + PyObjHolder ret(PyObject_CallFunctionObjArgs(func.obj, py_name.obj, py_args.obj, py_kw.obj, nullptr)); + CHECK(is_type(ret.obj)) << "expect return type string."; + return from_py_object(ret.obj); +} + +} diff --git a/src/pyjt/py_caller.h b/src/pyjt/py_caller.h new file mode 100644 index 00000000..d6658b49 --- /dev/null +++ b/src/pyjt/py_caller.h @@ -0,0 +1,16 @@ +// *************************************************************** +// Copyright (c) 2020 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +string py_caller(const string& mod_func, const vector& args, const map& kw); + +}