mirror of https://github.com/Jittor/Jittor
code op header tuning
This commit is contained in:
parent
aa8c50dbec
commit
0627774ec6
|
@ -67,40 +67,33 @@ class TestCodeOp(unittest.TestCase):
|
|||
a = jt.random([100000])
|
||||
b = jt.random([100000])
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_header='''
|
||||
namespace jittor {
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
for (; i<in0shape0; i+=stride)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
|
||||
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
for (; i<in0shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
|
||||
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
for (; i<in0shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
|
||||
}
|
||||
''',
|
||||
cuda_src='''
|
||||
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
'''])
|
||||
da, db = jt.grad(c, [a, b])
|
||||
|
@ -114,38 +107,32 @@ class TestCodeOp(unittest.TestCase):
|
|||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_header='''
|
||||
namespace jittor {
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
|
||||
}
|
||||
''',
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
__global__ static void kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel<<<32, 32>>>(@ARGS);
|
||||
''', '''
|
||||
kernel3<<<32, 32>>>(@ARGS);
|
||||
__global__ static void kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@pout(0,0);
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
kernel<<<32, 32>>>(@ARGS);
|
||||
'''])
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
|
||||
|
|
|
@ -62,9 +62,32 @@ void CodeOp::jit_prepare() {
|
|||
add_jit_define("Tin", JK::hex1(i), in[i]->dtype());
|
||||
}
|
||||
if (use_cuda) {
|
||||
add_jit_define("HEADER", cuda_header);
|
||||
add_jit_define("CODE", cuda_src);
|
||||
jk << JK::key << "HEADER" << JK::val << cuda_header;
|
||||
ASSERT(cuda_src.size());
|
||||
jk << "\nnamespace jittor {\n";
|
||||
int i=0;
|
||||
// move cuda kernel function into header
|
||||
for (; i<cuda_src.size(); i++) {
|
||||
if (cuda_src[i] == ' ' || cuda_src[i] == '\t' || cuda_src[i] == '\n') {
|
||||
jk << cuda_src[i];
|
||||
} else
|
||||
if (cuda_src[i] == '_') {
|
||||
int presum = 0;
|
||||
while (i < cuda_src.size()) {
|
||||
jk << cuda_src[i];
|
||||
if (cuda_src[i] == '{') presum ++;
|
||||
else if (cuda_src[i] == '}') {
|
||||
presum--;
|
||||
if (presum==0)
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
} else break;
|
||||
}
|
||||
jk << "}" << JK::end << JK::key << "CODE" << JK::val;
|
||||
for (; i<cuda_src.size(); i++) jk << cuda_src[i];
|
||||
jk << JK::end;
|
||||
} else {
|
||||
add_jit_define("HEADER", cpu_header);
|
||||
add_jit_define("CODE", cpu_src);
|
||||
|
|
|
@ -70,40 +70,33 @@ struct CodeOp : Op {
|
|||
a = jt.random([100000])
|
||||
b = jt.random([100000])
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_header='''
|
||||
namespace jittor {
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
|
||||
}
|
||||
''',
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0shape0; i+=stride)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
'''])
|
||||
```
|
||||
|
@ -113,37 +106,30 @@ struct CodeOp : Op {
|
|||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_header='''
|
||||
namespace jittor {
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
|
||||
}
|
||||
''',
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
kernel3<<<32, 32>>>(@ARGS);
|
||||
'''])
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue