mirror of https://github.com/Jittor/Jittor
update code op help doc & test file
This commit is contained in:
parent
0ac6d53c02
commit
f434f64cd2
|
@ -6,6 +6,7 @@
|
|||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import Function
|
||||
|
||||
class TestCodeOp(unittest.TestCase):
|
||||
def test(self):
|
||||
|
@ -27,6 +28,35 @@ class TestCodeOp(unittest.TestCase):
|
|||
da = jt.grad(c*b, a)
|
||||
assert np.allclose(c.data*na*4, da.data), (c.data*na*4, da.data)
|
||||
|
||||
def test_use_func(self):
|
||||
class Func(Function):
|
||||
def execute(self, x):
|
||||
self.save_vars = x
|
||||
return jt.code(x.shape, x.dtype, [x],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in0(i)*@in0(i)*2;
|
||||
''')
|
||||
|
||||
def grad(self, grad_x):
|
||||
x = self.save_vars
|
||||
return jt.code(x.shape, x.dtype, [x, grad_x],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in1(i)*@in0(i)*4;
|
||||
''')
|
||||
|
||||
a = jt.random([10])
|
||||
func = Func()
|
||||
b = func(a)
|
||||
|
||||
na, nb = jt.fetch_sync([a,b])
|
||||
assert np.allclose(na*na*2, nb)
|
||||
|
||||
c = jt.random([10])
|
||||
da = jt.grad(c*b, a)
|
||||
assert np.allclose(c.data*na*4, da.data), (c.data*na*4, da.data)
|
||||
|
||||
def test_multi_input(self):
|
||||
a = jt.random([10])
|
||||
b = jt.random([10])
|
||||
|
@ -230,6 +260,48 @@ class TestCodeOp(unittest.TestCase):
|
|||
assert np.allclose(da.data, b.data)
|
||||
assert np.allclose(db.data, a.data)
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda2_use_func(self):
|
||||
class Func(Function):
|
||||
def execute(self, a, b):
|
||||
self.save_vars = a, b
|
||||
return jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''')
|
||||
|
||||
def grad(self, grad):
|
||||
a, b = self.save_vars
|
||||
return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad],
|
||||
cuda_src='''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x) {
|
||||
@out0(i,j) = @in2(i,j)*@in1(i,j);
|
||||
@out1(i,j) = @in2(i,j)*@in0(i,j);
|
||||
}
|
||||
}
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
''')
|
||||
|
||||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
|
||||
func = Func()
|
||||
c = func(a,b)
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
|
||||
assert np.allclose(da.data, b.data)
|
||||
assert np.allclose(db.data, a.data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -34,37 +34,41 @@ struct CodeOp : Op {
|
|||
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
|
||||
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
|
||||
|
||||
* [in] cpu_grad_src: A list of string, cpu source code string for gradient, represents gradiant for each inputm buildin value, buildin value:
|
||||
|
||||
* in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...)
|
||||
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
|
||||
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
|
||||
* pout{x}, pout{x}_shape{y}, pout{x}_stride{y}, pout{x}_type, pout{x}_p, @pout{x}(...)
|
||||
* pout, pout_shape{y}, pout_stride{y}, pout_type, pout_p, @pout(...)
|
||||
* dout, dout_shape{y}, dout_stride{y}, dout_type, dout_p, @dout(...)
|
||||
|
||||
* [in] cpu_header: cpu header code string.
|
||||
|
||||
* [in] cuda_src: cuda source code string.
|
||||
|
||||
* [in] cuda_grad_src: A list of string.
|
||||
|
||||
* [in] cuda_header: cuda header code string.
|
||||
|
||||
----------------
|
||||
|
||||
Example-1::
|
||||
|
||||
|
||||
from jittor import Function
|
||||
from jittor import jt
|
||||
|
||||
class Func(Function):
|
||||
def execute(self, x):
|
||||
self.save_vars = x
|
||||
return jt.code(x.shape, x.dtype, [x],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in0(i)*@in0(i)*2;
|
||||
''')
|
||||
|
||||
def grad(self, grad_x):
|
||||
x = self.save_vars
|
||||
return jt.code(x.shape, x.dtype, [x, grad_x],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in1(i)*@in0(i)*4;
|
||||
''')
|
||||
|
||||
a = jt.random([10])
|
||||
b = jt.code(a.shape, "float32", [a],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in0(i)*@in0(i)*2;
|
||||
''',
|
||||
cpu_grad_src = ['''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @dout(i)*@in0(i)*4;
|
||||
'''])
|
||||
func = Func()
|
||||
b = func(a)
|
||||
print(b)
|
||||
print(jt.grad(b,a))
|
||||
|
||||
Example-2::
|
||||
|
||||
|
@ -136,71 +140,90 @@ struct CodeOp : Op {
|
|||
CUDA Example-1::
|
||||
|
||||
#This example shows how to use CUDA in code op.
|
||||
from jittor import jt
|
||||
from jittor import Function
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
class Func(Function):
|
||||
def execute(self, a, b):
|
||||
self.save_vars = a, b
|
||||
return jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
kernel1<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''')
|
||||
|
||||
def grad(self, grad):
|
||||
a, b = self.save_vars
|
||||
return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad],
|
||||
cuda_src='''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride) {
|
||||
@out0(i) = @in2(i)*@in1(i);
|
||||
@out1(i) = @in2(i)*@in0(i);
|
||||
}
|
||||
}
|
||||
kernel2<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''')
|
||||
|
||||
a = jt.random([100000])
|
||||
b = jt.random([100000])
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
kernel1<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
kernel2<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
kernel3<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
'''])
|
||||
func = Func()
|
||||
c = func(a,b)
|
||||
print(c)
|
||||
print(jt.grad(c, [a, b]))
|
||||
|
||||
CUDA Example-2::
|
||||
|
||||
#This example shows how to use multi dimension data with CUDA.
|
||||
from jittor import jt
|
||||
from jittor import Function
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
class Func(Function):
|
||||
def execute(self, a, b):
|
||||
self.save_vars = a, b
|
||||
return jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''')
|
||||
|
||||
def grad(self, grad):
|
||||
a, b = self.save_vars
|
||||
return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad],
|
||||
cuda_src='''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x) {
|
||||
@out0(i,j) = @in2(i,j)*@in1(i,j);
|
||||
@out1(i,j) = @in2(i,j)*@in0(i,j);
|
||||
}
|
||||
}
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
''')
|
||||
|
||||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
kernel3<<<32, 32>>>(@ARGS);
|
||||
'''])
|
||||
func = Func()
|
||||
c = func(a,b)
|
||||
print(c)
|
||||
print(jt.grad(c, [a, b]))
|
||||
*/
|
||||
CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
|
||||
|
||||
|
|
Loading…
Reference in New Issue