mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor into move_cutt
This commit is contained in:
commit
4109abde9e
|
@ -117,13 +117,13 @@ jittor会自动在路径中寻找合适的编译器, 如果您希望手动指定
|
|||
|
||||
```bash
|
||||
# install with clang and cuda
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_clang=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_clang=1 with_cuda=1 bash
|
||||
# install with clang
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_clang=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_clang=1 bash
|
||||
# install with g++ and cuda
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_gcc=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_gcc=1 with_cuda=1 bash
|
||||
# install with g++
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_gcc=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_gcc=1 bash
|
||||
```
|
||||
|
||||
执行后,脚本将显示一些需要导出的环境变量。
|
||||
|
|
|
@ -112,13 +112,13 @@ We provide single line command for quick installation the latest version of Jitt
|
|||
|
||||
```bash
|
||||
# install with clang and cuda
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_clang=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_clang=1 with_cuda=1 bash
|
||||
# install with clang
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_clang=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_clang=1 bash
|
||||
# install with g++ and cuda
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_gcc=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_gcc=1 with_cuda=1 bash
|
||||
# install with g++
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_gcc=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_gcc=1 bash
|
||||
```
|
||||
After execution, the script will show some environment variables you need to export.
|
||||
|
||||
|
|
|
@ -149,13 +149,13 @@ We provide single line command for quick installation the latest version of Jitt
|
|||
|
||||
```bash
|
||||
# install with clang and cuda
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_clang=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_clang=1 with_cuda=1 bash
|
||||
# install with clang
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_clang=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_clang=1 bash
|
||||
# install with g++ and cuda
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_gcc=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_gcc=1 with_cuda=1 bash
|
||||
# install with g++
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz && mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor && with_gcc=1 bash ./jittor/script/install.sh
|
||||
wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_gcc=1 bash
|
||||
```
|
||||
After execution, the script will show some environment variables you need to export.
|
||||
|
||||
|
|
|
@ -45,14 +45,14 @@ class TestCodeOp(unittest.TestCase):
|
|||
}
|
||||
'''])
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert np.allclose(c.data, a.data*b.data)
|
||||
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
|
||||
assert np.allclose(da.data, b.data)
|
||||
assert np.allclose(db.data, a.data)
|
||||
|
||||
def test_header(self):
|
||||
a = jt.array([3,2,1])
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
header='#include <algorithm>',
|
||||
cpu_header='#include <algorithm>',
|
||||
cpu_src="""
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @in0(i);
|
||||
|
@ -61,5 +61,97 @@ class TestCodeOp(unittest.TestCase):
|
|||
)
|
||||
assert (b.data==[1,2,3]).all()
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda(self):
|
||||
a = jt.random([100000])
|
||||
b = jt.random([100000])
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_header='''
|
||||
namespace jittor {
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
|
||||
}
|
||||
''',
|
||||
cuda_src='''
|
||||
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
'''])
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
|
||||
assert np.allclose(da.data, b.data)
|
||||
assert np.allclose(db.data, a.data)
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda2(self):
|
||||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_header='''
|
||||
namespace jittor {
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
|
||||
}
|
||||
''',
|
||||
cuda_src='''
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
''', '''
|
||||
kernel3<<<32, 32>>>(@ARGS);
|
||||
'''])
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
|
||||
assert np.allclose(da.data, b.data)
|
||||
assert np.allclose(db.data, a.data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -94,7 +94,7 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
|
|||
a = jt.code([4], "int", cpu_src="""
|
||||
#pragma omp parallel num_threads(4)
|
||||
@out(omp_get_thread_num()) = 456;
|
||||
""", header='#include <omp.h>').data
|
||||
""", cpu_header='#include <omp.h>').data
|
||||
assert (a==[456]*4).all(), a
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#!/bin/bash
|
||||
# Single line install script
|
||||
# git clone https://github.com/Jittor/jittor.git && with_clang=1 with_cuda=1 bash ./jittor/script/install.sh
|
||||
|
||||
# wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_clang=1 with_cuda=1 bash
|
||||
set -ex
|
||||
|
||||
if [ "$is_docker" = "1" ]; then
|
||||
|
@ -50,12 +49,7 @@ wget -O - https://bootstrap.pypa.io/get-pip.py | sudo -H python$py_version
|
|||
|
||||
# Step 3: Run jittor
|
||||
|
||||
if [ ! -d jittor ]; then
|
||||
wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/jittor.tgz
|
||||
mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor
|
||||
fi
|
||||
|
||||
sudo python$py_version -m pip install ./jittor
|
||||
sudo python$py_version -m pip install git+https://github.com/Jittor/jittor.git
|
||||
|
||||
if [ "$with_cuda" = "1" ]; then
|
||||
export nvcc_path="/usr/local/cuda/bin/nvcc"
|
||||
|
|
7
setup.py
7
setup.py
|
@ -1,3 +1,10 @@
|
|||
import platform
|
||||
|
||||
error_msg = "Jittor only supports Ubuntu>=16.04 currently."
|
||||
assert hasattr(platform, "linux_distribution"), error_msg
|
||||
dis_name = platform.linux_distribution()[0].lower()
|
||||
version = float(platform.linux_distribution()[1])
|
||||
assert "ubuntu" in dis_name and version >= 16, error_msg
|
||||
|
||||
import setuptools
|
||||
from setuptools import setup, find_packages
|
||||
|
|
|
@ -176,7 +176,7 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
|
|||
while (k<expr.size() && isvar(expr[k])) k++;
|
||||
string var = expr.substr(j, k-j);
|
||||
auto iter = vars.find(var);
|
||||
ASSERT(iter!=vars.end()) << "Jit var " << var << " not found.";
|
||||
ASSERT(iter!=vars.end()) << "Jit var " << var << " not found." << vars;
|
||||
new_expr += iter->second;
|
||||
i = k-1;
|
||||
}
|
||||
|
@ -451,6 +451,22 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
new_src += S(OpCompiler::eval(src.substr(j+1, k-j-2), defs));
|
||||
i = k-1;
|
||||
continue;
|
||||
} else if (src[j] == '(') {
|
||||
// syntax @(...)
|
||||
// ij k
|
||||
size_t k=j+1;
|
||||
int presum = 1;
|
||||
while (k<src.size() && presum) {
|
||||
if (src[k] == ')')
|
||||
presum--;
|
||||
else if (src[k] == '(')
|
||||
presum++;
|
||||
k++;
|
||||
}
|
||||
ASSERT(presum==0) << "Jit error: braces are not matched.";
|
||||
new_src += precompile(defs, src.substr(j+1, k-j-2), macros);
|
||||
i = k-1;
|
||||
continue;
|
||||
} else if (isvar(src[j])) {
|
||||
size_t k=j+1;
|
||||
while (k<src.size() && isvar(src[k])) k++;
|
||||
|
@ -541,6 +557,36 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (expr == "define") {
|
||||
// syntax: @define(macro, value)
|
||||
// ij k l
|
||||
ASSERT(args.size()>=1u)
|
||||
<< "Jit error: define wrong arguments.";
|
||||
new_src += "#define ";
|
||||
auto key = precompile(defs, args[0], macros);
|
||||
string value, src;
|
||||
new_src += key;
|
||||
if (args.size()>=2) {
|
||||
new_src += " ";
|
||||
string all_args = args[1];
|
||||
for (int i=2; i<args.size(); i++) {
|
||||
all_args += ',';
|
||||
all_args += args[i];
|
||||
}
|
||||
src = precompile(defs, all_args, macros);
|
||||
for (auto c : src) {
|
||||
if (c == '\n')
|
||||
value += " \\";
|
||||
value += c;
|
||||
}
|
||||
new_src += value;
|
||||
}
|
||||
ASSERT(macros.count(key)==0) << "Macro" << key << "redefined.";
|
||||
defs[key] = src;
|
||||
macros[key] = value;
|
||||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (args.size()) {
|
||||
// syntax: @e0(i0,i1,...,in) -> e0p[i0*e0stride0+i1*e0stride1+...]
|
||||
int nid=(int)expr.size();
|
||||
|
|
|
@ -7,27 +7,32 @@
|
|||
#include "var.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static auto make_code = get_op_info("code")
|
||||
.get_constructor<VarPtr, NanoVector, NanoString, vector<Var*>&&, string&&, vector<string>&&, string&&>();
|
||||
.get_constructor<VarPtr, NanoVector, NanoString, vector<Var*>&&, string&&, vector<string>&&, string&&, string&&, vector<string>&&, string&&>();
|
||||
|
||||
CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs, string&& cpu_src, vector<string>&& cpu_grad_src, string&& header)
|
||||
: in(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), header(move(header)) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
|
||||
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
|
||||
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
|
||||
: in(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
|
||||
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
|
||||
{
|
||||
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
|
||||
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
|
||||
out = create_output(shape, dtype);
|
||||
ASSERT(this->cpu_src.size());
|
||||
ASSERTop(inputs.size(),<=,10);
|
||||
}
|
||||
|
||||
VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
// Do not have grad to extras input
|
||||
if (cpu_grad_src.size() <= v_index) return nullptr;
|
||||
auto src = cpu_grad_src[v_index];
|
||||
if (!src.size()) return nullptr;
|
||||
string cpu_src = v_index < cpu_grad_src.size() ? cpu_grad_src[v_index] : "";
|
||||
string cuda_src = v_index < cuda_grad_src.size() ? cuda_grad_src[v_index] : "";
|
||||
if (!cuda_src.size() && !cpu_src.size()) return nullptr;
|
||||
auto inputs = clone(in);
|
||||
inputs.push_back(out);
|
||||
inputs.push_back(dout);
|
||||
|
@ -35,7 +40,8 @@ VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
in[v_index]->shape,
|
||||
in[v_index]->dtype(),
|
||||
move(inputs),
|
||||
move(src), {}, clone(header)
|
||||
move(cpu_src), {}, clone(cpu_header),
|
||||
move(cuda_src), {}, clone(cuda_header)
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -55,47 +61,84 @@ void CodeOp::jit_prepare() {
|
|||
add_jit_define("INDIM", JK::hex1(i), JK::hex1(in[i]->shape.size()));
|
||||
add_jit_define("Tin", JK::hex1(i), in[i]->dtype());
|
||||
}
|
||||
add_jit_define("HEADER", header);
|
||||
add_jit_define("CODE", cpu_src);
|
||||
if (use_cuda) {
|
||||
add_jit_define("HEADER", cuda_header);
|
||||
add_jit_define("CODE", cuda_src);
|
||||
ASSERT(cuda_src.size());
|
||||
} else {
|
||||
add_jit_define("HEADER", cpu_header);
|
||||
add_jit_define("CODE", cpu_src);
|
||||
ASSERT(cpu_src.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
||||
#else // JIT
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
@for(i, 0, INSIZE,
|
||||
@define(in@i@@stride@{INDIM@i-1},1)
|
||||
)
|
||||
@define(outstride@{OUTDIM-1},1)
|
||||
@if(INSIZE>=2,
|
||||
@define(poutstride@{POUTDIM-1},1)
|
||||
@define(doutstride@{DOUTDIM-1},1)
|
||||
,)
|
||||
|
||||
@define(ARGS_DEF,
|
||||
@for(i, 0, INSIZE, @(
|
||||
Tin@i* __restrict__ in@i@@p,
|
||||
@for(j, 0, INDIM@i, @(index_t in@i@@shape@j,))
|
||||
))
|
||||
@for(i, 0, OUTDIM, @(index_t outshape@i,))
|
||||
Tout* __restrict__ outp
|
||||
)
|
||||
|
||||
@define(ARGS,
|
||||
@for(i, 0, INSIZE, @(
|
||||
in@i@@p,
|
||||
@for(j, 0, INDIM@i, @(in@i@@shape@j,))
|
||||
))
|
||||
@for(i, 0, OUTDIM, @(outshape@i,))
|
||||
outp
|
||||
)
|
||||
|
||||
@define(PRECALC,
|
||||
@for(i, 0, INSIZE,
|
||||
@for(j, INDIM@i-2, -1, -1, auto in@i@@stride@j = in@i@@stride@{j+1} * in@i@@shape@{j+1};)
|
||||
)
|
||||
@for(i, OUTDIM-2, -1, -1, auto outstride@i = outstride@{i+1} * outshape@{i+1};)
|
||||
@if(INSIZE>=2,
|
||||
auto* __restrict__ poutp = in@{INSIZE-2}@@p;
|
||||
@for(i, 0, POUTDIM, index_t poutshape@i = in@{INSIZE-2}@@shape@i;)
|
||||
@for(i, POUTDIM-2, -1, -1, auto poutstride@i = in@{INSIZE-2}@@stride@i;)
|
||||
|
||||
auto* __restrict__ doutp = in@{INSIZE-1}@@p;
|
||||
@for(i, 0, DOUTDIM, index_t doutshape@i = in@{INSIZE-1}@@shape@i;)
|
||||
@for(i, DOUTDIM-2, -1, -1, auto doutstride@i = in@{INSIZE-1}@@stride@i;)
|
||||
,)
|
||||
)
|
||||
|
||||
|
||||
@HEADER
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
void CodeOp::jit_run() {
|
||||
// define inputs
|
||||
@for(i, 0, INSIZE,
|
||||
auto in@i = in[@i];
|
||||
auto* __restrict__ in@i@@p = in[@i]->ptr<Tin@i>();
|
||||
@for(j, 0, INDIM@i, index_t in@i@@shape@j = in[@i]->shape[@j];)
|
||||
index_t in@i@@stride@{INDIM@i-1} = 1;
|
||||
@for(j, INDIM@i-2, -1, -1, auto in@i@@stride@j = in@i@@stride@{j+1} * in@i@@shape@{j+1};)
|
||||
)
|
||||
// define out
|
||||
auto* __restrict__ outp = out->ptr<Tout>();
|
||||
@for(i, 0, OUTDIM, index_t outshape@i = out->shape[@i];)
|
||||
index_t outstride@{OUTDIM-1} = 1;
|
||||
@for(i, OUTDIM-2, -1, -1, auto outstride@i = outstride@{i+1} * outshape@{i+1};)
|
||||
|
||||
@if(INSIZE>=2,
|
||||
auto pout = in[@{INSIZE-2}];
|
||||
auto* __restrict__ poutp = pout->ptr<Tpout>();
|
||||
@for(i, 0, POUTDIM, index_t poutshape@i = pout->shape[@i];)
|
||||
index_t poutstride@{POUTDIM-1} = 1;
|
||||
@for(i, POUTDIM-2, -1, -1, auto poutstride@i = poutstride@{i+1} * poutshape@{i+1};)
|
||||
@PRECALC
|
||||
|
||||
auto dout = in[@{INSIZE-1}];
|
||||
auto* __restrict__ doutp = dout->ptr<Tdout>();
|
||||
@for(i, 0, DOUTDIM, index_t doutshape@i = dout->shape[@i];)
|
||||
index_t doutstride@{DOUTDIM-1} = 1;
|
||||
@for(i, DOUTDIM-2, -1, -1, auto doutstride@i = doutstride@{i+1} * doutshape@{i+1};)
|
||||
,)
|
||||
@CODE
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,10 @@ struct CodeOp : Op {
|
|||
Var* out;
|
||||
string cpu_src;
|
||||
vector<string> cpu_grad_src;
|
||||
string header;
|
||||
string cpu_header;
|
||||
string cuda_src;
|
||||
vector<string> cuda_grad_src;
|
||||
string cuda_header;
|
||||
/**
|
||||
Code Operator for easily customized op.
|
||||
|
||||
|
@ -37,6 +40,14 @@ struct CodeOp : Op {
|
|||
* pout, poutshape{y}, poutstride{y}, Tpout, poutp, @pout(...)
|
||||
* dout, doutshape{y}, doutstride{y}, Tdout, doutp, @dout(...)
|
||||
|
||||
@param[in] cpu_header cpu header code string.
|
||||
|
||||
@param[in] cuda_src cuda source code string.
|
||||
|
||||
@param[in] cuda_grad_src A list of string.
|
||||
|
||||
@param[in] cuda_header cuda header code string.
|
||||
|
||||
----------------
|
||||
|
||||
Example
|
||||
|
@ -53,8 +64,91 @@ struct CodeOp : Op {
|
|||
@out(i) = @dout(i)*@in0(i)*4;
|
||||
'''])
|
||||
```
|
||||
|
||||
Example2(CUDA):
|
||||
```
|
||||
a = jt.random([100000])
|
||||
b = jt.random([100000])
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_header='''
|
||||
namespace jittor {
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i=0; i<in0shape0; i++)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
|
||||
}
|
||||
''',
|
||||
cuda_src='''
|
||||
kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
'''])
|
||||
```
|
||||
|
||||
Example3(CUDA):
|
||||
```
|
||||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_header='''
|
||||
namespace jittor {
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
|
||||
}
|
||||
''',
|
||||
cuda_src='''
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
''', '''
|
||||
kernel3<<<32, 32>>>(@ARGS);
|
||||
'''])
|
||||
```
|
||||
*/
|
||||
CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& header="");
|
||||
CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
|
||||
|
||||
const char* name() const override { return "code"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
|
|
Loading…
Reference in New Issue