mirror of https://github.com/Jittor/Jittor
python jit and auto_parallel
This commit is contained in:
parent
96545765ec
commit
37aafe431a
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.14'
|
||||
__version__ = '1.2.2.15'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -722,8 +722,100 @@ def triu_(x,diagonal=0):
|
|||
|
||||
jt.Var.triu_ = triu_
|
||||
|
||||
def python_pass_warper(mod_func, args, kw):
|
||||
import importlib
|
||||
mod, func = mod_func.rsplit(".", 1)
|
||||
mod = importlib.import_module(mod)
|
||||
func = getattr(mod, func)
|
||||
args = args + ("**kw",)
|
||||
args = ",".join(args)
|
||||
return eval(f"func({args})")
|
||||
|
||||
def auto_parallel(n, src, **kw):
|
||||
"""
|
||||
auto parallel(CPU and GPU) n-d for loop function like below:
|
||||
|
||||
Before:
|
||||
|
||||
void inner_func(int n0, int i0, int n1, int i1) {
|
||||
...
|
||||
}
|
||||
|
||||
for (int i0=0; i0<n0; i0++)
|
||||
for (int i1=0; i1<n1; i1++)
|
||||
inner_func(n0, i0, n1, i1, ...);
|
||||
|
||||
After:
|
||||
|
||||
@python.jittor.auto_parallel(2)
|
||||
void inner_func(int n0, int i0, int n1, int i1) {
|
||||
...
|
||||
}
|
||||
|
||||
inner_func(n0, 0, n1, 0, ...);
|
||||
|
||||
|
||||
"""
|
||||
# src = prev_func func_name(args)code
|
||||
a, b = src.split('(', 1)
|
||||
prev_func, func_name = a.rsplit(None, 1)
|
||||
args, code = b.split(')', 1)
|
||||
args = args.split(',')
|
||||
assert len(args) >= n*2, (args, n)
|
||||
oargs = args[n*2:]
|
||||
pargs = args[:n*2]
|
||||
piargs = pargs[1::2]
|
||||
pnargs = pargs[0::2]
|
||||
pnargs2 = [ a.split()[-1] for a in pnargs ]
|
||||
oargs2 = [ a.split()[-1] for a in oargs ]
|
||||
entry_func_args_def = ",".join(["int tn"+str(i) for i in range(n)]
|
||||
+ pnargs + oargs)
|
||||
entry_func_args = ",".join(["tn"+str(i) for i in range(n)]
|
||||
+ pnargs2 + oargs2)
|
||||
tid_def = ""
|
||||
tid_loop = ""
|
||||
call_args = []
|
||||
for i in reversed(range(n)):
|
||||
tid_def += f"\nauto tid{i} = tid & ((1<<tn{i})-1);"
|
||||
tid_def += f"\nauto tnum{i} = 1<<tn{i};"
|
||||
tid_def += f"\ntid = tid>>tn{i};"
|
||||
for i in range(n):
|
||||
tid_loop += f"\nfor (int i{i}=tid{i}; i{i}<{pnargs2[i]}; i{i}+=tn{i})"
|
||||
call_args.append(pnargs2[i])
|
||||
call_args.append(f"i{i}")
|
||||
call_args += oargs2
|
||||
call_args = ",".join(call_args)
|
||||
xn = '\n'
|
||||
new_src = f"""
|
||||
#ifdef JIT_cuda
|
||||
__device__
|
||||
#endif
|
||||
{src.replace(func_name, func_name+"_inner", 1)}
|
||||
|
||||
#ifdef JIT_cuda
|
||||
__global__ static void {func_name}_entry({entry_func_args_def}) {{
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
{tid_def}
|
||||
{tid_loop}
|
||||
{func_name}_inner({call_args});
|
||||
}}
|
||||
#endif
|
||||
|
||||
inline static void {func_name}({",".join(pargs+oargs)}) {{
|
||||
#ifdef JIT_cuda
|
||||
int thread_num = 256*1024;
|
||||
{xn.join([f"int tn{i} = NanoVector::get_nbits(std::min(thread_num, {pnargs2[i]})) - 2;thread_num >>= tn{i};" for i in reversed(range(n))])}
|
||||
thread_num = 1<<({"+".join([f"tn{i}" for i in range(n)])});
|
||||
int p1 = std::max(thread_num/1024, 1);
|
||||
int p2 = std::min(thread_num, 1024);
|
||||
{func_name}_entry<<<p1,p2>>>({entry_func_args});
|
||||
#else
|
||||
{xn.join([f"for (int i{i}=0; i{i}<{pnargs2[i]}; i{i}++)" for i in range(n)])}
|
||||
{func_name}_inner({call_args});
|
||||
#endif
|
||||
}}
|
||||
"""
|
||||
return new_src
|
||||
|
||||
def searchsorted(sorted, values, right=False):
|
||||
"""
|
||||
|
@ -748,11 +840,10 @@ Example::
|
|||
_searchsorted_header = f"""
|
||||
namespace jittor {{
|
||||
|
||||
#ifdef JIT_cuda
|
||||
__device__
|
||||
#endif
|
||||
inline static void searchsorted_kernel(int batch_id, int value_id,
|
||||
int value_num, int sorted_num, int batch_stride,
|
||||
@python.jittor.auto_parallel(2)
|
||||
inline static void searchsorted(
|
||||
int batch_num, int batch_id, int value_num, int value_id,
|
||||
int sorted_num, int batch_stride,
|
||||
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
|
||||
int32* __restrict__ index_p) {{
|
||||
int32 l = batch_id * batch_stride;
|
||||
|
@ -768,27 +859,6 @@ inline static void searchsorted_kernel(int batch_id, int value_id,
|
|||
index_p[batch_id * value_num + value_id] = l - batch_id * batch_stride;
|
||||
}}
|
||||
|
||||
#ifdef JIT_cuda
|
||||
__global__ void searchsorted(int tn0, int tn1, int batch_num,
|
||||
int value_num, int sorted_num, int batch_stride,
|
||||
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
|
||||
int32* __restrict__ index_p
|
||||
) {{
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto i1 = tid & ((1<<tn1)-1);
|
||||
auto i0 = tid >> tn1;
|
||||
for (int i=i0; i<batch_num; i+=1<<tn0)
|
||||
for (int j=i1; j<value_num; j+=1<<tn1)
|
||||
searchsorted_kernel(i, j, value_num, sorted_num, batch_stride, sort_p, value_p, index_p);
|
||||
}}
|
||||
|
||||
inline static int get_thread_range_log(int& thread_num, int64 range) {{
|
||||
int nbits = NanoVector::get_nbits(std::min((int64)thread_num, range)) - 2;
|
||||
thread_num >>= nbits;
|
||||
return nbits;
|
||||
}}
|
||||
#endif
|
||||
|
||||
}}
|
||||
"""
|
||||
_searchsorted_src = """
|
||||
|
@ -799,19 +869,7 @@ inline static int get_thread_range_log(int& thread_num, int64 range) {{
|
|||
int32 batch_stride = batch_num == 1 ? 0 : sorted_num;
|
||||
CHECK(batch_num == batch_num2 || batch_num == 1);
|
||||
|
||||
#ifdef JIT_cuda
|
||||
int thread_num = 256*1024;
|
||||
auto tn1 = get_thread_range_log(thread_num, value_num);
|
||||
auto tn0 = get_thread_range_log(thread_num, batch_num2);
|
||||
thread_num = 1<<(tn0+tn1);
|
||||
int p1 = std::max(thread_num/1024, 1);
|
||||
int p2 = std::min(thread_num, 1024);
|
||||
searchsorted<<<p1,p2>>>(tn0, tn1, batch_num2, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p);
|
||||
#else
|
||||
for (int32 i=0; i<batch_num2; i++)
|
||||
for (int32 j=0; j<value_num; j++)
|
||||
searchsorted_kernel(i, j, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p);
|
||||
#endif
|
||||
searchsorted(batch_num2, 0, value_num, 0, sorted_num, batch_stride, in0_p, in1_p, out0_p);
|
||||
"""
|
||||
return jt.code(values.shape, "int32", [sorted, values],
|
||||
cpu_header=_searchsorted_header,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "ops/array_op.h"
|
||||
#include "lock.h"
|
||||
#include "opt/expr.h"
|
||||
#include "pyjt/py_caller.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -412,12 +413,17 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
size_t k=j+1;
|
||||
while (k<src.size() && isvar(src[k])) k++;
|
||||
string expr = src.substr(j, k-j);
|
||||
// syntax for @python.module.function(args)
|
||||
if (expr == "python") {
|
||||
while (k<src.size() && (isvar(src[k]) || src[k]=='.' )) k++;
|
||||
string full_expr = src.substr(j, k-j);
|
||||
}
|
||||
int presum = 1;
|
||||
vector<int> comma;
|
||||
vector<string> args;
|
||||
size_t l = k+1;
|
||||
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
|
||||
expr == "is_def" ||
|
||||
expr == "is_def" || expr == "python" ||
|
||||
(k<src.size() && src[k]=='(')) {
|
||||
ASSERT(src[k] == '(');
|
||||
comma.push_back(k);
|
||||
|
@ -435,6 +441,30 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
for (uint i=0; i+1<comma.size(); i++)
|
||||
args.push_back(src.substr(comma[i]+1, comma[i+1]-comma[i]-1));
|
||||
}
|
||||
if (expr == "python") {
|
||||
string full_expr = src.substr(j, k-j);
|
||||
LOGvvv << "python call" << full_expr << args;
|
||||
int presum = 0;
|
||||
auto ll = l;
|
||||
while (l<src.size()) {
|
||||
if (src[l] == '{')
|
||||
presum++;
|
||||
else if (src[l] == '}')
|
||||
presum--;
|
||||
if (presum==0 && (src[l] == '}' || src[l] == ';'))
|
||||
break;
|
||||
l++;
|
||||
}
|
||||
CHECK(l<src.size()) << "Jit error: braces are not matched.";
|
||||
auto full_src = src.substr(ll, l+1-ll);
|
||||
i = l;
|
||||
full_src = py_caller(
|
||||
full_expr.substr(7),
|
||||
args, {{"src",full_src}}
|
||||
);
|
||||
new_src += precompile(defs, full_src, macros);
|
||||
continue;
|
||||
}
|
||||
// syntax @for(i, l, r, ...)
|
||||
// ij k l
|
||||
if (expr == "for") {
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "pyjt/py_obj_holder.h"
|
||||
#include "pyjt/py_converter.h"
|
||||
#include "pyjt/py_caller.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
string py_caller(const string& mod_func, const vector<string>& args, const map<string,string>& kw) {
|
||||
PyObjHolder mod(PyImport_ImportModule("jittor"));
|
||||
PyObjHolder func(PyObject_GetAttrString(mod.obj, "python_pass_warper"));
|
||||
PyObjHolder py_name(to_py_object<string>(mod_func));
|
||||
PyObjHolder py_args(to_py_tuple(args));
|
||||
PyObjHolder py_kw(to_py_object(kw));
|
||||
PyObjHolder ret(PyObject_CallFunctionObjArgs(func.obj, py_name.obj, py_args.obj, py_kw.obj, nullptr));
|
||||
CHECK(is_type<string>(ret.obj)) << "expect return type string.";
|
||||
return from_py_object<string>(ret.obj);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
string py_caller(const string& mod_func, const vector<string>& args, const map<string,string>& kw);
|
||||
|
||||
}
|
Loading…
Reference in New Issue