python jit and auto_parallel

This commit is contained in:
Dun Liang 2021-01-07 16:22:26 +08:00
parent 96545765ec
commit 37aafe431a
5 changed files with 171 additions and 41 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.14'
__version__ = '1.2.2.15'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -722,8 +722,100 @@ def triu_(x,diagonal=0):
jt.Var.triu_ = triu_
def python_pass_warper(mod_func, args, kw):
import importlib
mod, func = mod_func.rsplit(".", 1)
mod = importlib.import_module(mod)
func = getattr(mod, func)
args = args + ("**kw",)
args = ",".join(args)
return eval(f"func({args})")
def auto_parallel(n, src, **kw):
"""
auto parallel(CPU and GPU) n-d for loop function like below:
Before:
void inner_func(int n0, int i0, int n1, int i1) {
...
}
for (int i0=0; i0<n0; i0++)
for (int i1=0; i1<n1; i1++)
inner_func(n0, i0, n1, i1, ...);
After:
@python.jittor.auto_parallel(2)
void inner_func(int n0, int i0, int n1, int i1) {
...
}
inner_func(n0, 0, n1, 0, ...);
"""
# src = prev_func func_name(args)code
a, b = src.split('(', 1)
prev_func, func_name = a.rsplit(None, 1)
args, code = b.split(')', 1)
args = args.split(',')
assert len(args) >= n*2, (args, n)
oargs = args[n*2:]
pargs = args[:n*2]
piargs = pargs[1::2]
pnargs = pargs[0::2]
pnargs2 = [ a.split()[-1] for a in pnargs ]
oargs2 = [ a.split()[-1] for a in oargs ]
entry_func_args_def = ",".join(["int tn"+str(i) for i in range(n)]
+ pnargs + oargs)
entry_func_args = ",".join(["tn"+str(i) for i in range(n)]
+ pnargs2 + oargs2)
tid_def = ""
tid_loop = ""
call_args = []
for i in reversed(range(n)):
tid_def += f"\nauto tid{i} = tid & ((1<<tn{i})-1);"
tid_def += f"\nauto tnum{i} = 1<<tn{i};"
tid_def += f"\ntid = tid>>tn{i};"
for i in range(n):
tid_loop += f"\nfor (int i{i}=tid{i}; i{i}<{pnargs2[i]}; i{i}+=tn{i})"
call_args.append(pnargs2[i])
call_args.append(f"i{i}")
call_args += oargs2
call_args = ",".join(call_args)
xn = '\n'
new_src = f"""
#ifdef JIT_cuda
__device__
#endif
{src.replace(func_name, func_name+"_inner", 1)}
#ifdef JIT_cuda
__global__ static void {func_name}_entry({entry_func_args_def}) {{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
{tid_def}
{tid_loop}
{func_name}_inner({call_args});
}}
#endif
inline static void {func_name}({",".join(pargs+oargs)}) {{
#ifdef JIT_cuda
int thread_num = 256*1024;
{xn.join([f"int tn{i} = NanoVector::get_nbits(std::min(thread_num, {pnargs2[i]})) - 2;thread_num >>= tn{i};" for i in reversed(range(n))])}
thread_num = 1<<({"+".join([f"tn{i}" for i in range(n)])});
int p1 = std::max(thread_num/1024, 1);
int p2 = std::min(thread_num, 1024);
{func_name}_entry<<<p1,p2>>>({entry_func_args});
#else
{xn.join([f"for (int i{i}=0; i{i}<{pnargs2[i]}; i{i}++)" for i in range(n)])}
{func_name}_inner({call_args});
#endif
}}
"""
return new_src
def searchsorted(sorted, values, right=False):
"""
@ -748,11 +840,10 @@ Example::
_searchsorted_header = f"""
namespace jittor {{
#ifdef JIT_cuda
__device__
#endif
inline static void searchsorted_kernel(int batch_id, int value_id,
int value_num, int sorted_num, int batch_stride,
@python.jittor.auto_parallel(2)
inline static void searchsorted(
int batch_num, int batch_id, int value_num, int value_id,
int sorted_num, int batch_stride,
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
int32* __restrict__ index_p) {{
int32 l = batch_id * batch_stride;
@ -768,27 +859,6 @@ inline static void searchsorted_kernel(int batch_id, int value_id,
index_p[batch_id * value_num + value_id] = l - batch_id * batch_stride;
}}
#ifdef JIT_cuda
__global__ void searchsorted(int tn0, int tn1, int batch_num,
int value_num, int sorted_num, int batch_stride,
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
int32* __restrict__ index_p
) {{
int tid = threadIdx.x + blockIdx.x * blockDim.x;
auto i1 = tid & ((1<<tn1)-1);
auto i0 = tid >> tn1;
for (int i=i0; i<batch_num; i+=1<<tn0)
for (int j=i1; j<value_num; j+=1<<tn1)
searchsorted_kernel(i, j, value_num, sorted_num, batch_stride, sort_p, value_p, index_p);
}}
inline static int get_thread_range_log(int& thread_num, int64 range) {{
int nbits = NanoVector::get_nbits(std::min((int64)thread_num, range)) - 2;
thread_num >>= nbits;
return nbits;
}}
#endif
}}
"""
_searchsorted_src = """
@ -799,19 +869,7 @@ inline static int get_thread_range_log(int& thread_num, int64 range) {{
int32 batch_stride = batch_num == 1 ? 0 : sorted_num;
CHECK(batch_num == batch_num2 || batch_num == 1);
#ifdef JIT_cuda
int thread_num = 256*1024;
auto tn1 = get_thread_range_log(thread_num, value_num);
auto tn0 = get_thread_range_log(thread_num, batch_num2);
thread_num = 1<<(tn0+tn1);
int p1 = std::max(thread_num/1024, 1);
int p2 = std::min(thread_num, 1024);
searchsorted<<<p1,p2>>>(tn0, tn1, batch_num2, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p);
#else
for (int32 i=0; i<batch_num2; i++)
for (int32 j=0; j<value_num; j++)
searchsorted_kernel(i, j, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p);
#endif
searchsorted(batch_num2, 0, value_num, 0, sorted_num, batch_stride, in0_p, in1_p, out0_p);
"""
return jt.code(values.shape, "int32", [sorted, values],
cpu_header=_searchsorted_header,

View File

@ -17,6 +17,7 @@
#include "ops/array_op.h"
#include "lock.h"
#include "opt/expr.h"
#include "pyjt/py_caller.h"
namespace jittor {
@ -412,12 +413,17 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
size_t k=j+1;
while (k<src.size() && isvar(src[k])) k++;
string expr = src.substr(j, k-j);
// syntax for @python.module.function(args)
if (expr == "python") {
while (k<src.size() && (isvar(src[k]) || src[k]=='.' )) k++;
string full_expr = src.substr(j, k-j);
}
int presum = 1;
vector<int> comma;
vector<string> args;
size_t l = k+1;
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
expr == "is_def" ||
expr == "is_def" || expr == "python" ||
(k<src.size() && src[k]=='(')) {
ASSERT(src[k] == '(');
comma.push_back(k);
@ -435,6 +441,30 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
for (uint i=0; i+1<comma.size(); i++)
args.push_back(src.substr(comma[i]+1, comma[i+1]-comma[i]-1));
}
if (expr == "python") {
string full_expr = src.substr(j, k-j);
LOGvvv << "python call" << full_expr << args;
int presum = 0;
auto ll = l;
while (l<src.size()) {
if (src[l] == '{')
presum++;
else if (src[l] == '}')
presum--;
if (presum==0 && (src[l] == '}' || src[l] == ';'))
break;
l++;
}
CHECK(l<src.size()) << "Jit error: braces are not matched.";
auto full_src = src.substr(ll, l+1-ll);
i = l;
full_src = py_caller(
full_expr.substr(7),
args, {{"src",full_src}}
);
new_src += precompile(defs, full_src, macros);
continue;
}
// syntax @for(i, l, r, ...)
// ij k l
if (expr == "for") {

26
src/pyjt/py_caller.cc Normal file
View File

@ -0,0 +1,26 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "pyjt/py_obj_holder.h"
#include "pyjt/py_converter.h"
#include "pyjt/py_caller.h"
namespace jittor {
string py_caller(const string& mod_func, const vector<string>& args, const map<string,string>& kw) {
PyObjHolder mod(PyImport_ImportModule("jittor"));
PyObjHolder func(PyObject_GetAttrString(mod.obj, "python_pass_warper"));
PyObjHolder py_name(to_py_object<string>(mod_func));
PyObjHolder py_args(to_py_tuple(args));
PyObjHolder py_kw(to_py_object(kw));
PyObjHolder ret(PyObject_CallFunctionObjArgs(func.obj, py_name.obj, py_args.obj, py_kw.obj, nullptr));
CHECK(is_type<string>(ret.obj)) << "expect return type string.";
return from_py_object<string>(ret.obj);
}
}

16
src/pyjt/py_caller.h Normal file
View File

@ -0,0 +1,16 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {
string py_caller(const string& mod_func, const vector<string>& args, const map<string,string>& kw);
}