code op header tuning

This commit is contained in:
Dun Liang 2020-03-26 21:34:40 +08:00
parent aa8c50dbec
commit 0627774ec6
3 changed files with 94 additions and 98 deletions

View File

@ -67,40 +67,33 @@ class TestCodeOp(unittest.TestCase):
a = jt.random([100000])
b = jt.random([100000])
c = jt.code(a.shape, a.dtype, [a,b],
cuda_header='''
namespace jittor {
cuda_src='''
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i=0; i<in0shape0; i++)
for (; i<in0shape0; i+=stride)
@out(i) = @in0(i)*@in1(i);
}
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
''',
cuda_grad_src = ['''
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i=0; i<in0shape0; i++)
for (; i<in0shape0; i+=stride)
@out(i) = @dout(i)*@in1(i);
}
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
''', '''
__global__ static void kernel3(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i=0; i<in0shape0; i++)
for (; i<in0shape0; i+=stride)
@out(i) = @dout(i)*@in0(i);
}
}
''',
cuda_src='''
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
''',
cuda_grad_src = ['''
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
''', '''
kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
'''])
da, db = jt.grad(c, [a, b])
@ -114,38 +107,32 @@ class TestCodeOp(unittest.TestCase):
a = jt.random((100,100))
b = jt.random((100,100))
c = jt.code(a.shape, a.dtype, [a,b],
cuda_header='''
namespace jittor {
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @in0(i,j)*@in1(i,j);
}
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in1(i,j);
}
__global__ static void kernel3(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in0(i,j);
}
}
''',
cuda_src='''
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @in0(i,j)*@in1(i,j);
}
kernel1<<<32, 32>>>(@ARGS);
''',
cuda_grad_src = ['''
kernel2<<<32, 32>>>(@ARGS);
__global__ static void kernel(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in1(i,j);
}
kernel<<<32, 32>>>(@ARGS);
''', '''
kernel3<<<32, 32>>>(@ARGS);
__global__ static void kernel(@ARGS_DEF) {
@PRECALC
@pout(0,0);
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in0(i,j);
}
kernel<<<32, 32>>>(@ARGS);
'''])
da, db = jt.grad(c, [a, b])
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)

View File

@ -62,9 +62,32 @@ void CodeOp::jit_prepare() {
add_jit_define("Tin", JK::hex1(i), in[i]->dtype());
}
if (use_cuda) {
add_jit_define("HEADER", cuda_header);
add_jit_define("CODE", cuda_src);
jk << JK::key << "HEADER" << JK::val << cuda_header;
ASSERT(cuda_src.size());
jk << "\nnamespace jittor {\n";
int i=0;
// move cuda kernel function into header
for (; i<cuda_src.size(); i++) {
if (cuda_src[i] == ' ' || cuda_src[i] == '\t' || cuda_src[i] == '\n') {
jk << cuda_src[i];
} else
if (cuda_src[i] == '_') {
int presum = 0;
while (i < cuda_src.size()) {
jk << cuda_src[i];
if (cuda_src[i] == '{') presum ++;
else if (cuda_src[i] == '}') {
presum--;
if (presum==0)
break;
}
i++;
}
} else break;
}
jk << "}" << JK::end << JK::key << "CODE" << JK::val;
for (; i<cuda_src.size(); i++) jk << cuda_src[i];
jk << JK::end;
} else {
add_jit_define("HEADER", cpu_header);
add_jit_define("CODE", cpu_src);

View File

@ -70,40 +70,33 @@ struct CodeOp : Op {
a = jt.random([100000])
b = jt.random([100000])
c = jt.code(a.shape, a.dtype, [a,b],
cuda_header='''
namespace jittor {
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i=0; i<in0shape0; i++)
@out(i) = @in0(i)*@in1(i);
}
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i=0; i<in0shape0; i++)
@out(i) = @dout(i)*@in1(i);
}
__global__ static void kernel3(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i=0; i<in0shape0; i++)
@out(i) = @dout(i)*@in0(i);
}
}
''',
cuda_src='''
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0shape0; i+=stride)
@out(i) = @in0(i)*@in1(i);
}
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
''',
cuda_grad_src = ['''
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0shape0; i+=stride)
@out(i) = @dout(i)*@in1(i);
}
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
''', '''
__global__ static void kernel3(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0shape0; i+=stride)
@out(i) = @dout(i)*@in0(i);
}
kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
'''])
```
@ -113,37 +106,30 @@ struct CodeOp : Op {
a = jt.random((100,100))
b = jt.random((100,100))
c = jt.code(a.shape, a.dtype, [a,b],
cuda_header='''
namespace jittor {
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @in0(i,j)*@in1(i,j);
}
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in1(i,j);
}
__global__ static void kernel3(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in0(i,j);
}
}
''',
cuda_src='''
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @in0(i,j)*@in1(i,j);
}
kernel1<<<32, 32>>>(@ARGS);
''',
cuda_grad_src = ['''
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in1(i,j);
}
kernel2<<<32, 32>>>(@ARGS);
''', '''
__global__ static void kernel3(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in0(i,j);
}
kernel3<<<32, 32>>>(@ARGS);
'''])
```