mirror of https://github.com/Jittor/Jittor
code op增加更多特性 :
* 可以有多个输出 * 允许输出动态大小的var * code op内部可以写注释了 * code op可以使用别名了,通过@alias为inputa和outputs增加别名
This commit is contained in:
parent
6c08f24cff
commit
ab533a5f45
|
@ -60,7 +60,8 @@ for mdname in all_md:
|
|||
"metadata": {
|
||||
},
|
||||
}
|
||||
ipynb_name = mdname[:-2]+"ipynb"
|
||||
ipynb_name = os.path.basename(mdname[:-2])+"ipynb"
|
||||
ipynb_name = os.path.join(notebook_dir, ipynb_name)
|
||||
print(mdname, len(src), len(blocks), len(cells), "--->", ipynb_name)
|
||||
with open(os.path.join(notebook_dir, ipynb_name), "w") as f:
|
||||
with open(ipynb_name, "w") as f:
|
||||
f.write(json.dumps(ipynb))
|
|
@ -17,36 +17,36 @@ def argmax_pool(x, size, stride, padding=0):
|
|||
|
||||
y = jt.code(y_shape, x.dtype, [x],
|
||||
cpu_src=f'''
|
||||
for (int i=0; i<outshape0; i++)
|
||||
for (int j=0; j<outshape1; j++)
|
||||
for (int k=0; k<outshape2; k++)
|
||||
for (int l=0; l<outshape3; l++) {{
|
||||
for (int i=0; i<out_shape0; i++)
|
||||
for (int j=0; j<out_shape1; j++)
|
||||
for (int k=0; k<out_shape2; k++)
|
||||
for (int l=0; l<out_shape3; l++) {{
|
||||
int kx=k*{stride}+{size}/2-{padding};
|
||||
int ky=l*{stride}+{size}/2-{padding};
|
||||
@out(i,j,k,l) = @in0(i,j,kx,ky);
|
||||
for (int p=kx-{size}/2;p<=kx+{size}/2;p++)
|
||||
for (int q=ky-{size}/2;q<=ky+{size}/2;q++)
|
||||
if (p>=0 && q>=0 && p<in0shape2 && q<in0shape3)
|
||||
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
|
||||
if (@out(i,j,k,l) < @in0(i,j,p,q))
|
||||
@out(i,j,k,l) = @in0(i,j,p,q);
|
||||
}}
|
||||
''',
|
||||
cpu_grad_src = [f'''
|
||||
for (int i=0; i<outshape0; i++)
|
||||
for (int j=0; j<outshape1; j++)
|
||||
for (int k=0; k<outshape2; k++)
|
||||
for (int l=0; l<outshape3; l++) @out(i,j,k,l) = 0;
|
||||
for (int i=0; i<out_shape0; i++)
|
||||
for (int j=0; j<out_shape1; j++)
|
||||
for (int k=0; k<out_shape2; k++)
|
||||
for (int l=0; l<out_shape3; l++) @out(i,j,k,l) = 0;
|
||||
|
||||
for (int i=0; i<poutshape0; i++)
|
||||
for (int j=0; j<poutshape1; j++)
|
||||
for (int k=0; k<poutshape2; k++)
|
||||
for (int l=0; l<poutshape3; l++) {{
|
||||
for (int i=0; i<pout_shape0; i++)
|
||||
for (int j=0; j<pout_shape1; j++)
|
||||
for (int k=0; k<pout_shape2; k++)
|
||||
for (int l=0; l<pout_shape3; l++) {{
|
||||
int kx=k*{stride}+{size}/2-{padding};
|
||||
int ky=l*{stride}+{size}/2-{padding};
|
||||
int bo=1;
|
||||
for (int p=kx-{size}/2;p<=kx+{size}/2 && bo;p++)
|
||||
for (int q=ky-{size}/2;q<=ky+{size}/2 && bo;q++)
|
||||
if (p>=0 && q>=0 && p<in0shape2 && q<in0shape3)
|
||||
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
|
||||
if (@pout(i,j,k,l) == @in0(i,j,p,q)) {{
|
||||
@out(i,j,p,q) += @dout(i,j,k,l);
|
||||
bo=0;
|
||||
|
|
|
@ -48,12 +48,12 @@ class Pool(Module):
|
|||
int s2 = blockDim.y * gridDim.x;
|
||||
int i1 = blockIdx.y;
|
||||
int i0 = blockIdx.z;
|
||||
for (int i3 = p3; i3 < outshape3; i3 += s3)
|
||||
for (int i2 = p2; i2 < outshape2; i2 += s2) {{
|
||||
for (int i3 = p3; i3 < out_shape3; i3 += s3)
|
||||
for (int i2 = p2; i2 < out_shape2; i2 += s2) {{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3_ = min(k3 + {self.kernel_size}, in0shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size}, in0shape2);
|
||||
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
|
||||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
@out(i0, i1, i2, i3) = @in0(i0, i1, k2, k3);
|
||||
|
@ -62,11 +62,11 @@ class Pool(Module):
|
|||
@out(i0, i1, i2, i3) = {op}(@out(i0, i1, i2, i3), @in0(i0, i1, p, q));
|
||||
}}
|
||||
}}
|
||||
int tx = min(1024, outshape3);
|
||||
int ty = min(1024 / tx, outshape2);
|
||||
int bx = (outshape2 - 1) / ty + 1;
|
||||
int by = outshape1;
|
||||
int bz = outshape0;
|
||||
int tx = min(1024, out_shape3);
|
||||
int ty = min(1024 / tx, out_shape2);
|
||||
int bx = (out_shape2 - 1) / ty + 1;
|
||||
int by = out_shape1;
|
||||
int bz = out_shape0;
|
||||
dim3 s1(bx, by, bz);
|
||||
dim3 s2(tx, ty);
|
||||
kernel1<<<s1, s2>>>(@ARGS);
|
||||
|
@ -80,12 +80,12 @@ class Pool(Module):
|
|||
int s2 = blockDim.y * gridDim.x;
|
||||
int i1 = blockIdx.y;
|
||||
int i0 = blockIdx.z;
|
||||
for (int i3 = p3; i3 < poutshape3; i3 += s3)
|
||||
for (int i2 = p2; i2 < poutshape2; i2 += s2) {{
|
||||
for (int i3 = p3; i3 < pout_shape3; i3 += s3)
|
||||
for (int i2 = p2; i2 < pout_shape2; i2 += s2) {{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3_ = min(k3 + {self.kernel_size}, in0shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size}, in0shape2);
|
||||
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
|
||||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
int bo=1;
|
||||
|
@ -98,25 +98,25 @@ class Pool(Module):
|
|||
}}
|
||||
}}
|
||||
}}
|
||||
cudaMemsetAsync(outp, 0, out->size);
|
||||
int tx = min(1024, poutshape3);
|
||||
int ty = min(1024 / tx, poutshape2);
|
||||
int bx = (poutshape2 - 1) / ty + 1;
|
||||
int by = poutshape1;
|
||||
int bz = poutshape0;
|
||||
cudaMemsetAsync(out_p, 0, out->size);
|
||||
int tx = min(1024, pout_shape3);
|
||||
int ty = min(1024 / tx, pout_shape2);
|
||||
int bx = (pout_shape2 - 1) / ty + 1;
|
||||
int by = pout_shape1;
|
||||
int bz = pout_shape0;
|
||||
dim3 s1_(bx, by, bz);
|
||||
dim3 s2_(tx, ty);
|
||||
kernel3<<<s1_, s2_>>>(@ARGS);
|
||||
'''],
|
||||
cpu_src=f'''
|
||||
for (int i0=0; i0<outshape0; i0++)
|
||||
for (int i1=0; i1<outshape1; i1++)
|
||||
for (int i2=0; i2<outshape2; i2++)
|
||||
for (int i3=0; i3<outshape3; i3++) {{
|
||||
for (int i0=0; i0<out_shape0; i0++)
|
||||
for (int i1=0; i1<out_shape1; i1++)
|
||||
for (int i2=0; i2<out_shape2; i2++)
|
||||
for (int i3=0; i3<out_shape3; i3++) {{
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2_ = std::min(k2 + {self.kernel_size}, in0shape2);
|
||||
int k3_ = std::min(k3 + {self.kernel_size}, in0shape3);
|
||||
int k2_ = std::min(k2 + {self.kernel_size}, in0_shape2);
|
||||
int k3_ = std::min(k3 + {self.kernel_size}, in0_shape3);
|
||||
k2 = std::max(0, k2);
|
||||
k3 = std::max(0, k3);
|
||||
@out(i0, i1, i2, i3) = @in0(i0, i1, k2, k3);
|
||||
|
@ -126,19 +126,19 @@ class Pool(Module):
|
|||
}}
|
||||
''',
|
||||
cpu_grad_src = [f'''
|
||||
for (int i=0; i<outshape0; i++)
|
||||
for (int j=0; j<outshape1; j++)
|
||||
for (int k=0; k<outshape2; k++)
|
||||
for (int l=0; l<outshape3; l++) @out(i,j,k,l) = 0;
|
||||
for (int i=0; i<out_shape0; i++)
|
||||
for (int j=0; j<out_shape1; j++)
|
||||
for (int k=0; k<out_shape2; k++)
|
||||
for (int l=0; l<out_shape3; l++) @out(i,j,k,l) = 0;
|
||||
|
||||
for (int i0=0; i0<poutshape0; i0++)
|
||||
for (int i1=0; i1<poutshape1; i1++)
|
||||
for (int i2=0; i2<poutshape2; i2++)
|
||||
for (int i3=0; i3<poutshape3; i3++) {{
|
||||
for (int i0=0; i0<pout_shape0; i0++)
|
||||
for (int i1=0; i1<pout_shape1; i1++)
|
||||
for (int i2=0; i2<pout_shape2; i2++)
|
||||
for (int i3=0; i3<pout_shape3; i3++) {{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3_ = std::min(k3 + {self.kernel_size}, in0shape3);
|
||||
int k2_ = std::min(k2 + {self.kernel_size}, in0shape2);
|
||||
int k3_ = std::min(k3 + {self.kernel_size}, in0_shape3);
|
||||
int k2_ = std::min(k2 + {self.kernel_size}, in0_shape2);
|
||||
k3 = std::max(0, k3);
|
||||
k2 = std::max(0, k2);
|
||||
int bo=1;
|
||||
|
|
|
@ -106,7 +106,7 @@ class TestArray(unittest.TestCase):
|
|||
with jt.flag_scope(use_cuda=1):
|
||||
a = jt.array(np.float32([1,2,3]))
|
||||
b = jt.code(a.shape, a.dtype, [a], cpu_src="""
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in0(i)*@in0(i)*2;
|
||||
""")
|
||||
assert (b.data==[2,8,18]).all()
|
||||
|
|
|
@ -12,11 +12,11 @@ class TestCodeOp(unittest.TestCase):
|
|||
a = jt.random([10])
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in0(i)*@in0(i)*2;
|
||||
''',
|
||||
cpu_grad_src = ['''
|
||||
for (int i=0; i<in0shape0; i++) {
|
||||
for (int i=0; i<in0_shape0; i++) {
|
||||
@out(i) = @dout(i)*@in0(i)*4;
|
||||
}
|
||||
'''])
|
||||
|
@ -32,15 +32,15 @@ class TestCodeOp(unittest.TestCase):
|
|||
b = jt.random([10])
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
''',
|
||||
cpu_grad_src = ['''
|
||||
for (int i=0; i<in0shape0; i++) {
|
||||
for (int i=0; i<in0_shape0; i++) {
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
''', '''
|
||||
for (int i=0; i<in0shape0; i++) {
|
||||
for (int i=0; i<in0_shape0; i++) {
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
'''])
|
||||
|
@ -52,11 +52,102 @@ class TestCodeOp(unittest.TestCase):
|
|||
def test_header(self):
|
||||
a = jt.array([3,2,1])
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
cpu_header='#include <algorithm>',
|
||||
cpu_header="""
|
||||
#include <algorithm>
|
||||
@alias(a, in0)
|
||||
@alias(b, out)
|
||||
""",
|
||||
cpu_src="""
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @in0(i);
|
||||
std::sort(&@out(0), &@out(in0shape0));
|
||||
for (int i=0; i<a_shape0; i++)
|
||||
@b(i) = @a(i);
|
||||
std::sort(&@b(0), &@b(in0_shape0));
|
||||
"""
|
||||
)
|
||||
assert (b.data==[1,2,3]).all()
|
||||
|
||||
def test_multi_output(self):
|
||||
a = jt.array([3,2,1])
|
||||
b,c = jt.code([[2],[4]], ["float32", "float64"], [a],
|
||||
cpu_src="""
|
||||
@alias(a, in0)
|
||||
@alias(b, out0)
|
||||
@alias(c, out1)
|
||||
for (int i=0; i<a_shape0; i++) {
|
||||
if (i<b_shape0) @b(i) = @a(i);
|
||||
if (i<c_shape0) @c(i) = @a(i);
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert b.shape == [2]
|
||||
assert c.shape == [4]
|
||||
assert b.dtype == "float32"
|
||||
assert c.dtype == "float64"
|
||||
assert (b.data == [3,2]).all()
|
||||
assert (c.data[:3] == [3,2,1]).all()
|
||||
|
||||
def test_multi_output2(self):
|
||||
a = jt.array([3,2,1])
|
||||
b,c = jt.code([(1,), (1,)], [a.dtype, a.dtype], [a],
|
||||
cpu_header="""
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
""",
|
||||
cpu_src="""
|
||||
@alias(a, in0)
|
||||
@alias(b, out0)
|
||||
@alias(c, out1)
|
||||
@b(0) = @c(0) = @a(0);
|
||||
for (int i=0; i<a_shape0; i++) {
|
||||
@b(0) = std::min(@b(0), @a(i));
|
||||
@c(0) = std::max(@c(0), @a(i));
|
||||
}
|
||||
cout << "min:" << @b(0) << " max:" << @c(0) << endl;
|
||||
"""
|
||||
)
|
||||
assert b.data == 1, b
|
||||
assert c.data == 3, c
|
||||
|
||||
def test_vary_shape(self):
|
||||
a = jt.array([5,-4,3,-2,1])
|
||||
|
||||
# negtive shape for max size of vary dimension
|
||||
b,c = jt.code([(-5,), (-5,)], [a.dtype, a.dtype], [a],
|
||||
cpu_src="""
|
||||
@alias(a, in0)
|
||||
@alias(b, out0)
|
||||
@alias(c, out1)
|
||||
int num_b=0, num_c=0;
|
||||
for (int i=0; i<a_shape0; i++) {
|
||||
if (@a(i)>0)
|
||||
@b(num_b++) = @a(i);
|
||||
else
|
||||
@c(num_c++) = @a(i);
|
||||
}
|
||||
b->set_shape({num_b});
|
||||
c->set_shape({num_c});
|
||||
"""
|
||||
)
|
||||
assert (b.data == [5,3,1]).all()
|
||||
assert (c.data == [-4,-2]).all()
|
||||
|
||||
def test_comment(self):
|
||||
a = jt.array([3,2,1])
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
cpu_header='''
|
||||
#include <algorithm>
|
||||
// asd
|
||||
/* asd
|
||||
*/
|
||||
''',
|
||||
cpu_src="""
|
||||
// test comment
|
||||
/*
|
||||
multi line
|
||||
*/
|
||||
@alias(a, in0)
|
||||
for (int i=0; i<a_shape0; i++)
|
||||
@out(i) = @a(i);
|
||||
std::sort(&@out(0), &@out(a_shape0));
|
||||
"""
|
||||
)
|
||||
assert (b.data==[1,2,3]).all()
|
||||
|
@ -72,29 +163,29 @@ class TestCodeOp(unittest.TestCase):
|
|||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0shape0; i+=stride)
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
kernel1<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0shape0; i+=stride)
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
kernel2<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0shape0; i+=stride)
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
kernel3<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
'''])
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
|
||||
|
@ -110,8 +201,8 @@ class TestCodeOp(unittest.TestCase):
|
|||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
|
@ -119,8 +210,8 @@ class TestCodeOp(unittest.TestCase):
|
|||
cuda_grad_src = ['''
|
||||
__global__ static void kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel<<<32, 32>>>(@ARGS);
|
||||
|
@ -128,8 +219,8 @@ class TestCodeOp(unittest.TestCase):
|
|||
__global__ static void kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@pout(0,0);
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
kernel<<<32, 32>>>(@ARGS);
|
||||
|
|
|
@ -11,6 +11,7 @@ import jittor as jt
|
|||
import numpy as np
|
||||
import jittor.models as jtmodels
|
||||
|
||||
skip_this_test = False
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
|
@ -18,10 +19,7 @@ try:
|
|||
from torch import nn
|
||||
except:
|
||||
torch = None
|
||||
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
skip_this_test = True
|
||||
|
||||
@unittest.skipIf(skip_this_test, "skip_this_test")
|
||||
class test_models(unittest.TestCase):
|
||||
|
|
|
@ -6,13 +6,16 @@
|
|||
import unittest, os
|
||||
import jittor as jt
|
||||
from jittor import LOG
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
dirname = os.path.join(jt.flags.jittor_path, "notebook")
|
||||
notebook_dir = os.path.join(jt.flags.cache_path, "notebook")
|
||||
notebook_dir = os.path.join(str(Path.home()), ".cache","jittor","notebook")
|
||||
tests = []
|
||||
for mdname in os.listdir(dirname):
|
||||
if not mdname.endswith(".src.md"): continue
|
||||
# temporary disable model_test
|
||||
if "LSGAN" in mdname: continue
|
||||
tests.append(mdname[:-3])
|
||||
|
||||
try:
|
||||
|
@ -27,7 +30,9 @@ def test(name):
|
|||
jt.compiler.run_cmd("ipython "+ipynb_name)
|
||||
|
||||
def init():
|
||||
jt.compiler.run_cmd("python3 "+os.path.join(dirname, "md_to_ipynb.py"))
|
||||
cmd = sys.executable+" "+os.path.join(dirname, "md_to_ipynb.py")
|
||||
LOG.i("init notebooks:", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
|
||||
src = """class TestNodebooks(unittest.TestCase):
|
||||
@classmethod
|
||||
|
|
|
@ -16,7 +16,9 @@ def check(op, *args):
|
|||
x = convert(x)
|
||||
y = convert(y)
|
||||
# str match nan and inf
|
||||
assert x.dtype == y.dtype and x.shape == y.shape and str(x)==str(y), f"{x}\n{y}"
|
||||
assert x.dtype == y.dtype and x.shape == y.shape
|
||||
for a,b in zip(x.flatten(), y.flatten()):
|
||||
assert str(a)[:5] == str(b)[:5], (a,b)
|
||||
|
||||
class TestUnaryOp(unittest.TestCase):
|
||||
def test_unary_op(self):
|
||||
|
@ -54,7 +56,7 @@ class TestUnaryOp(unittest.TestCase):
|
|||
ja = jt.array(b)
|
||||
jb = eval(f"jt.{op}(ja)")
|
||||
jda = jt.grad(jb, ja)
|
||||
assert (np.abs(jda.data-da)<1e-5).all(), (jda.data,da,op)
|
||||
assert (np.allclose(jda.data, da)), (jda.data,da,op)
|
||||
|
||||
class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)):
|
||||
pass
|
||||
|
|
|
@ -24,4 +24,6 @@ bool endswith(const string& a, const string& b);
|
|||
// max_split: maximun split number(include)
|
||||
vector<string> split(const string& s, const string& sep, int max_split=0);
|
||||
|
||||
string strip(const string& s);
|
||||
|
||||
} // jittor
|
|
@ -125,13 +125,18 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
|
|||
presum++;
|
||||
k++;
|
||||
}
|
||||
ASSERT(presum==0) << "Jit error: braces are not matched.";
|
||||
CHECK(presum==0) << "Jit error: braces are not matched.";
|
||||
new_expr += S(eval(expr.substr(j+1, k-j-2), vars));
|
||||
i = k-1;
|
||||
continue;
|
||||
} else {
|
||||
if (expr[j] == '@') {
|
||||
// syntax @@
|
||||
i = j;
|
||||
continue;
|
||||
}
|
||||
// syntax: @x
|
||||
ASSERT(isvar(expr[j]));
|
||||
CHECK(isvar(expr[j])) << expr[j] << "is not var";
|
||||
size_t k=j+1;
|
||||
while (k<expr.size() && isvar(expr[k])) k++;
|
||||
if (k<expr.size() && expr[k]=='(') {
|
||||
|
@ -200,7 +205,9 @@ void load_macros(const string& src, unordered_map<string,string>& macros) {
|
|||
auto r=k;
|
||||
while (r<l && src[r] != '(') r++;
|
||||
auto body = q>p ? src.substr(p,q-p) : "";
|
||||
body = (r<l?src.substr(r,l-r):"()") + body;
|
||||
auto args = "<"+ (r+1<l?src.substr(r+1,l-r-2):"") + ">";
|
||||
// header <args>body
|
||||
body = args + body;
|
||||
auto header = src.substr(k,r-k);
|
||||
LOGvvvv << "header:" << header << "body:" << body;
|
||||
macros[header] = body;
|
||||
|
@ -211,9 +218,13 @@ void load_macros(const string& src, unordered_map<string,string>& macros) {
|
|||
|
||||
void expand_macro(const string& macro, const vector<string>& args, string& new_src) {
|
||||
LOGvvvv << "expand_macro" << macro << "args:" << args;
|
||||
auto i = macro.find(")");
|
||||
if (macro.size() == 0 || macro[0] != '<') {
|
||||
new_src += macro;
|
||||
return;
|
||||
}
|
||||
auto i = macro.find(">");
|
||||
ASSERT(i != string::npos);
|
||||
// (a1, a2, ...)body
|
||||
// <a1, a2, ...>body
|
||||
// j k i
|
||||
unordered_map<string, int> args_map;
|
||||
for (uint j=1, l=0; j<i; l++) {
|
||||
|
@ -373,7 +384,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
presum++;
|
||||
k++;
|
||||
}
|
||||
ASSERT(presum==0) << "Jit error: braces are not matched.";
|
||||
CHECK(presum==0) << "Jit error: braces are not matched.";
|
||||
new_src += S(OpCompiler::eval(src.substr(j+1, k-j-2), defs));
|
||||
i = k-1;
|
||||
continue;
|
||||
|
@ -389,7 +400,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
presum++;
|
||||
k++;
|
||||
}
|
||||
ASSERT(presum==0) << "Jit error: braces are not matched.";
|
||||
CHECK(presum==0) << "Jit error: braces are not matched.";
|
||||
new_src += precompile(defs, src.substr(j+1, k-j-2), macros);
|
||||
i = k-1;
|
||||
continue;
|
||||
|
@ -414,7 +425,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
comma.push_back(l);
|
||||
l++;
|
||||
}
|
||||
ASSERT(presum==0) << "Jit error: braces are not matched.";
|
||||
CHECK(presum==0) << "Jit error: braces are not matched.";
|
||||
comma.push_back(l-1);
|
||||
for (uint i=0; i+1<comma.size(); i++)
|
||||
args.push_back(src.substr(comma[i]+1, comma[i+1]-comma[i]-1));
|
||||
|
@ -516,7 +527,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
if (expr == "strcmp") {
|
||||
// syntax: @strcmp(s1,s2)
|
||||
// ij k l
|
||||
ASSERT(args.size()==2u)
|
||||
CHECK(args.size()==2u)
|
||||
<< "Jit error: strcmp wrong arguments.";
|
||||
auto s1 = precompile(defs, args[0], macros);
|
||||
auto s2 = precompile(defs, args[1], macros);
|
||||
|
@ -526,27 +537,76 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (expr == "alias") {
|
||||
// syntax: @alias(s1,s2)
|
||||
// ij k l
|
||||
|
||||
// alias(a,b)
|
||||
// a->b
|
||||
// a_type->b_type
|
||||
// a_dim -> b_dim
|
||||
// for i in a_dim:
|
||||
// a_shapei -> b_shapei
|
||||
// a_stridei -> b_stridei
|
||||
CHECK(args.size()==2u)
|
||||
<< "Jit error: alias wrong arguments.";
|
||||
auto key = strip(precompile(defs, args[0], macros));
|
||||
auto value = strip(precompile(defs, args[1], macros));
|
||||
CHECK(defs.count(value+"_dim")) << '"' >> value >> '"' << "not exsit";
|
||||
int dim = std::stoi(defs.at(value+"_dim"));
|
||||
vector<string> keys = {"", "p", "dim", "type"};
|
||||
for (int i=0; i<dim; i++) {
|
||||
keys.push_back("stride"+S(i));
|
||||
keys.push_back("shape"+S(i));
|
||||
}
|
||||
new_src += '\n';
|
||||
for (auto& s : keys) {
|
||||
string from = value+"_"+s;
|
||||
string to = key+"_"+s;
|
||||
if (!s.size()) {
|
||||
from = value;
|
||||
to = key;
|
||||
}
|
||||
if (defs.count(from))
|
||||
from = defs.at(from);
|
||||
else if (macros.count(from))
|
||||
from = macros.at(from);
|
||||
defs[to] = from;
|
||||
macros[to] = from;
|
||||
new_src += "#define "+to+" "+from+"\n";
|
||||
}
|
||||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (args.size()) {
|
||||
// syntax: @e0(i0,i1,...,in) -> e0p[i0*e0stride0+i1*e0stride1+...]
|
||||
// syntax: @e0(i0,i1,...,in) -> e0_p[i0*e0_stride0+i1*e0_stride1+...]
|
||||
ASSERT(expr.size());
|
||||
|
||||
int nid=(int)expr.size();
|
||||
while (nid && isdigit(expr[nid-1])) nid--;
|
||||
// xyz123 ---> prefix: xxx; suffix: 123
|
||||
string prefix = expr.substr(0, nid);
|
||||
string suffix = expr.substr(nid);
|
||||
string up_prefix = prefix;
|
||||
for (auto& c : up_prefix)
|
||||
if (c>='a' && c<='z') c = c-'a'+'A';
|
||||
string dim = up_prefix + "DIM" + suffix;
|
||||
if (prefix == "e") prefix = "extras";
|
||||
ASSERT(defs.count(dim)) << dim;
|
||||
ASSERTop(defs.at(dim),==,S(args.size()));
|
||||
expr = prefix + suffix; // e0 ->extras0
|
||||
string dim;
|
||||
if (expr == "x" && defs.count("XDIM")) {
|
||||
dim = "XDIM";
|
||||
prefix = "x";
|
||||
} else
|
||||
if (prefix == "e") {
|
||||
// TODO: unify interface
|
||||
prefix = "extras" + suffix;
|
||||
dim = "EDIM" + suffix;
|
||||
} else {
|
||||
prefix = expr+"_";
|
||||
dim = prefix + "dim";
|
||||
}
|
||||
CHECK(macros.count(dim)) << expr << "not exsit" << macros;
|
||||
CHECKop(macros.at(dim),==,S(args.size())) << expr << "dimension not matched";
|
||||
std::stringstream ss;
|
||||
ss << expr << "p[";
|
||||
ss << prefix << "p[";
|
||||
for (uint ii=0; ii<args.size(); ii++) {
|
||||
string arg = precompile(defs, args[ii], macros);
|
||||
if (ii) ss << "+";
|
||||
ss << '(' << arg << ")*" << expr << "stride" << ii;
|
||||
ss << '(' << arg << ")*" << prefix << "stride" << ii;
|
||||
}
|
||||
ss << ']';
|
||||
new_src += ss.str();
|
||||
|
@ -568,10 +628,10 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
} else
|
||||
new_src += src[i];
|
||||
} catch (std::exception& e) {
|
||||
uint il = i, ir = i;
|
||||
while (il && src[il] != '\n') il--;
|
||||
while (ir<src.size() && src[ir] != '\n') ir++;
|
||||
string this_line = src.substr(il+1, ir-il-1);
|
||||
int il = i, ir = i;
|
||||
while (il>0 && src[il-1] != '\n') il--;
|
||||
while (ir+1<src.size() && src[ir+1] != '\n') ir++;
|
||||
string this_line = src.substr(il, ir-il+1);
|
||||
LOGf << e.what() >> "\nJit compiler error:\n" >> this_line;
|
||||
}
|
||||
}
|
||||
|
@ -579,7 +639,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
}
|
||||
|
||||
string OpCompiler::precompile(const unordered_map<string,string>& defs, const string& src) {
|
||||
unordered_map<string, string> macros;
|
||||
unordered_map<string, string> macros = defs;
|
||||
return jittor::precompile(defs, src, macros);
|
||||
}
|
||||
|
||||
|
@ -600,6 +660,10 @@ string OpCompiler::get_jit_src(Op* op) {
|
|||
string after_include_src = "";
|
||||
auto jit_define = op->get_jit_define();
|
||||
for (auto &t : jit_define) {
|
||||
// don't add CODE in define
|
||||
// this allowed comment exsit in CODE
|
||||
if (t.first == "CODE" || t.first == "HEADER")
|
||||
continue;
|
||||
string src = "#define " + t.first + " ";
|
||||
for (char c : t.second) {
|
||||
if (c=='\n') src += '\\';
|
||||
|
@ -798,7 +862,7 @@ string OpCompiler::__get_fused_src(
|
|||
presum++;
|
||||
k++;
|
||||
}
|
||||
ASSERT(presum==0) << "Jit error: braces are not matched.";
|
||||
CHECK(presum==0) << "Jit error: braces are not matched.";
|
||||
for (;j < k-2; j++) {
|
||||
if (isvar(src[j])) {
|
||||
uint l=j;
|
||||
|
|
|
@ -16,50 +16,100 @@ namespace jittor {
|
|||
static auto make_code = get_op_info("code")
|
||||
.get_constructor<VarPtr, NanoVector, NanoString, vector<Var*>&&, string&&, vector<string>&&, string&&, string&&, vector<string>&&, string&&>();
|
||||
|
||||
static inline void check_vary_shape(NanoVector v) {
|
||||
ASSERT(v.size()) << "Vary shape should not be zero dimension";
|
||||
for (int i=0; i<v.size(); i++)
|
||||
ASSERT((i == 0) ^ (v[0] >= 0))
|
||||
<< "Vary shape should only occur in the first dimension:" << v;
|
||||
}
|
||||
|
||||
CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
|
||||
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
|
||||
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
|
||||
: in(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
|
||||
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
|
||||
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
|
||||
{
|
||||
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
|
||||
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
|
||||
out = create_output(shape, dtype);
|
||||
ASSERTop(inputs.size(),<=,10);
|
||||
_outputs.push_back(create_output(shape, dtype));
|
||||
CHECKop(_inputs.size(),<=,10);
|
||||
|
||||
if (_outputs[0]->num < 0) {
|
||||
flags.set(NodeFlags::_vary_shape);
|
||||
check_vary_shape(_outputs[0]->shape);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CodeOp::CodeOp(
|
||||
vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs,
|
||||
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
|
||||
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
|
||||
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
|
||||
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
|
||||
{
|
||||
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
|
||||
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
|
||||
CHECKop(shapes.size(),==,dtypes.size()) << "Number of outputs' shapes and dtypes should be the same";
|
||||
_outputs.resize(shapes.size());
|
||||
CHECKop(_inputs.size(),<=,10);
|
||||
CHECKop(_outputs.size(),<=,10);
|
||||
CHECKop(_outputs.size(),>,0);
|
||||
for (int i=0; i<shapes.size(); i++) {
|
||||
_outputs[i] = create_output(shapes[i], dtypes[i]);
|
||||
if (_outputs[i]->num < 0) {
|
||||
flags.set(NodeFlags::_vary_shape);
|
||||
check_vary_shape(_outputs[i]->shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
// Do not have grad to extras input
|
||||
string cpu_src = v_index < cpu_grad_src.size() ? cpu_grad_src[v_index] : "";
|
||||
string cuda_src = v_index < cuda_grad_src.size() ? cuda_grad_src[v_index] : "";
|
||||
if (!cuda_src.size() && !cpu_src.size()) return nullptr;
|
||||
auto inputs = clone(in);
|
||||
inputs.push_back(out);
|
||||
auto inputs = clone(_inputs);
|
||||
// TODO: remove unused deps
|
||||
// dout -> dout
|
||||
std::stringstream new_alias;
|
||||
new_alias << "\n@alias(dout,in" << inputs.size() << ")\n";
|
||||
inputs.push_back(dout);
|
||||
// _outputs[i] -> poutj
|
||||
for (int i=0; i<_outputs.size(); i++) {
|
||||
new_alias << "\n@alias(pout" << i << ",in" << inputs.size() << ")\n";
|
||||
if (_outputs[i] == out)
|
||||
new_alias << "\n@alias(pout,in" << inputs.size() << ")\n";
|
||||
inputs.push_back(_outputs[i]);
|
||||
}
|
||||
auto alias = new_alias.str();
|
||||
return make_code(
|
||||
in[v_index]->shape,
|
||||
in[v_index]->dtype(),
|
||||
_inputs[v_index]->shape,
|
||||
_inputs[v_index]->dtype(),
|
||||
move(inputs),
|
||||
move(cpu_src), {}, clone(cpu_header),
|
||||
move(cuda_src), {}, clone(cuda_header)
|
||||
move(cpu_src), {}, alias+cpu_header,
|
||||
move(cuda_src), {}, alias+cuda_header
|
||||
);
|
||||
}
|
||||
|
||||
void CodeOp::jit_prepare() {
|
||||
add_jit_define("Tout", out->dtype());
|
||||
add_jit_define("OUTDIM", JK::hex1(out->shape.size()));
|
||||
if (in.size()>=2) {
|
||||
auto pout = in.rbegin()[1];
|
||||
auto dout = in.rbegin()[0];
|
||||
add_jit_define("Tpout", pout->dtype());
|
||||
add_jit_define("POUTDIM", JK::hex1(pout->shape.size()));
|
||||
add_jit_define("Tdout", dout->dtype());
|
||||
add_jit_define("DOUTDIM", JK::hex1(dout->shape.size()));
|
||||
|
||||
// forward: in0 in1 in2 -> out0 out1
|
||||
// backward: in0 in1 in2 in3(pout0) in4(pout1)
|
||||
add_jit_define("IN_SIZE", JK::hex1(_inputs.size()));
|
||||
for (uint i=0; i<_inputs.size(); i++) {
|
||||
jk << JK::key << "in" << JK::hex1(i) << "_dim" <<
|
||||
JK::val << JK::hex1(_inputs[i]->shape.size()) << JK::end;
|
||||
jk << JK::key << "in" << JK::hex1(i) << "_type" <<
|
||||
JK::val << _inputs[i]->dtype() << JK::end;
|
||||
}
|
||||
add_jit_define("INSIZE", JK::hex1(in.size()));
|
||||
for (uint i=0; i<in.size(); i++) {
|
||||
add_jit_define("INDIM", JK::hex1(i), JK::hex1(in[i]->shape.size()));
|
||||
add_jit_define("Tin", JK::hex1(i), in[i]->dtype());
|
||||
add_jit_define("OUT_SIZE", JK::hex1(_outputs.size()));
|
||||
for (uint i=0; i<_outputs.size(); i++) {
|
||||
jk << JK::key << "out" << JK::hex1(i) << "_dim" <<
|
||||
JK::val << JK::hex1(_outputs[i]->shape.size()) << JK::end;
|
||||
jk << JK::key << "out" << JK::hex1(i) << "_type" <<
|
||||
JK::val << _outputs[i]->dtype() << JK::end;
|
||||
}
|
||||
if (flags.get(NodeFlags::_cuda)) {
|
||||
jk << JK::key << "HEADER" << JK::val << cuda_header;
|
||||
|
@ -90,7 +140,8 @@ void CodeOp::jit_prepare() {
|
|||
jk << JK::end;
|
||||
} else {
|
||||
add_jit_define("HEADER", cpu_header);
|
||||
add_jit_define("CODE", cpu_src);
|
||||
jk << JK::key << "CODE" << JK::val;
|
||||
jk << cpu_src << JK::end;
|
||||
ASSERT(cpu_src.size());
|
||||
}
|
||||
}
|
||||
|
@ -101,49 +152,47 @@ void CodeOp::jit_prepare() {
|
|||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
@for(i, 0, INSIZE,
|
||||
@define(in@i@@stride@{INDIM@i-1},1)
|
||||
@for(i, 0, IN_SIZE,
|
||||
@define(in@i@@_stride@{in@i@@_dim-1},1)
|
||||
)
|
||||
@for(i, 0, OUT_SIZE,
|
||||
@define(out@i@@_stride@{out@i@@_dim-1},1)
|
||||
)
|
||||
@define(outstride@{OUTDIM-1},1)
|
||||
@if(INSIZE>=2,
|
||||
@define(poutstride@{POUTDIM-1},1)
|
||||
@define(doutstride@{DOUTDIM-1},1)
|
||||
,)
|
||||
|
||||
@define(ARGS_DEF,
|
||||
@for(i, 0, INSIZE, @(
|
||||
Tin@i* __restrict__ in@i@@p,
|
||||
@for(j, 0, INDIM@i, @(index_t in@i@@shape@j,))
|
||||
@for(i, 0, IN_SIZE, @(
|
||||
in@i@@_type* __restrict__ in@i@@_p,
|
||||
@for(j, 0, in@i@@_dim, @(index_t in@i@@_shape@j,))
|
||||
))
|
||||
@for(i, 0, OUTDIM, @(index_t outshape@i,))
|
||||
Tout* __restrict__ outp
|
||||
@for(i, 0, OUT_SIZE, @(
|
||||
out@i@@_type* __restrict__ out@i@@_p,
|
||||
@for(j, 0, out@i@@_dim, @(index_t out@i@@_shape@j,))
|
||||
))
|
||||
int __tmp
|
||||
)
|
||||
|
||||
@define(ARGS,
|
||||
@for(i, 0, INSIZE, @(
|
||||
in@i@@p,
|
||||
@for(j, 0, INDIM@i, @(in@i@@shape@j,))
|
||||
@for(i, 0, IN_SIZE, @(
|
||||
in@i@@_p,
|
||||
@for(j, 0, in@i@@_dim, @(in@i@@_shape@j,))
|
||||
))
|
||||
@for(i, 0, OUTDIM, @(outshape@i,))
|
||||
outp
|
||||
@for(i, 0, OUT_SIZE, @(
|
||||
out@i@@_p,
|
||||
@for(j, 0, out@i@@_dim, @(out@i@@_shape@j,))
|
||||
))
|
||||
0
|
||||
)
|
||||
|
||||
@define(PRECALC,
|
||||
@for(i, 0, INSIZE,
|
||||
@for(j, INDIM@i-2, -1, -1, auto in@i@@stride@j = in@i@@stride@{j+1} * in@i@@shape@{j+1};)
|
||||
@for(i, 0, IN_SIZE,
|
||||
@for(j, in@i@@_dim-2, -1, -1, auto in@i@@_stride@j = in@i@@_stride@{j+1} * in@i@@_shape@{j+1};)
|
||||
)
|
||||
@for(i, 0, OUT_SIZE,
|
||||
@for(j, out@i@@_dim-2, -1, -1, auto out@i@@_stride@j = out@i@@_stride@{j+1} * out@i@@_shape@{j+1};)
|
||||
)
|
||||
@for(i, OUTDIM-2, -1, -1, auto outstride@i = outstride@{i+1} * outshape@{i+1};)
|
||||
@if(INSIZE>=2,
|
||||
auto* __restrict__ poutp = in@{INSIZE-2}@@p;
|
||||
@for(i, 0, POUTDIM, index_t poutshape@i = in@{INSIZE-2}@@shape@i;)
|
||||
@for(i, POUTDIM-2, -1, -1, auto poutstride@i = in@{INSIZE-2}@@stride@i;)
|
||||
|
||||
auto* __restrict__ doutp = in@{INSIZE-1}@@p;
|
||||
@for(i, 0, DOUTDIM, index_t doutshape@i = in@{INSIZE-1}@@shape@i;)
|
||||
@for(i, DOUTDIM-2, -1, -1, auto doutstride@i = in@{INSIZE-1}@@stride@i;)
|
||||
,)
|
||||
)
|
||||
|
||||
@alias(out, out0)
|
||||
|
||||
@HEADER
|
||||
|
||||
|
@ -151,14 +200,17 @@ namespace jittor {
|
|||
|
||||
void CodeOp::jit_run() {
|
||||
// define inputs
|
||||
@for(i, 0, INSIZE,
|
||||
auto in@i = in[@i];
|
||||
auto* __restrict__ in@i@@p = in[@i]->ptr<Tin@i>();
|
||||
@for(j, 0, INDIM@i, index_t in@i@@shape@j = in[@i]->shape[@j];)
|
||||
@for(i, 0, IN_SIZE,
|
||||
auto in@i = _inputs[@i];
|
||||
auto* __restrict__ in@i@@_p = _inputs[@i]->ptr<in@i@@_type>();
|
||||
@for(j, 0, in@i@@_dim, index_t in@i@@_shape@j = _inputs[@i]->shape[@j];)
|
||||
)
|
||||
// define outputs
|
||||
@for(i, 0, OUT_SIZE,
|
||||
auto out@i = _outputs[@i];
|
||||
auto* __restrict__ out@i@@_p = _outputs[@i]->ptr<out@i@@_type>();
|
||||
@for(j, 0, out@i@@_dim, index_t out@i@@_shape@j = _outputs[@i]->shape[@j];)
|
||||
)
|
||||
// define out
|
||||
auto* __restrict__ outp = out->ptr<Tout>();
|
||||
@for(i, 0, OUTDIM, index_t outshape@i = out->shape[@i];)
|
||||
|
||||
@PRECALC
|
||||
|
||||
|
|
|
@ -9,8 +9,8 @@
|
|||
namespace jittor {
|
||||
|
||||
struct CodeOp : Op {
|
||||
vector<Var*> in;
|
||||
Var* out;
|
||||
vector<Var*> _inputs;
|
||||
vector<Var*> _outputs;
|
||||
string cpu_src;
|
||||
vector<string> cpu_grad_src;
|
||||
string cpu_header;
|
||||
|
@ -29,16 +29,19 @@ struct CodeOp : Op {
|
|||
@param[in] inputs A list of input jittor Vars
|
||||
|
||||
@param[in] cpu_src cpu source code string, buildin value:
|
||||
* in{x}, in{x}shape{y}, in{x}stride{y}, Tin{x}, in{x}p, @in0(...)
|
||||
* out, outshape{y}, outstride{y}, Tout, outp, @out(...)
|
||||
* in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...)
|
||||
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
|
||||
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
|
||||
|
||||
@param[in] cpu_grad_src A list of string,
|
||||
cpu source code string for gradient, represents gradiant
|
||||
for each inputm buildin value, buildin value:
|
||||
* in{x}, in{x}shape{y}, in{x}stride{y}, Tin{x}, in{x}p, @in0(...)
|
||||
* out, outshape{y}, outstride{y}, Tout, outp, @out(...)
|
||||
* pout, poutshape{y}, poutstride{y}, Tpout, poutp, @pout(...)
|
||||
* dout, doutshape{y}, doutstride{y}, Tdout, doutp, @dout(...)
|
||||
* in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...)
|
||||
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
|
||||
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
|
||||
* pout{x}, pout{x}_shape{y}, pout{x}_stride{y}, pout{x}_type, pout{x}_p, @pout{x}(...)
|
||||
* pout, pout_shape{y}, pout_stride{y}, pout_type, pout_p, @pout(...)
|
||||
* dout, dout_shape{y}, dout_stride{y}, dout_type, dout_p, @dout(...)
|
||||
|
||||
@param[in] cpu_header cpu header code string.
|
||||
|
||||
|
@ -47,25 +50,96 @@ struct CodeOp : Op {
|
|||
@param[in] cuda_grad_src A list of string.
|
||||
|
||||
@param[in] cuda_header cuda header code string.
|
||||
|
||||
|
||||
----------------
|
||||
|
||||
Example
|
||||
Example-1:
|
||||
|
||||
```
|
||||
a = jt.random([10])
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in0(i)*@in0(i)*2;
|
||||
''',
|
||||
cpu_grad_src = ['''
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @dout(i)*@in0(i)*4;
|
||||
'''])
|
||||
```
|
||||
|
||||
Example2(CUDA):
|
||||
Example-2:
|
||||
```
|
||||
a = jt.array([3,2,1])
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
cpu_header="""
|
||||
#include <algorithm>
|
||||
@alias(a, in0)
|
||||
@alias(b, out)
|
||||
"""",
|
||||
cpu_src="""
|
||||
for (int i=0; i<a_shape0; i++)
|
||||
@b(i) = @a(i);
|
||||
std::sort(&@b(0), &@b(in0_shape0));
|
||||
"""
|
||||
)
|
||||
assert (b.data==[1,2,3]).all()
|
||||
```
|
||||
|
||||
Example-3:
|
||||
This example shows how to set multiple outputs in code op.
|
||||
```
|
||||
a = jt.array([3,2,1])
|
||||
b,c = jt.code([(1,), (1,)], [a.dtype, a.dtype], [a],
|
||||
cpu_header="""
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
""",
|
||||
cpu_src="""
|
||||
@alias(a, in0)
|
||||
@alias(b, out0)
|
||||
@alias(c, out1)
|
||||
@b(0) = @c(0) = @a(0);
|
||||
for (int i=0; i<a_shape0; i++) {
|
||||
@b(0) = std::min(@b(0), @a(i));
|
||||
@c(0) = std::max(@c(0), @a(i));
|
||||
}
|
||||
cout << "min:" << @b(0) << " max:" << @c(0) << endl;
|
||||
"""
|
||||
)
|
||||
assert b.data == 1, b
|
||||
assert c.data == 3, c
|
||||
```
|
||||
|
||||
Example-4:
|
||||
This example shows how to use dynamic shape of jittor variables.
|
||||
```
|
||||
a = jt.array([5,-4,3,-2,1])
|
||||
|
||||
# negtive shape for max size of vary dimension
|
||||
b,c = jt.code([(-5,), (-5,)], [a.dtype, a.dtype], [a],
|
||||
cpu_src="""
|
||||
@alias(a, in0)
|
||||
@alias(b, out0)
|
||||
@alias(c, out1)
|
||||
int num_b=0, num_c=0;
|
||||
for (int i=0; i<a_shape0; i++) {
|
||||
if (@a(i)>0)
|
||||
@b(num_b++) = @a(i);
|
||||
else
|
||||
@c(num_c++) = @a(i);
|
||||
}
|
||||
b->set_shape({num_b});
|
||||
c->set_shape({num_c});
|
||||
"""
|
||||
)
|
||||
assert (b.data == [5,3,1]).all()
|
||||
assert (c.data == [-4,-2]).all()
|
||||
```
|
||||
|
||||
|
||||
CUDA Example-1:
|
||||
This example shows how to use CUDA in code op.
|
||||
```
|
||||
a = jt.random([100000])
|
||||
b = jt.random([100000])
|
||||
|
@ -75,33 +149,34 @@ struct CodeOp : Op {
|
|||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0shape0; i+=stride)
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
kernel1<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0shape0; i+=stride)
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
kernel2<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0shape0; i+=stride)
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
kernel3<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
'''])
|
||||
```
|
||||
|
||||
Example3(CUDA):
|
||||
CUDA Example-2:
|
||||
This example shows how to use multi dimension data with CUDA.
|
||||
```
|
||||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
|
@ -109,8 +184,8 @@ struct CodeOp : Op {
|
|||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
|
@ -118,16 +193,16 @@ struct CodeOp : Op {
|
|||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
kernel3<<<32, 32>>>(@ARGS);
|
||||
|
@ -136,6 +211,9 @@ struct CodeOp : Op {
|
|||
*/
|
||||
CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
|
||||
|
||||
// @attrs(multiple_outputs)
|
||||
CodeOp(vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
|
||||
|
||||
const char* name() const override { return "code"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
|
|
|
@ -57,6 +57,14 @@ vector<string> split(const string& s, const string& sep, int max_split) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
string strip(const string& s) {
|
||||
int i=0;
|
||||
while (i<s.size() && (s[i]==' ' || s[i]=='\t' || s[i]=='\n')) i++;
|
||||
int j = s.size();
|
||||
while (j>i && (s[j]==' ' || s[j]=='\t' || s[j]=='\n')) j--;
|
||||
return s.substr(i,j-i);
|
||||
}
|
||||
|
||||
void KernelIR::del_scope() {
|
||||
if (father && (type=="define" || type=="func")) {
|
||||
father->scope[attrs["lvalue"]].remove(this);
|
||||
|
|
Loading…
Reference in New Issue