code op增加更多特性 :

* 可以有多个输出
* 允许输出动态大小的var
* code op内部可以写注释了
* code op可以使用别名了,通过@alias为inputa和outputs增加别名
This commit is contained in:
Dun Liang 2020-04-26 20:46:05 +08:00
parent 6c08f24cff
commit ab533a5f45
13 changed files with 495 additions and 194 deletions

View File

@ -60,7 +60,8 @@ for mdname in all_md:
"metadata": {
},
}
ipynb_name = mdname[:-2]+"ipynb"
ipynb_name = os.path.basename(mdname[:-2])+"ipynb"
ipynb_name = os.path.join(notebook_dir, ipynb_name)
print(mdname, len(src), len(blocks), len(cells), "--->", ipynb_name)
with open(os.path.join(notebook_dir, ipynb_name), "w") as f:
with open(ipynb_name, "w") as f:
f.write(json.dumps(ipynb))

View File

@ -17,36 +17,36 @@ def argmax_pool(x, size, stride, padding=0):
y = jt.code(y_shape, x.dtype, [x],
cpu_src=f'''
for (int i=0; i<outshape0; i++)
for (int j=0; j<outshape1; j++)
for (int k=0; k<outshape2; k++)
for (int l=0; l<outshape3; l++) {{
for (int i=0; i<out_shape0; i++)
for (int j=0; j<out_shape1; j++)
for (int k=0; k<out_shape2; k++)
for (int l=0; l<out_shape3; l++) {{
int kx=k*{stride}+{size}/2-{padding};
int ky=l*{stride}+{size}/2-{padding};
@out(i,j,k,l) = @in0(i,j,kx,ky);
for (int p=kx-{size}/2;p<=kx+{size}/2;p++)
for (int q=ky-{size}/2;q<=ky+{size}/2;q++)
if (p>=0 && q>=0 && p<in0shape2 && q<in0shape3)
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
if (@out(i,j,k,l) < @in0(i,j,p,q))
@out(i,j,k,l) = @in0(i,j,p,q);
}}
''',
cpu_grad_src = [f'''
for (int i=0; i<outshape0; i++)
for (int j=0; j<outshape1; j++)
for (int k=0; k<outshape2; k++)
for (int l=0; l<outshape3; l++) @out(i,j,k,l) = 0;
for (int i=0; i<out_shape0; i++)
for (int j=0; j<out_shape1; j++)
for (int k=0; k<out_shape2; k++)
for (int l=0; l<out_shape3; l++) @out(i,j,k,l) = 0;
for (int i=0; i<poutshape0; i++)
for (int j=0; j<poutshape1; j++)
for (int k=0; k<poutshape2; k++)
for (int l=0; l<poutshape3; l++) {{
for (int i=0; i<pout_shape0; i++)
for (int j=0; j<pout_shape1; j++)
for (int k=0; k<pout_shape2; k++)
for (int l=0; l<pout_shape3; l++) {{
int kx=k*{stride}+{size}/2-{padding};
int ky=l*{stride}+{size}/2-{padding};
int bo=1;
for (int p=kx-{size}/2;p<=kx+{size}/2 && bo;p++)
for (int q=ky-{size}/2;q<=ky+{size}/2 && bo;q++)
if (p>=0 && q>=0 && p<in0shape2 && q<in0shape3)
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
if (@pout(i,j,k,l) == @in0(i,j,p,q)) {{
@out(i,j,p,q) += @dout(i,j,k,l);
bo=0;

View File

@ -48,12 +48,12 @@ class Pool(Module):
int s2 = blockDim.y * gridDim.x;
int i1 = blockIdx.y;
int i0 = blockIdx.z;
for (int i3 = p3; i3 < outshape3; i3 += s3)
for (int i2 = p2; i2 < outshape2; i2 += s2) {{
for (int i3 = p3; i3 < out_shape3; i3 += s3)
for (int i2 = p2; i2 < out_shape2; i2 += s2) {{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0shape3);
int k2_ = min(k2 + {self.kernel_size}, in0shape2);
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
@out(i0, i1, i2, i3) = @in0(i0, i1, k2, k3);
@ -62,11 +62,11 @@ class Pool(Module):
@out(i0, i1, i2, i3) = {op}(@out(i0, i1, i2, i3), @in0(i0, i1, p, q));
}}
}}
int tx = min(1024, outshape3);
int ty = min(1024 / tx, outshape2);
int bx = (outshape2 - 1) / ty + 1;
int by = outshape1;
int bz = outshape0;
int tx = min(1024, out_shape3);
int ty = min(1024 / tx, out_shape2);
int bx = (out_shape2 - 1) / ty + 1;
int by = out_shape1;
int bz = out_shape0;
dim3 s1(bx, by, bz);
dim3 s2(tx, ty);
kernel1<<<s1, s2>>>(@ARGS);
@ -80,12 +80,12 @@ class Pool(Module):
int s2 = blockDim.y * gridDim.x;
int i1 = blockIdx.y;
int i0 = blockIdx.z;
for (int i3 = p3; i3 < poutshape3; i3 += s3)
for (int i2 = p2; i2 < poutshape2; i2 += s2) {{
for (int i3 = p3; i3 < pout_shape3; i3 += s3)
for (int i2 = p2; i2 < pout_shape2; i2 += s2) {{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0shape3);
int k2_ = min(k2 + {self.kernel_size}, in0shape2);
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
int bo=1;
@ -98,25 +98,25 @@ class Pool(Module):
}}
}}
}}
cudaMemsetAsync(outp, 0, out->size);
int tx = min(1024, poutshape3);
int ty = min(1024 / tx, poutshape2);
int bx = (poutshape2 - 1) / ty + 1;
int by = poutshape1;
int bz = poutshape0;
cudaMemsetAsync(out_p, 0, out->size);
int tx = min(1024, pout_shape3);
int ty = min(1024 / tx, pout_shape2);
int bx = (pout_shape2 - 1) / ty + 1;
int by = pout_shape1;
int bz = pout_shape0;
dim3 s1_(bx, by, bz);
dim3 s2_(tx, ty);
kernel3<<<s1_, s2_>>>(@ARGS);
'''],
cpu_src=f'''
for (int i0=0; i0<outshape0; i0++)
for (int i1=0; i1<outshape1; i1++)
for (int i2=0; i2<outshape2; i2++)
for (int i3=0; i3<outshape3; i3++) {{
for (int i0=0; i0<out_shape0; i0++)
for (int i1=0; i1<out_shape1; i1++)
for (int i2=0; i2<out_shape2; i2++)
for (int i3=0; i3<out_shape3; i3++) {{
int k2 = i2*{self.stride}-{self.padding};
int k3 = i3*{self.stride}-{self.padding};
int k2_ = std::min(k2 + {self.kernel_size}, in0shape2);
int k3_ = std::min(k3 + {self.kernel_size}, in0shape3);
int k2_ = std::min(k2 + {self.kernel_size}, in0_shape2);
int k3_ = std::min(k3 + {self.kernel_size}, in0_shape3);
k2 = std::max(0, k2);
k3 = std::max(0, k3);
@out(i0, i1, i2, i3) = @in0(i0, i1, k2, k3);
@ -126,19 +126,19 @@ class Pool(Module):
}}
''',
cpu_grad_src = [f'''
for (int i=0; i<outshape0; i++)
for (int j=0; j<outshape1; j++)
for (int k=0; k<outshape2; k++)
for (int l=0; l<outshape3; l++) @out(i,j,k,l) = 0;
for (int i=0; i<out_shape0; i++)
for (int j=0; j<out_shape1; j++)
for (int k=0; k<out_shape2; k++)
for (int l=0; l<out_shape3; l++) @out(i,j,k,l) = 0;
for (int i0=0; i0<poutshape0; i0++)
for (int i1=0; i1<poutshape1; i1++)
for (int i2=0; i2<poutshape2; i2++)
for (int i3=0; i3<poutshape3; i3++) {{
for (int i0=0; i0<pout_shape0; i0++)
for (int i1=0; i1<pout_shape1; i1++)
for (int i2=0; i2<pout_shape2; i2++)
for (int i3=0; i3<pout_shape3; i3++) {{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = std::min(k3 + {self.kernel_size}, in0shape3);
int k2_ = std::min(k2 + {self.kernel_size}, in0shape2);
int k3_ = std::min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = std::min(k2 + {self.kernel_size}, in0_shape2);
k3 = std::max(0, k3);
k2 = std::max(0, k2);
int bo=1;

View File

@ -106,7 +106,7 @@ class TestArray(unittest.TestCase):
with jt.flag_scope(use_cuda=1):
a = jt.array(np.float32([1,2,3]))
b = jt.code(a.shape, a.dtype, [a], cpu_src="""
for (int i=0; i<in0shape0; i++)
for (int i=0; i<in0_shape0; i++)
@out(i) = @in0(i)*@in0(i)*2;
""")
assert (b.data==[2,8,18]).all()

View File

@ -12,11 +12,11 @@ class TestCodeOp(unittest.TestCase):
a = jt.random([10])
b = jt.code(a.shape, a.dtype, [a],
cpu_src='''
for (int i=0; i<in0shape0; i++)
for (int i=0; i<in0_shape0; i++)
@out(i) = @in0(i)*@in0(i)*2;
''',
cpu_grad_src = ['''
for (int i=0; i<in0shape0; i++) {
for (int i=0; i<in0_shape0; i++) {
@out(i) = @dout(i)*@in0(i)*4;
}
'''])
@ -32,15 +32,15 @@ class TestCodeOp(unittest.TestCase):
b = jt.random([10])
c = jt.code(a.shape, a.dtype, [a,b],
cpu_src='''
for (int i=0; i<in0shape0; i++)
for (int i=0; i<in0_shape0; i++)
@out(i) = @in0(i)*@in1(i);
''',
cpu_grad_src = ['''
for (int i=0; i<in0shape0; i++) {
for (int i=0; i<in0_shape0; i++) {
@out(i) = @dout(i)*@in1(i);
}
''', '''
for (int i=0; i<in0shape0; i++) {
for (int i=0; i<in0_shape0; i++) {
@out(i) = @dout(i)*@in0(i);
}
'''])
@ -52,11 +52,102 @@ class TestCodeOp(unittest.TestCase):
def test_header(self):
a = jt.array([3,2,1])
b = jt.code(a.shape, a.dtype, [a],
cpu_header='#include <algorithm>',
cpu_header="""
#include <algorithm>
@alias(a, in0)
@alias(b, out)
""",
cpu_src="""
for (int i=0; i<in0shape0; i++)
@out(i) = @in0(i);
std::sort(&@out(0), &@out(in0shape0));
for (int i=0; i<a_shape0; i++)
@b(i) = @a(i);
std::sort(&@b(0), &@b(in0_shape0));
"""
)
assert (b.data==[1,2,3]).all()
def test_multi_output(self):
a = jt.array([3,2,1])
b,c = jt.code([[2],[4]], ["float32", "float64"], [a],
cpu_src="""
@alias(a, in0)
@alias(b, out0)
@alias(c, out1)
for (int i=0; i<a_shape0; i++) {
if (i<b_shape0) @b(i) = @a(i);
if (i<c_shape0) @c(i) = @a(i);
}
"""
)
assert b.shape == [2]
assert c.shape == [4]
assert b.dtype == "float32"
assert c.dtype == "float64"
assert (b.data == [3,2]).all()
assert (c.data[:3] == [3,2,1]).all()
def test_multi_output2(self):
a = jt.array([3,2,1])
b,c = jt.code([(1,), (1,)], [a.dtype, a.dtype], [a],
cpu_header="""
#include <iostream>
using namespace std;
""",
cpu_src="""
@alias(a, in0)
@alias(b, out0)
@alias(c, out1)
@b(0) = @c(0) = @a(0);
for (int i=0; i<a_shape0; i++) {
@b(0) = std::min(@b(0), @a(i));
@c(0) = std::max(@c(0), @a(i));
}
cout << "min:" << @b(0) << " max:" << @c(0) << endl;
"""
)
assert b.data == 1, b
assert c.data == 3, c
def test_vary_shape(self):
a = jt.array([5,-4,3,-2,1])
# negtive shape for max size of vary dimension
b,c = jt.code([(-5,), (-5,)], [a.dtype, a.dtype], [a],
cpu_src="""
@alias(a, in0)
@alias(b, out0)
@alias(c, out1)
int num_b=0, num_c=0;
for (int i=0; i<a_shape0; i++) {
if (@a(i)>0)
@b(num_b++) = @a(i);
else
@c(num_c++) = @a(i);
}
b->set_shape({num_b});
c->set_shape({num_c});
"""
)
assert (b.data == [5,3,1]).all()
assert (c.data == [-4,-2]).all()
def test_comment(self):
a = jt.array([3,2,1])
b = jt.code(a.shape, a.dtype, [a],
cpu_header='''
#include <algorithm>
// asd
/* asd
*/
''',
cpu_src="""
// test comment
/*
multi line
*/
@alias(a, in0)
for (int i=0; i<a_shape0; i++)
@out(i) = @a(i);
std::sort(&@out(0), &@out(a_shape0));
"""
)
assert (b.data==[1,2,3]).all()
@ -72,29 +163,29 @@ class TestCodeOp(unittest.TestCase):
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0shape0; i+=stride)
for (; i<in0_shape0; i+=stride)
@out(i) = @in0(i)*@in1(i);
}
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
kernel1<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
''',
cuda_grad_src = ['''
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0shape0; i+=stride)
for (; i<in0_shape0; i+=stride)
@out(i) = @dout(i)*@in1(i);
}
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
kernel2<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
''', '''
__global__ static void kernel3(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0shape0; i+=stride)
for (; i<in0_shape0; i+=stride)
@out(i) = @dout(i)*@in0(i);
}
kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
kernel3<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
'''])
da, db = jt.grad(c, [a, b])
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
@ -110,8 +201,8 @@ class TestCodeOp(unittest.TestCase):
cuda_src='''
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
@out(i,j) = @in0(i,j)*@in1(i,j);
}
kernel1<<<32, 32>>>(@ARGS);
@ -119,8 +210,8 @@ class TestCodeOp(unittest.TestCase):
cuda_grad_src = ['''
__global__ static void kernel(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in1(i,j);
}
kernel<<<32, 32>>>(@ARGS);
@ -128,8 +219,8 @@ class TestCodeOp(unittest.TestCase):
__global__ static void kernel(@ARGS_DEF) {
@PRECALC
@pout(0,0);
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in0(i,j);
}
kernel<<<32, 32>>>(@ARGS);

View File

@ -11,6 +11,7 @@ import jittor as jt
import numpy as np
import jittor.models as jtmodels
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
@ -18,10 +19,7 @@ try:
from torch import nn
except:
torch = None
skip_this_test = False
skip_this_test = True
@unittest.skipIf(skip_this_test, "skip_this_test")
class test_models(unittest.TestCase):

View File

@ -6,13 +6,16 @@
import unittest, os
import jittor as jt
from jittor import LOG
import json
import sys
from pathlib import Path
dirname = os.path.join(jt.flags.jittor_path, "notebook")
notebook_dir = os.path.join(jt.flags.cache_path, "notebook")
notebook_dir = os.path.join(str(Path.home()), ".cache","jittor","notebook")
tests = []
for mdname in os.listdir(dirname):
if not mdname.endswith(".src.md"): continue
# temporary disable model_test
if "LSGAN" in mdname: continue
tests.append(mdname[:-3])
try:
@ -27,7 +30,9 @@ def test(name):
jt.compiler.run_cmd("ipython "+ipynb_name)
def init():
jt.compiler.run_cmd("python3 "+os.path.join(dirname, "md_to_ipynb.py"))
cmd = sys.executable+" "+os.path.join(dirname, "md_to_ipynb.py")
LOG.i("init notebooks:", cmd)
jt.compiler.run_cmd(cmd)
src = """class TestNodebooks(unittest.TestCase):
@classmethod

View File

@ -16,7 +16,9 @@ def check(op, *args):
x = convert(x)
y = convert(y)
# str match nan and inf
assert x.dtype == y.dtype and x.shape == y.shape and str(x)==str(y), f"{x}\n{y}"
assert x.dtype == y.dtype and x.shape == y.shape
for a,b in zip(x.flatten(), y.flatten()):
assert str(a)[:5] == str(b)[:5], (a,b)
class TestUnaryOp(unittest.TestCase):
def test_unary_op(self):
@ -54,7 +56,7 @@ class TestUnaryOp(unittest.TestCase):
ja = jt.array(b)
jb = eval(f"jt.{op}(ja)")
jda = jt.grad(jb, ja)
assert (np.abs(jda.data-da)<1e-5).all(), (jda.data,da,op)
assert (np.allclose(jda.data, da)), (jda.data,da,op)
class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)):
pass

View File

@ -24,4 +24,6 @@ bool endswith(const string& a, const string& b);
// max_split: maximun split number(include)
vector<string> split(const string& s, const string& sep, int max_split=0);
string strip(const string& s);
} // jittor

View File

@ -125,13 +125,18 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
presum++;
k++;
}
ASSERT(presum==0) << "Jit error: braces are not matched.";
CHECK(presum==0) << "Jit error: braces are not matched.";
new_expr += S(eval(expr.substr(j+1, k-j-2), vars));
i = k-1;
continue;
} else {
if (expr[j] == '@') {
// syntax @@
i = j;
continue;
}
// syntax: @x
ASSERT(isvar(expr[j]));
CHECK(isvar(expr[j])) << expr[j] << "is not var";
size_t k=j+1;
while (k<expr.size() && isvar(expr[k])) k++;
if (k<expr.size() && expr[k]=='(') {
@ -200,7 +205,9 @@ void load_macros(const string& src, unordered_map<string,string>& macros) {
auto r=k;
while (r<l && src[r] != '(') r++;
auto body = q>p ? src.substr(p,q-p) : "";
body = (r<l?src.substr(r,l-r):"()") + body;
auto args = "<"+ (r+1<l?src.substr(r+1,l-r-2):"") + ">";
// header <args>body
body = args + body;
auto header = src.substr(k,r-k);
LOGvvvv << "header:" << header << "body:" << body;
macros[header] = body;
@ -211,9 +218,13 @@ void load_macros(const string& src, unordered_map<string,string>& macros) {
void expand_macro(const string& macro, const vector<string>& args, string& new_src) {
LOGvvvv << "expand_macro" << macro << "args:" << args;
auto i = macro.find(")");
if (macro.size() == 0 || macro[0] != '<') {
new_src += macro;
return;
}
auto i = macro.find(">");
ASSERT(i != string::npos);
// (a1, a2, ...)body
// <a1, a2, ...>body
// j k i
unordered_map<string, int> args_map;
for (uint j=1, l=0; j<i; l++) {
@ -373,7 +384,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
presum++;
k++;
}
ASSERT(presum==0) << "Jit error: braces are not matched.";
CHECK(presum==0) << "Jit error: braces are not matched.";
new_src += S(OpCompiler::eval(src.substr(j+1, k-j-2), defs));
i = k-1;
continue;
@ -389,7 +400,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
presum++;
k++;
}
ASSERT(presum==0) << "Jit error: braces are not matched.";
CHECK(presum==0) << "Jit error: braces are not matched.";
new_src += precompile(defs, src.substr(j+1, k-j-2), macros);
i = k-1;
continue;
@ -414,7 +425,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
comma.push_back(l);
l++;
}
ASSERT(presum==0) << "Jit error: braces are not matched.";
CHECK(presum==0) << "Jit error: braces are not matched.";
comma.push_back(l-1);
for (uint i=0; i+1<comma.size(); i++)
args.push_back(src.substr(comma[i]+1, comma[i+1]-comma[i]-1));
@ -516,7 +527,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
if (expr == "strcmp") {
// syntax: @strcmp(s1,s2)
// ij k l
ASSERT(args.size()==2u)
CHECK(args.size()==2u)
<< "Jit error: strcmp wrong arguments.";
auto s1 = precompile(defs, args[0], macros);
auto s2 = precompile(defs, args[1], macros);
@ -526,27 +537,76 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
i = l-1;
continue;
} else
if (expr == "alias") {
// syntax: @alias(s1,s2)
// ij k l
// alias(a,b)
// a->b
// a_type->b_type
// a_dim -> b_dim
// for i in a_dim:
// a_shapei -> b_shapei
// a_stridei -> b_stridei
CHECK(args.size()==2u)
<< "Jit error: alias wrong arguments.";
auto key = strip(precompile(defs, args[0], macros));
auto value = strip(precompile(defs, args[1], macros));
CHECK(defs.count(value+"_dim")) << '"' >> value >> '"' << "not exsit";
int dim = std::stoi(defs.at(value+"_dim"));
vector<string> keys = {"", "p", "dim", "type"};
for (int i=0; i<dim; i++) {
keys.push_back("stride"+S(i));
keys.push_back("shape"+S(i));
}
new_src += '\n';
for (auto& s : keys) {
string from = value+"_"+s;
string to = key+"_"+s;
if (!s.size()) {
from = value;
to = key;
}
if (defs.count(from))
from = defs.at(from);
else if (macros.count(from))
from = macros.at(from);
defs[to] = from;
macros[to] = from;
new_src += "#define "+to+" "+from+"\n";
}
i = l-1;
continue;
} else
if (args.size()) {
// syntax: @e0(i0,i1,...,in) -> e0p[i0*e0stride0+i1*e0stride1+...]
// syntax: @e0(i0,i1,...,in) -> e0_p[i0*e0_stride0+i1*e0_stride1+...]
ASSERT(expr.size());
int nid=(int)expr.size();
while (nid && isdigit(expr[nid-1])) nid--;
// xyz123 ---> prefix: xxx; suffix: 123
string prefix = expr.substr(0, nid);
string suffix = expr.substr(nid);
string up_prefix = prefix;
for (auto& c : up_prefix)
if (c>='a' && c<='z') c = c-'a'+'A';
string dim = up_prefix + "DIM" + suffix;
if (prefix == "e") prefix = "extras";
ASSERT(defs.count(dim)) << dim;
ASSERTop(defs.at(dim),==,S(args.size()));
expr = prefix + suffix; // e0 ->extras0
string dim;
if (expr == "x" && defs.count("XDIM")) {
dim = "XDIM";
prefix = "x";
} else
if (prefix == "e") {
// TODO: unify interface
prefix = "extras" + suffix;
dim = "EDIM" + suffix;
} else {
prefix = expr+"_";
dim = prefix + "dim";
}
CHECK(macros.count(dim)) << expr << "not exsit" << macros;
CHECKop(macros.at(dim),==,S(args.size())) << expr << "dimension not matched";
std::stringstream ss;
ss << expr << "p[";
ss << prefix << "p[";
for (uint ii=0; ii<args.size(); ii++) {
string arg = precompile(defs, args[ii], macros);
if (ii) ss << "+";
ss << '(' << arg << ")*" << expr << "stride" << ii;
ss << '(' << arg << ")*" << prefix << "stride" << ii;
}
ss << ']';
new_src += ss.str();
@ -568,10 +628,10 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
} else
new_src += src[i];
} catch (std::exception& e) {
uint il = i, ir = i;
while (il && src[il] != '\n') il--;
while (ir<src.size() && src[ir] != '\n') ir++;
string this_line = src.substr(il+1, ir-il-1);
int il = i, ir = i;
while (il>0 && src[il-1] != '\n') il--;
while (ir+1<src.size() && src[ir+1] != '\n') ir++;
string this_line = src.substr(il, ir-il+1);
LOGf << e.what() >> "\nJit compiler error:\n" >> this_line;
}
}
@ -579,7 +639,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
}
string OpCompiler::precompile(const unordered_map<string,string>& defs, const string& src) {
unordered_map<string, string> macros;
unordered_map<string, string> macros = defs;
return jittor::precompile(defs, src, macros);
}
@ -600,6 +660,10 @@ string OpCompiler::get_jit_src(Op* op) {
string after_include_src = "";
auto jit_define = op->get_jit_define();
for (auto &t : jit_define) {
// don't add CODE in define
// this allowed comment exsit in CODE
if (t.first == "CODE" || t.first == "HEADER")
continue;
string src = "#define " + t.first + " ";
for (char c : t.second) {
if (c=='\n') src += '\\';
@ -798,7 +862,7 @@ string OpCompiler::__get_fused_src(
presum++;
k++;
}
ASSERT(presum==0) << "Jit error: braces are not matched.";
CHECK(presum==0) << "Jit error: braces are not matched.";
for (;j < k-2; j++) {
if (isvar(src[j])) {
uint l=j;

View File

@ -16,50 +16,100 @@ namespace jittor {
static auto make_code = get_op_info("code")
.get_constructor<VarPtr, NanoVector, NanoString, vector<Var*>&&, string&&, vector<string>&&, string&&, string&&, vector<string>&&, string&&>();
static inline void check_vary_shape(NanoVector v) {
ASSERT(v.size()) << "Vary shape should not be zero dimension";
for (int i=0; i<v.size(); i++)
ASSERT((i == 0) ^ (v[0] >= 0))
<< "Vary shape should only occur in the first dimension:" << v;
}
CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
: in(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
{
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
out = create_output(shape, dtype);
ASSERTop(inputs.size(),<=,10);
_outputs.push_back(create_output(shape, dtype));
CHECKop(_inputs.size(),<=,10);
if (_outputs[0]->num < 0) {
flags.set(NodeFlags::_vary_shape);
check_vary_shape(_outputs[0]->shape);
}
}
CodeOp::CodeOp(
vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs,
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
{
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
CHECKop(shapes.size(),==,dtypes.size()) << "Number of outputs' shapes and dtypes should be the same";
_outputs.resize(shapes.size());
CHECKop(_inputs.size(),<=,10);
CHECKop(_outputs.size(),<=,10);
CHECKop(_outputs.size(),>,0);
for (int i=0; i<shapes.size(); i++) {
_outputs[i] = create_output(shapes[i], dtypes[i]);
if (_outputs[i]->num < 0) {
flags.set(NodeFlags::_vary_shape);
check_vary_shape(_outputs[i]->shape);
}
}
}
VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
// Do not have grad to extras input
string cpu_src = v_index < cpu_grad_src.size() ? cpu_grad_src[v_index] : "";
string cuda_src = v_index < cuda_grad_src.size() ? cuda_grad_src[v_index] : "";
if (!cuda_src.size() && !cpu_src.size()) return nullptr;
auto inputs = clone(in);
inputs.push_back(out);
auto inputs = clone(_inputs);
// TODO: remove unused deps
// dout -> dout
std::stringstream new_alias;
new_alias << "\n@alias(dout,in" << inputs.size() << ")\n";
inputs.push_back(dout);
// _outputs[i] -> poutj
for (int i=0; i<_outputs.size(); i++) {
new_alias << "\n@alias(pout" << i << ",in" << inputs.size() << ")\n";
if (_outputs[i] == out)
new_alias << "\n@alias(pout,in" << inputs.size() << ")\n";
inputs.push_back(_outputs[i]);
}
auto alias = new_alias.str();
return make_code(
in[v_index]->shape,
in[v_index]->dtype(),
_inputs[v_index]->shape,
_inputs[v_index]->dtype(),
move(inputs),
move(cpu_src), {}, clone(cpu_header),
move(cuda_src), {}, clone(cuda_header)
move(cpu_src), {}, alias+cpu_header,
move(cuda_src), {}, alias+cuda_header
);
}
void CodeOp::jit_prepare() {
add_jit_define("Tout", out->dtype());
add_jit_define("OUTDIM", JK::hex1(out->shape.size()));
if (in.size()>=2) {
auto pout = in.rbegin()[1];
auto dout = in.rbegin()[0];
add_jit_define("Tpout", pout->dtype());
add_jit_define("POUTDIM", JK::hex1(pout->shape.size()));
add_jit_define("Tdout", dout->dtype());
add_jit_define("DOUTDIM", JK::hex1(dout->shape.size()));
// forward: in0 in1 in2 -> out0 out1
// backward: in0 in1 in2 in3(pout0) in4(pout1)
add_jit_define("IN_SIZE", JK::hex1(_inputs.size()));
for (uint i=0; i<_inputs.size(); i++) {
jk << JK::key << "in" << JK::hex1(i) << "_dim" <<
JK::val << JK::hex1(_inputs[i]->shape.size()) << JK::end;
jk << JK::key << "in" << JK::hex1(i) << "_type" <<
JK::val << _inputs[i]->dtype() << JK::end;
}
add_jit_define("INSIZE", JK::hex1(in.size()));
for (uint i=0; i<in.size(); i++) {
add_jit_define("INDIM", JK::hex1(i), JK::hex1(in[i]->shape.size()));
add_jit_define("Tin", JK::hex1(i), in[i]->dtype());
add_jit_define("OUT_SIZE", JK::hex1(_outputs.size()));
for (uint i=0; i<_outputs.size(); i++) {
jk << JK::key << "out" << JK::hex1(i) << "_dim" <<
JK::val << JK::hex1(_outputs[i]->shape.size()) << JK::end;
jk << JK::key << "out" << JK::hex1(i) << "_type" <<
JK::val << _outputs[i]->dtype() << JK::end;
}
if (flags.get(NodeFlags::_cuda)) {
jk << JK::key << "HEADER" << JK::val << cuda_header;
@ -90,7 +140,8 @@ void CodeOp::jit_prepare() {
jk << JK::end;
} else {
add_jit_define("HEADER", cpu_header);
add_jit_define("CODE", cpu_src);
jk << JK::key << "CODE" << JK::val;
jk << cpu_src << JK::end;
ASSERT(cpu_src.size());
}
}
@ -101,49 +152,47 @@ void CodeOp::jit_prepare() {
#pragma GCC diagnostic ignored "-Wunused-variable"
@for(i, 0, INSIZE,
@define(in@i@@stride@{INDIM@i-1},1)
@for(i, 0, IN_SIZE,
@define(in@i@@_stride@{in@i@@_dim-1},1)
)
@for(i, 0, OUT_SIZE,
@define(out@i@@_stride@{out@i@@_dim-1},1)
)
@define(outstride@{OUTDIM-1},1)
@if(INSIZE>=2,
@define(poutstride@{POUTDIM-1},1)
@define(doutstride@{DOUTDIM-1},1)
,)
@define(ARGS_DEF,
@for(i, 0, INSIZE, @(
Tin@i* __restrict__ in@i@@p,
@for(j, 0, INDIM@i, @(index_t in@i@@shape@j,))
@for(i, 0, IN_SIZE, @(
in@i@@_type* __restrict__ in@i@@_p,
@for(j, 0, in@i@@_dim, @(index_t in@i@@_shape@j,))
))
@for(i, 0, OUTDIM, @(index_t outshape@i,))
Tout* __restrict__ outp
@for(i, 0, OUT_SIZE, @(
out@i@@_type* __restrict__ out@i@@_p,
@for(j, 0, out@i@@_dim, @(index_t out@i@@_shape@j,))
))
int __tmp
)
@define(ARGS,
@for(i, 0, INSIZE, @(
in@i@@p,
@for(j, 0, INDIM@i, @(in@i@@shape@j,))
@for(i, 0, IN_SIZE, @(
in@i@@_p,
@for(j, 0, in@i@@_dim, @(in@i@@_shape@j,))
))
@for(i, 0, OUTDIM, @(outshape@i,))
outp
@for(i, 0, OUT_SIZE, @(
out@i@@_p,
@for(j, 0, out@i@@_dim, @(out@i@@_shape@j,))
))
0
)
@define(PRECALC,
@for(i, 0, INSIZE,
@for(j, INDIM@i-2, -1, -1, auto in@i@@stride@j = in@i@@stride@{j+1} * in@i@@shape@{j+1};)
@for(i, 0, IN_SIZE,
@for(j, in@i@@_dim-2, -1, -1, auto in@i@@_stride@j = in@i@@_stride@{j+1} * in@i@@_shape@{j+1};)
)
@for(i, 0, OUT_SIZE,
@for(j, out@i@@_dim-2, -1, -1, auto out@i@@_stride@j = out@i@@_stride@{j+1} * out@i@@_shape@{j+1};)
)
@for(i, OUTDIM-2, -1, -1, auto outstride@i = outstride@{i+1} * outshape@{i+1};)
@if(INSIZE>=2,
auto* __restrict__ poutp = in@{INSIZE-2}@@p;
@for(i, 0, POUTDIM, index_t poutshape@i = in@{INSIZE-2}@@shape@i;)
@for(i, POUTDIM-2, -1, -1, auto poutstride@i = in@{INSIZE-2}@@stride@i;)
auto* __restrict__ doutp = in@{INSIZE-1}@@p;
@for(i, 0, DOUTDIM, index_t doutshape@i = in@{INSIZE-1}@@shape@i;)
@for(i, DOUTDIM-2, -1, -1, auto doutstride@i = in@{INSIZE-1}@@stride@i;)
,)
)
@alias(out, out0)
@HEADER
@ -151,14 +200,17 @@ namespace jittor {
void CodeOp::jit_run() {
// define inputs
@for(i, 0, INSIZE,
auto in@i = in[@i];
auto* __restrict__ in@i@@p = in[@i]->ptr<Tin@i>();
@for(j, 0, INDIM@i, index_t in@i@@shape@j = in[@i]->shape[@j];)
@for(i, 0, IN_SIZE,
auto in@i = _inputs[@i];
auto* __restrict__ in@i@@_p = _inputs[@i]->ptr<in@i@@_type>();
@for(j, 0, in@i@@_dim, index_t in@i@@_shape@j = _inputs[@i]->shape[@j];)
)
// define outputs
@for(i, 0, OUT_SIZE,
auto out@i = _outputs[@i];
auto* __restrict__ out@i@@_p = _outputs[@i]->ptr<out@i@@_type>();
@for(j, 0, out@i@@_dim, index_t out@i@@_shape@j = _outputs[@i]->shape[@j];)
)
// define out
auto* __restrict__ outp = out->ptr<Tout>();
@for(i, 0, OUTDIM, index_t outshape@i = out->shape[@i];)
@PRECALC

View File

@ -9,8 +9,8 @@
namespace jittor {
struct CodeOp : Op {
vector<Var*> in;
Var* out;
vector<Var*> _inputs;
vector<Var*> _outputs;
string cpu_src;
vector<string> cpu_grad_src;
string cpu_header;
@ -29,16 +29,19 @@ struct CodeOp : Op {
@param[in] inputs A list of input jittor Vars
@param[in] cpu_src cpu source code string, buildin value:
* in{x}, in{x}shape{y}, in{x}stride{y}, Tin{x}, in{x}p, @in0(...)
* out, outshape{y}, outstride{y}, Tout, outp, @out(...)
* in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...)
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
@param[in] cpu_grad_src A list of string,
cpu source code string for gradient, represents gradiant
for each inputm buildin value, buildin value:
* in{x}, in{x}shape{y}, in{x}stride{y}, Tin{x}, in{x}p, @in0(...)
* out, outshape{y}, outstride{y}, Tout, outp, @out(...)
* pout, poutshape{y}, poutstride{y}, Tpout, poutp, @pout(...)
* dout, doutshape{y}, doutstride{y}, Tdout, doutp, @dout(...)
* in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...)
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
* pout{x}, pout{x}_shape{y}, pout{x}_stride{y}, pout{x}_type, pout{x}_p, @pout{x}(...)
* pout, pout_shape{y}, pout_stride{y}, pout_type, pout_p, @pout(...)
* dout, dout_shape{y}, dout_stride{y}, dout_type, dout_p, @dout(...)
@param[in] cpu_header cpu header code string.
@ -47,25 +50,96 @@ struct CodeOp : Op {
@param[in] cuda_grad_src A list of string.
@param[in] cuda_header cuda header code string.
----------------
Example
Example-1:
```
a = jt.random([10])
b = jt.code(a.shape, a.dtype, [a],
cpu_src='''
for (int i=0; i<in0shape0; i++)
for (int i=0; i<in0_shape0; i++)
@out(i) = @in0(i)*@in0(i)*2;
''',
cpu_grad_src = ['''
for (int i=0; i<in0shape0; i++)
for (int i=0; i<in0_shape0; i++)
@out(i) = @dout(i)*@in0(i)*4;
'''])
```
Example2(CUDA):
Example-2:
```
a = jt.array([3,2,1])
b = jt.code(a.shape, a.dtype, [a],
cpu_header="""
#include <algorithm>
@alias(a, in0)
@alias(b, out)
"""",
cpu_src="""
for (int i=0; i<a_shape0; i++)
@b(i) = @a(i);
std::sort(&@b(0), &@b(in0_shape0));
"""
)
assert (b.data==[1,2,3]).all()
```
Example-3:
This example shows how to set multiple outputs in code op.
```
a = jt.array([3,2,1])
b,c = jt.code([(1,), (1,)], [a.dtype, a.dtype], [a],
cpu_header="""
#include <iostream>
using namespace std;
""",
cpu_src="""
@alias(a, in0)
@alias(b, out0)
@alias(c, out1)
@b(0) = @c(0) = @a(0);
for (int i=0; i<a_shape0; i++) {
@b(0) = std::min(@b(0), @a(i));
@c(0) = std::max(@c(0), @a(i));
}
cout << "min:" << @b(0) << " max:" << @c(0) << endl;
"""
)
assert b.data == 1, b
assert c.data == 3, c
```
Example-4:
This example shows how to use dynamic shape of jittor variables.
```
a = jt.array([5,-4,3,-2,1])
# negtive shape for max size of vary dimension
b,c = jt.code([(-5,), (-5,)], [a.dtype, a.dtype], [a],
cpu_src="""
@alias(a, in0)
@alias(b, out0)
@alias(c, out1)
int num_b=0, num_c=0;
for (int i=0; i<a_shape0; i++) {
if (@a(i)>0)
@b(num_b++) = @a(i);
else
@c(num_c++) = @a(i);
}
b->set_shape({num_b});
c->set_shape({num_c});
"""
)
assert (b.data == [5,3,1]).all()
assert (c.data == [-4,-2]).all()
```
CUDA Example-1:
This example shows how to use CUDA in code op.
```
a = jt.random([100000])
b = jt.random([100000])
@ -75,33 +149,34 @@ struct CodeOp : Op {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0shape0; i+=stride)
for (; i<in0_shape0; i+=stride)
@out(i) = @in0(i)*@in1(i);
}
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
kernel1<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
''',
cuda_grad_src = ['''
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0shape0; i+=stride)
for (; i<in0_shape0; i+=stride)
@out(i) = @dout(i)*@in1(i);
}
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
kernel2<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
''', '''
__global__ static void kernel3(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0shape0; i+=stride)
for (; i<in0_shape0; i+=stride)
@out(i) = @dout(i)*@in0(i);
}
kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
kernel3<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
'''])
```
Example3(CUDA):
CUDA Example-2:
This example shows how to use multi dimension data with CUDA.
```
a = jt.random((100,100))
b = jt.random((100,100))
@ -109,8 +184,8 @@ struct CodeOp : Op {
cuda_src='''
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
@out(i,j) = @in0(i,j)*@in1(i,j);
}
kernel1<<<32, 32>>>(@ARGS);
@ -118,16 +193,16 @@ struct CodeOp : Op {
cuda_grad_src = ['''
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in1(i,j);
}
kernel2<<<32, 32>>>(@ARGS);
''', '''
__global__ static void kernel3(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in0(i,j);
}
kernel3<<<32, 32>>>(@ARGS);
@ -136,6 +211,9 @@ struct CodeOp : Op {
*/
CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
// @attrs(multiple_outputs)
CodeOp(vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
const char* name() const override { return "code"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;

View File

@ -57,6 +57,14 @@ vector<string> split(const string& s, const string& sep, int max_split) {
return ret;
}
string strip(const string& s) {
int i=0;
while (i<s.size() && (s[i]==' ' || s[i]=='\t' || s[i]=='\n')) i++;
int j = s.size();
while (j>i && (s[j]==' ' || s[j]=='\t' || s[j]=='\n')) j--;
return s.substr(i,j-i);
}
void KernelIR::del_scope() {
if (father && (type=="define" || type=="func")) {
father->scope[attrs["lvalue"]].remove(this);