mirror of https://github.com/Jittor/Jittor
python jit and auto_parallel
This commit is contained in:
parent
96545765ec
commit
37aafe431a
|
@ -8,7 +8,7 @@
|
||||||
# This file is subject to the terms and conditions defined in
|
# This file is subject to the terms and conditions defined in
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
__version__ = '1.2.2.14'
|
__version__ = '1.2.2.15'
|
||||||
from . import lock
|
from . import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -722,8 +722,100 @@ def triu_(x,diagonal=0):
|
||||||
|
|
||||||
jt.Var.triu_ = triu_
|
jt.Var.triu_ = triu_
|
||||||
|
|
||||||
|
def python_pass_warper(mod_func, args, kw):
|
||||||
|
import importlib
|
||||||
|
mod, func = mod_func.rsplit(".", 1)
|
||||||
|
mod = importlib.import_module(mod)
|
||||||
|
func = getattr(mod, func)
|
||||||
|
args = args + ("**kw",)
|
||||||
|
args = ",".join(args)
|
||||||
|
return eval(f"func({args})")
|
||||||
|
|
||||||
|
def auto_parallel(n, src, **kw):
|
||||||
|
"""
|
||||||
|
auto parallel(CPU and GPU) n-d for loop function like below:
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
void inner_func(int n0, int i0, int n1, int i1) {
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i0=0; i0<n0; i0++)
|
||||||
|
for (int i1=0; i1<n1; i1++)
|
||||||
|
inner_func(n0, i0, n1, i1, ...);
|
||||||
|
|
||||||
|
After:
|
||||||
|
|
||||||
|
@python.jittor.auto_parallel(2)
|
||||||
|
void inner_func(int n0, int i0, int n1, int i1) {
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
inner_func(n0, 0, n1, 0, ...);
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# src = prev_func func_name(args)code
|
||||||
|
a, b = src.split('(', 1)
|
||||||
|
prev_func, func_name = a.rsplit(None, 1)
|
||||||
|
args, code = b.split(')', 1)
|
||||||
|
args = args.split(',')
|
||||||
|
assert len(args) >= n*2, (args, n)
|
||||||
|
oargs = args[n*2:]
|
||||||
|
pargs = args[:n*2]
|
||||||
|
piargs = pargs[1::2]
|
||||||
|
pnargs = pargs[0::2]
|
||||||
|
pnargs2 = [ a.split()[-1] for a in pnargs ]
|
||||||
|
oargs2 = [ a.split()[-1] for a in oargs ]
|
||||||
|
entry_func_args_def = ",".join(["int tn"+str(i) for i in range(n)]
|
||||||
|
+ pnargs + oargs)
|
||||||
|
entry_func_args = ",".join(["tn"+str(i) for i in range(n)]
|
||||||
|
+ pnargs2 + oargs2)
|
||||||
|
tid_def = ""
|
||||||
|
tid_loop = ""
|
||||||
|
call_args = []
|
||||||
|
for i in reversed(range(n)):
|
||||||
|
tid_def += f"\nauto tid{i} = tid & ((1<<tn{i})-1);"
|
||||||
|
tid_def += f"\nauto tnum{i} = 1<<tn{i};"
|
||||||
|
tid_def += f"\ntid = tid>>tn{i};"
|
||||||
|
for i in range(n):
|
||||||
|
tid_loop += f"\nfor (int i{i}=tid{i}; i{i}<{pnargs2[i]}; i{i}+=tn{i})"
|
||||||
|
call_args.append(pnargs2[i])
|
||||||
|
call_args.append(f"i{i}")
|
||||||
|
call_args += oargs2
|
||||||
|
call_args = ",".join(call_args)
|
||||||
|
xn = '\n'
|
||||||
|
new_src = f"""
|
||||||
|
#ifdef JIT_cuda
|
||||||
|
__device__
|
||||||
|
#endif
|
||||||
|
{src.replace(func_name, func_name+"_inner", 1)}
|
||||||
|
|
||||||
|
#ifdef JIT_cuda
|
||||||
|
__global__ static void {func_name}_entry({entry_func_args_def}) {{
|
||||||
|
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
{tid_def}
|
||||||
|
{tid_loop}
|
||||||
|
{func_name}_inner({call_args});
|
||||||
|
}}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline static void {func_name}({",".join(pargs+oargs)}) {{
|
||||||
|
#ifdef JIT_cuda
|
||||||
|
int thread_num = 256*1024;
|
||||||
|
{xn.join([f"int tn{i} = NanoVector::get_nbits(std::min(thread_num, {pnargs2[i]})) - 2;thread_num >>= tn{i};" for i in reversed(range(n))])}
|
||||||
|
thread_num = 1<<({"+".join([f"tn{i}" for i in range(n)])});
|
||||||
|
int p1 = std::max(thread_num/1024, 1);
|
||||||
|
int p2 = std::min(thread_num, 1024);
|
||||||
|
{func_name}_entry<<<p1,p2>>>({entry_func_args});
|
||||||
|
#else
|
||||||
|
{xn.join([f"for (int i{i}=0; i{i}<{pnargs2[i]}; i{i}++)" for i in range(n)])}
|
||||||
|
{func_name}_inner({call_args});
|
||||||
|
#endif
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
return new_src
|
||||||
|
|
||||||
def searchsorted(sorted, values, right=False):
|
def searchsorted(sorted, values, right=False):
|
||||||
"""
|
"""
|
||||||
|
@ -748,11 +840,10 @@ Example::
|
||||||
_searchsorted_header = f"""
|
_searchsorted_header = f"""
|
||||||
namespace jittor {{
|
namespace jittor {{
|
||||||
|
|
||||||
#ifdef JIT_cuda
|
@python.jittor.auto_parallel(2)
|
||||||
__device__
|
inline static void searchsorted(
|
||||||
#endif
|
int batch_num, int batch_id, int value_num, int value_id,
|
||||||
inline static void searchsorted_kernel(int batch_id, int value_id,
|
int sorted_num, int batch_stride,
|
||||||
int value_num, int sorted_num, int batch_stride,
|
|
||||||
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
|
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
|
||||||
int32* __restrict__ index_p) {{
|
int32* __restrict__ index_p) {{
|
||||||
int32 l = batch_id * batch_stride;
|
int32 l = batch_id * batch_stride;
|
||||||
|
@ -768,27 +859,6 @@ inline static void searchsorted_kernel(int batch_id, int value_id,
|
||||||
index_p[batch_id * value_num + value_id] = l - batch_id * batch_stride;
|
index_p[batch_id * value_num + value_id] = l - batch_id * batch_stride;
|
||||||
}}
|
}}
|
||||||
|
|
||||||
#ifdef JIT_cuda
|
|
||||||
__global__ void searchsorted(int tn0, int tn1, int batch_num,
|
|
||||||
int value_num, int sorted_num, int batch_stride,
|
|
||||||
{sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p,
|
|
||||||
int32* __restrict__ index_p
|
|
||||||
) {{
|
|
||||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
auto i1 = tid & ((1<<tn1)-1);
|
|
||||||
auto i0 = tid >> tn1;
|
|
||||||
for (int i=i0; i<batch_num; i+=1<<tn0)
|
|
||||||
for (int j=i1; j<value_num; j+=1<<tn1)
|
|
||||||
searchsorted_kernel(i, j, value_num, sorted_num, batch_stride, sort_p, value_p, index_p);
|
|
||||||
}}
|
|
||||||
|
|
||||||
inline static int get_thread_range_log(int& thread_num, int64 range) {{
|
|
||||||
int nbits = NanoVector::get_nbits(std::min((int64)thread_num, range)) - 2;
|
|
||||||
thread_num >>= nbits;
|
|
||||||
return nbits;
|
|
||||||
}}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
_searchsorted_src = """
|
_searchsorted_src = """
|
||||||
|
@ -799,19 +869,7 @@ inline static int get_thread_range_log(int& thread_num, int64 range) {{
|
||||||
int32 batch_stride = batch_num == 1 ? 0 : sorted_num;
|
int32 batch_stride = batch_num == 1 ? 0 : sorted_num;
|
||||||
CHECK(batch_num == batch_num2 || batch_num == 1);
|
CHECK(batch_num == batch_num2 || batch_num == 1);
|
||||||
|
|
||||||
#ifdef JIT_cuda
|
searchsorted(batch_num2, 0, value_num, 0, sorted_num, batch_stride, in0_p, in1_p, out0_p);
|
||||||
int thread_num = 256*1024;
|
|
||||||
auto tn1 = get_thread_range_log(thread_num, value_num);
|
|
||||||
auto tn0 = get_thread_range_log(thread_num, batch_num2);
|
|
||||||
thread_num = 1<<(tn0+tn1);
|
|
||||||
int p1 = std::max(thread_num/1024, 1);
|
|
||||||
int p2 = std::min(thread_num, 1024);
|
|
||||||
searchsorted<<<p1,p2>>>(tn0, tn1, batch_num2, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p);
|
|
||||||
#else
|
|
||||||
for (int32 i=0; i<batch_num2; i++)
|
|
||||||
for (int32 j=0; j<value_num; j++)
|
|
||||||
searchsorted_kernel(i, j, value_num, sorted_num, batch_stride, in0_p, in1_p, out0_p);
|
|
||||||
#endif
|
|
||||||
"""
|
"""
|
||||||
return jt.code(values.shape, "int32", [sorted, values],
|
return jt.code(values.shape, "int32", [sorted, values],
|
||||||
cpu_header=_searchsorted_header,
|
cpu_header=_searchsorted_header,
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "ops/array_op.h"
|
#include "ops/array_op.h"
|
||||||
#include "lock.h"
|
#include "lock.h"
|
||||||
#include "opt/expr.h"
|
#include "opt/expr.h"
|
||||||
|
#include "pyjt/py_caller.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
|
@ -412,12 +413,17 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
||||||
size_t k=j+1;
|
size_t k=j+1;
|
||||||
while (k<src.size() && isvar(src[k])) k++;
|
while (k<src.size() && isvar(src[k])) k++;
|
||||||
string expr = src.substr(j, k-j);
|
string expr = src.substr(j, k-j);
|
||||||
|
// syntax for @python.module.function(args)
|
||||||
|
if (expr == "python") {
|
||||||
|
while (k<src.size() && (isvar(src[k]) || src[k]=='.' )) k++;
|
||||||
|
string full_expr = src.substr(j, k-j);
|
||||||
|
}
|
||||||
int presum = 1;
|
int presum = 1;
|
||||||
vector<int> comma;
|
vector<int> comma;
|
||||||
vector<string> args;
|
vector<string> args;
|
||||||
size_t l = k+1;
|
size_t l = k+1;
|
||||||
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
|
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
|
||||||
expr == "is_def" ||
|
expr == "is_def" || expr == "python" ||
|
||||||
(k<src.size() && src[k]=='(')) {
|
(k<src.size() && src[k]=='(')) {
|
||||||
ASSERT(src[k] == '(');
|
ASSERT(src[k] == '(');
|
||||||
comma.push_back(k);
|
comma.push_back(k);
|
||||||
|
@ -435,6 +441,30 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
||||||
for (uint i=0; i+1<comma.size(); i++)
|
for (uint i=0; i+1<comma.size(); i++)
|
||||||
args.push_back(src.substr(comma[i]+1, comma[i+1]-comma[i]-1));
|
args.push_back(src.substr(comma[i]+1, comma[i+1]-comma[i]-1));
|
||||||
}
|
}
|
||||||
|
if (expr == "python") {
|
||||||
|
string full_expr = src.substr(j, k-j);
|
||||||
|
LOGvvv << "python call" << full_expr << args;
|
||||||
|
int presum = 0;
|
||||||
|
auto ll = l;
|
||||||
|
while (l<src.size()) {
|
||||||
|
if (src[l] == '{')
|
||||||
|
presum++;
|
||||||
|
else if (src[l] == '}')
|
||||||
|
presum--;
|
||||||
|
if (presum==0 && (src[l] == '}' || src[l] == ';'))
|
||||||
|
break;
|
||||||
|
l++;
|
||||||
|
}
|
||||||
|
CHECK(l<src.size()) << "Jit error: braces are not matched.";
|
||||||
|
auto full_src = src.substr(ll, l+1-ll);
|
||||||
|
i = l;
|
||||||
|
full_src = py_caller(
|
||||||
|
full_expr.substr(7),
|
||||||
|
args, {{"src",full_src}}
|
||||||
|
);
|
||||||
|
new_src += precompile(defs, full_src, macros);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
// syntax @for(i, l, r, ...)
|
// syntax @for(i, l, r, ...)
|
||||||
// ij k l
|
// ij k l
|
||||||
if (expr == "for") {
|
if (expr == "for") {
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||||
|
// Maintainers:
|
||||||
|
// Dun Liang <randonlang@gmail.com>.
|
||||||
|
//
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#include "pyjt/py_obj_holder.h"
|
||||||
|
#include "pyjt/py_converter.h"
|
||||||
|
#include "pyjt/py_caller.h"
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
string py_caller(const string& mod_func, const vector<string>& args, const map<string,string>& kw) {
|
||||||
|
PyObjHolder mod(PyImport_ImportModule("jittor"));
|
||||||
|
PyObjHolder func(PyObject_GetAttrString(mod.obj, "python_pass_warper"));
|
||||||
|
PyObjHolder py_name(to_py_object<string>(mod_func));
|
||||||
|
PyObjHolder py_args(to_py_tuple(args));
|
||||||
|
PyObjHolder py_kw(to_py_object(kw));
|
||||||
|
PyObjHolder ret(PyObject_CallFunctionObjArgs(func.obj, py_name.obj, py_args.obj, py_kw.obj, nullptr));
|
||||||
|
CHECK(is_type<string>(ret.obj)) << "expect return type string.";
|
||||||
|
return from_py_object<string>(ret.obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||||
|
// Maintainers:
|
||||||
|
// Dun Liang <randonlang@gmail.com>.
|
||||||
|
//
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#pragma once
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
string py_caller(const string& mod_func, const vector<string>& args, const map<string,string>& kw);
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue