JittorMirror/python/jittor/src/ops/code_op.cc

246 lines
8.0 KiB
C++

// ***************************************************************
// Copyright (c) 2022 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <cmath>
#include "var.h"
#include "ops/code_op.h"
#include "ops/op_register.h"
#include "misc/cuda_flags.h"
#define __inline_static__ inline static
#ifndef JIT
namespace jittor {
static auto make_code = get_op_info("code")
.get_constructor<VarPtr, NanoVector, NanoString, vector<Var*>&&, string&&, vector<string>&&, string&&, string&&, vector<string>&&, string&&>();
static inline void check_vary_shape(NanoVector v) {
ASSERT(v.size()) << "Vary shape should not be zero dimension";
for (int i=0; i<v.size(); i++)
ASSERT((i == 0) ^ (v[i] >= 0))
<< "Vary shape should only occur in the first dimension:" << v;
}
CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
{
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
_outputs.push_back(create_output(shape, dtype));
if (_outputs[0]->num < 0) {
check_vary_shape(_outputs[0]->shape);
}
if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0) {
flags.set(NodeFlags::_manual_set_vnbb);
}
}
CodeOp::CodeOp(
vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs,
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
{
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
CHECKop(shapes.size(),==,dtypes.size()) << "Number of outputs' shapes and dtypes should be the same";
_outputs.resize(shapes.size());
CHECKop(_outputs.size(),>,0);
for (int i=0; i<shapes.size(); i++) {
_outputs[i] = create_output(shapes[i], dtypes[i]);
if (_outputs[i]->num < 0) {
check_vary_shape(_outputs[i]->shape);
}
}
if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0)
flags.set(NodeFlags::_manual_set_vnbb);
}
CodeOp::CodeOp(
vector<Var*>&& inputs, vector<Var*>&& outputs,
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
{
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
_outputs.resize(outputs.size());
CHECKop(_outputs.size(),>,0);
for (int i=0; i<outputs.size(); i++) {
auto o = outputs[i];
_outputs[i] = create_output(o->shape, o->dtype());
_outputs[i]->share_with(o);
/*
TODO: vary shape not allowed in direct output
*/
}
if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0)
flags.set(NodeFlags::_manual_set_vnbb);
}
VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
// Do not have grad to extras input
string cpu_src = v_index < cpu_grad_src.size() ? cpu_grad_src[v_index] : "";
string cuda_src = v_index < cuda_grad_src.size() ? cuda_grad_src[v_index] : "";
if (!cuda_src.size() && !cpu_src.size()) return nullptr;
auto inputs = clone(_inputs);
// TODO: remove unused deps
// dout -> dout
std::stringstream new_alias;
new_alias << "\n@alias(dout,in" << inputs.size() << ")\n";
inputs.push_back(dout);
// _outputs[i] -> poutj
for (int i=0; i<_outputs.size(); i++) {
new_alias << "\n@alias(pout" << i << ",in" << inputs.size() << ")\n";
if (_outputs[i] == out)
new_alias << "\n@alias(pout,in" << inputs.size() << ")\n";
inputs.push_back(_outputs[i]);
}
auto alias = new_alias.str();
return make_code(
_inputs[v_index]->shape,
_inputs[v_index]->dtype(),
move(inputs),
move(cpu_src), {}, alias+cpu_header,
move(cuda_src), {}, alias+cuda_header
);
}
void CodeOp::jit_prepare(JK& jk) {
// forward: in0 in1 in2 -> out0 out1
// backward: in0 in1 in2 in3(pout0) in4(pout1)
jk << "«IN_SIZE=" << JK::hex(_inputs.size());
for (uint i=0; i<_inputs.size(); i++) {
jk << "«in" << JK::hex(i) << "_dim="
<< JK::hex1(_inputs[i]->shape.size());
jk << "«in" << JK::hex(i) << "_type:"
<< _inputs[i]->dtype();
}
jk << "«OUT_SIZE=" << JK::hex(_outputs.size());
for (uint i=0; i<_outputs.size(); i++) {
jk << "«out" << JK::hex(i) << "_dim="
<< JK::hex1(_outputs[i]->shape.size());
jk << "«out" << JK::hex(i) << "_type:"
<< _outputs[i]->dtype();
}
string& header = flags.get(NodeFlags::_cuda) ?
cuda_header : cpu_header;
string& src = flags.get(NodeFlags::_cuda) ?
cuda_src : cpu_src;
jk << "«HEADER:" << header;
CHECK(src.size());
jk << "\nnamespace jittor {\n";
int i=0;
// move cuda kernel function into header
for (; i<src.size(); i++) {
if (src[i] == ' ' || src[i] == '\t' || src[i] == '\n') {
jk << src[i];
} else
if (src[i] == '_') {
int presum = 0;
while (i < src.size()) {
jk << src[i];
if (src[i] == '{') presum ++;
else if (src[i] == '}') {
presum--;
if (presum==0)
break;
}
i++;
}
} else break;
}
jk << "}«CODE:";
for (; i<src.size(); i++) jk << src[i];
}
} // jittor
#else // JIT
#pragma GCC diagnostic ignored "-Wunused-variable"
@for(i, 0, IN_SIZE,
@define(in@i@@_stride@{in@i@@_dim-1},1)
)
@for(i, 0, OUT_SIZE,
@define(out@i@@_stride@{out@i@@_dim-1},1)
)
@define(ARGS_DEF,
@for(i, 0, IN_SIZE, @(
in@i@@_type* __restrict__ in@i@@_p,
@for(j, 0, in@i@@_dim, @(index_t in@i@@_shape@j,))
))
@for(i, 0, OUT_SIZE, @(
out@i@@_type* __restrict__ out@i@@_p,
@for(j, 0, out@i@@_dim, @(index_t out@i@@_shape@j,))
))
int __tmp
)
@define(ARGS,
@for(i, 0, IN_SIZE, @(
in@i@@_p,
@for(j, 0, in@i@@_dim, @(in@i@@_shape@j,))
))
@for(i, 0, OUT_SIZE, @(
out@i@@_p,
@for(j, 0, out@i@@_dim, @(out@i@@_shape@j,))
))
0
)
@define(PRECALC,
@for(i, 0, IN_SIZE,
@for(j, in@i@@_dim-2, -1, -1, auto in@i@@_stride@j = in@i@@_stride@{j+1} * in@i@@_shape@{j+1};)
)
@for(i, 0, OUT_SIZE,
@for(j, out@i@@_dim-2, -1, -1, auto out@i@@_stride@j = out@i@@_stride@{j+1} * out@i@@_shape@{j+1};)
)
)
@alias(out, out0)
@HEADER
namespace jittor {
void CodeOp::jit_run() {
// define inputs
@for(i, 0, IN_SIZE,
auto in@i = _inputs[@i];
auto* __restrict__ in@i@@_p = _inputs[@i]->ptr<in@i@@_type>();
@for(j, 0, in@i@@_dim, index_t in@i@@_shape@j = _inputs[@i]->shape[@j];)
)
// define outputs
@for(i, 0, OUT_SIZE,
auto out@i = _outputs[@i];
auto* __restrict__ out@i@@_p = _outputs[@i]->ptr<out@i@@_type>();
@for(j, 0, out@i@@_dim, index_t out@i@@_shape@j = _outputs[@i]->shape[@j];)
)
@PRECALC
@CODE
}
} // jittor
#endif // JIT