JittorMirror/python/jittor/src/ops/getitem_op.cc

580 lines
19 KiB
C++

// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <cmath>
#include "var.h"
#include "executor.h"
#include "ops/getitem_op.h"
#include "ops/op_register.h"
#ifdef JIT_cuda
#include <cuda_runtime.h>
#include "helper_cuda.h"
#endif
#ifndef JIT
#include "misc/stack_vector.h"
#include "opt/kernel_ir.h"
#ifdef HAS_CUDA
#include "misc/cuda_flags.h"
#endif
#endif
namespace jittor {
#ifndef JIT
static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>();
static auto make_empty = get_op_info("empty")
.get_constructor<VarPtr, NanoVector, NanoString>();
static auto make_setitem = get_op_info("setitem")
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
GetitemOp::GetitemOp(Var* x, VarSlices&& slices)
: vs(move(slices)) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_has_gopt);
flags.set(NodeFlags::_manual_set_vnbb);
for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var())
vs.slices[i].var->flags.set(NodeFlags::_needed_by_backward);
create_output(nullptr, x->dtype());
}
GetitemOp::GetitemOp(Var* x, VarSlices&& slices, int _)
: vs(move(slices)) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_has_gopt);
flags.set(NodeFlags::_custom_flag);
flags.set(NodeFlags::_grads);
flags.set(NodeFlags::_manual_set_vnbb);
for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var())
vs.slices[i].var->flags.set(NodeFlags::_needed_by_backward);
create_output(nullptr, x->dtype());
auto out2 = create_output(nullptr, x->dtype());
out2->share_with(x);
ns.data = _;
}
void GetitemOp::infer_slices(
StackVector<>& __restrict__ i_to_vs,
StackVector<>& __restrict__ i_to_o,
StackVector<>& __restrict__ out_shape
) {
auto in = inputs().front();
auto in_shape = in->shape;
auto nin = in_shape.size();
i_to_vs.n = i_to_o.n = nin;
out_shape.n = 0;
int vid = 0;
first_oid_of_var = -1;
var_dim = 0;
for (int i=0; i<nin; i++) {
auto& s = vs.slices[vid];
if (vid >= vs.n) {
// i i i
// | | |
// v v v --> overflow
// s s
i_to_vs[i] = -1;
i_to_o[i] = out_shape.size();
out_shape.push_back(in_shape[i]);
} else
if (s.is_var()) {
// i --> s ---> o
// + ---> o
// var maybe multiple dims
if (first_oid_of_var == -1) {
for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var())
var_dim = std::max(var_dim, vs.slices[i].var->shape.size());
first_oid_of_var = out_shape.size();
for (int j=0; j<var_dim; j++) {
out_shape.push_back(1);
}
}
i_to_vs[i] = vid++;
i_to_o[i] = -1;
auto iv = s.var;
auto iv_shape = iv->shape;
auto niv = iv_shape.size();
for (int j=0; j<niv; j++) {
auto iv_shape_j = iv_shape[niv-j-1];
auto& out_shape_j = out_shape[first_oid_of_var+var_dim-j-1];
if (out_shape_j == 1)
out_shape_j = iv_shape_j;
else
ASSERT(out_shape_j == iv_shape_j || out_shape_j < 0 || iv_shape_j < 0)
<< out_shape_j << iv_shape_j << out_shape;
}
} else
if (s.is_ellipsis()) {
auto remain_slice = vs.n-vid-1;
for (int i=vid+1; i<vs.n; i++)
if (vs.slices[i].is_none())
remain_slice--;
auto remain_idims = nin-i;
auto ellipsis_size = remain_idims - remain_slice;
ASSERT(ellipsis_size>=0) << "NDims not match";
for (int j=0; j<ellipsis_size; j++) {
i_to_vs[i+j] = -1;
i_to_o[i+j] = out_shape.size();
out_shape.push_back(in_shape[i+j]);
}
vid ++;
i += ellipsis_size-1;
} else
if (s.is_none()) {
i--;
out_shape.push_back(1);
vid++;
continue;
} else
if (s.is_int()) {
i_to_vs[i] = vid++;
i_to_o[i] = -1;
auto in_shape_i = in_shape[i];
auto& v = s.slice.start;
if (v<0) v += in_shape_i;
CHECK(v>=0 && v<in_shape_i) << "slice overflow, " << v << "not in [0,">>in_shape_i>>")";
} else
if (s.is_str()) {
i_to_vs[i] = vid++;
i_to_o[i] = -1;
} else {
// slice
auto& slice = s.slice;
auto in_shape_i = in_shape[i];
auto out_shape_j = in_shape_i;
if (slice.mask == 7) {
// slice is a[::]
// start, stop, step is not filled
vid++;
i_to_vs[i] = -1;
i_to_o[i] = out_shape.size();
out_shape.push_back(out_shape_j);
} else {
i_to_vs[i] = vid++;
i_to_o[i] = out_shape.size();
if (in_shape_i > 0) {
slice.fill(in_shape_i);
if (std::abs(slice.step) <= 1)
out_shape_j = (slice.stop - slice.start) * slice.step;
else if (slice.step>0)
out_shape_j = (slice.stop - slice.start - 1) / slice.step + 1;
else
out_shape_j = (slice.start - slice.stop - 1) / -slice.step + 1;
out_shape_j = std::max((int64)0, out_shape_j);
}
out_shape.push_back(out_shape_j);
}
}
}
while (vid < vs.n) {
auto& s = vs.slices[vid++];
if (s.is_none()) {
out_shape.push_back(1);
} else
CHECK(s.is_ellipsis()) << "Too many slices" << vs << "shape:" << in->shape;
}
}
void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims) {
// bz by bx tz ty tx
// 5 4 3 2 1 0
// LOi: bitmask of used dims of loop i
// LOi bit 6: need for
// if need for, keep for range: for (int i@i=tid; tid<range; tid+=tnum)
// if not need for, replace range -> tnum, for -> int i@i = tid
int rtnum = 1024;
// int max_tnum = {1024, 1024, 64, (1u<<31)-1, 65535, 65535};
int loop_id = (int)o_shape.size()-1;
int tid = 0;
int64 block_size = 1;
int thread_size = 1;
for (int i=0; i<6; i++) tdims[i] = 1;
for (; tid<3 && loop_id>=0 && rtnum>1; tid++) {
int64 si = o_shape[loop_id];
int mask = 1<<tid;
if (tid==2) rtnum = std::min(64, rtnum);
if (si>rtnum*4) {
// need for, use tid(1<<i) and bx(8)
mask |= 8|(1<<6);
block_size = (si-1)/rtnum+1;
tdims[tid] = rtnum;
tdims[3] = block_size;
tid = 3;
thread_size *= rtnum;
rtnum = 0;
} else
if (si>rtnum) {
mask |= (1<<6);
thread_size *= rtnum;
tdims[tid] = rtnum;
rtnum = 0;
} else {
rtnum = rtnum / std::max(si, (int64)1);
thread_size *= si;
tdims[tid] = si;
if (si == 0) mask |= 1<<7;
}
masks[loop_id] = mask;
loop_id --;
}
int64 total_size = (int64)block_size*thread_size;
if (tid<3) tid=3;
for (; tid<6 && loop_id>=0 && total_size<(256*1024); tid++) {
int64 si = o_shape[loop_id];
int mask = 1<<tid;
if (si == 0) mask |= 1<<7;
int64 max_thread = tid>=4 ? 65535 : (1u<<31)-1;
if (si > max_thread) {
si = max_thread;
mask |= 1<<6;
}
total_size *= si;
tdims[tid] = si;
masks[loop_id] = mask;
loop_id --;
}
while (loop_id>=0) {
masks[loop_id--] = 0;
}
}
void GetitemOp::compile_optimize(string& src) {
_compile_optimize(src);
}
void GetitemOp::_compile_optimize(string& src) {
if (!flags.get(NodeFlags::_cuda))
return;
auto jd = get_jit_define();
map<string,string> jd_map(jd.begin(), jd.end());
KernelIR main(src);
auto& func = main.children.back()->children.back();
// auto& loop = func->children.back();
func->push_back("void func() {}", &func->before);
auto& new_func = func->before.back();
// auto new_func = func->before.back()->move_out();
new_func->attrs["dtype"] = "static __global__ void";
// LOGir << main.to_string();
src = main.to_string();
string arg_call = "";
const char* tname[] = {"threadIdx.x", "threadIdx.y", "threadIdx.z", "blockIdx.x", "blockIdx.y", "blockIdx.z"};
const char* tname2[] = {"blockDim.x", "blockDim.y", "blockDim.z", "gridDim.x", "gridDim.y", "gridDim.z"};
for (auto& ir : func->children) {
if (ir->type == "define") {
string& rvalue = ir->attrs.at("rvalue");
string& lvalue = ir->attrs.at("lvalue");
string& dtype = ir->attrs.at("dtype");
if (startswith(rvalue, "input")
|| startswith(rvalue, "output")
|| startswith(rvalue, "vs.")
|| rvalue.back() == ')'
|| rvalue.back() == ']')
{
if (dtype == "auto")
LOGvvvv << "keep" << rvalue;
else {
LOGvvvv << "args" << rvalue;
if (arg_call.size()) arg_call += ", ";
arg_call += lvalue;
LOGvvvv << dtype+" "+lvalue;
new_func->push_back(dtype+" "+lvalue+";", &new_func->inner);
}
} else {
LOGvvvv << "move" <<rvalue;
new_func->push_back(ir->clone());
}
}
}
new_func->push_back(func->children.back()->move_out());
auto& loop = new_func->children.back();
int no = o_shape.size();
STACK_ALLOC(KernelIR*, loops, no);
if (!no) {
func->push_back("func<<<1,1>>>("+arg_call+");");
} else {
bool has_zero = 0;
loops[0] = loop.get();
for (int i=1; i<no; i++)
loops[i] = loops[i-1]->children.back().get();
for (int i=0; i<no; i++) {
auto l = loops[i];
ASSERT(l->inner.size() == 3);
auto lo = l->find_define("LO"+S(i));
ASSERT(lo);
auto loi = std::stoi(lo->attrs.at("rvalue"));
if (loi>>7) has_zero = 1;
string tid = "";
string tnum = "";
for (int j=0; j<6; j++) {
if ((loi>>j)&1) {
if (tid.size()) {
tid += string("+")+tnum+"*"+tname[j];
tnum += string("*")+tname2[j];
} else {
tid = tname[j];
tnum = tname2[j];
}
}
}
if (!tid.size()) {
continue;
}
if (loi&(1<<6)) {
l->inner.at(0)->attrs.at("rvalue") = tid;
l->inner.at(2)->attrs.at("code") = "i"+S(i)+"+="+tnum+";";
} else {
// no need for
while (l->inner.size())
l->inner.at(0)->erase();
l->push_front("index_t i"+S(i)+" = "+tid+";");
}
}
if (!has_zero) {
func->push_back("int no = o_shape.size();");
func->push_back("STACK_ALLOC(int,masks,no);");
func->push_back("int tdims[6];");
func->push_back("cuda_loop_schedule(o_shape, masks, tdims);");
func->push_back("dim3 grid_dim(tdims[3],tdims[4],tdims[5]);");
func->push_back("dim3 block_dim(tdims[0],tdims[1],tdims[2]);");
func->push_back("func<<<grid_dim, block_dim>>>("+arg_call+");");
}
}
src = main.to_string();
}
void GetitemOp::infer_shape() {
auto in = inputs().front();
auto out = outputs().front();
auto in_shape = in->shape;
auto nin = in_shape.size();
StackVector<> i_to_vs(nin);
StackVector<> i_to_o(nin);
// shape return to use
StackVector<> out_shape;
infer_slices(i_to_vs, i_to_o, out_shape);
// optimized shape (each dim is a loop var)
StackVector<> o_shape;
int fov = -1;
for (int i=0; i<nin; i++) {
auto& vid = i_to_vs[i];
auto& oid = i_to_o[i];
auto os = out_shape[oid];
if (oid>=0) {
if (vid==-1 && i && i_to_vs[i-1]<0) {
vid = -2;
o_shape.back() *= os;
} else
o_shape.push_back(os);
oid = o_shape.size()-1;
} else {
auto& s = vs.slices[vid];
if (s.is_var() && fov == -1) {
fov = o_shape.size();
for (int i=0; i<var_dim; i++)
o_shape.push_back(out_shape[first_oid_of_var+i]);
}
}
}
first_oid_of_var = fov;
if (!out_shape.size()) out_shape.push_back(1);
out->set_shape(out_shape.to_nano_vector());
this->i_to_vs = i_to_vs.to_nano_vector();
this->i_to_o = i_to_o.to_nano_vector();
this->o_shape = o_shape.to_nano_vector();
if (outputs().size() > 1) {
auto out2 = output(1);
out2->set_shape(in->shape);
}
LOGV(999) << "\ni_to_vs:" << i_to_vs
<< "\ni_to_o:" << i_to_o
<< "\no_shape:" << o_shape;
}
VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index)
return nullptr;
auto zeros = make_number(0, v);
// TODO: maybe add here?
// need analysis the overlap attr os var slices
for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var()) {
return make_setitem(zeros, VarSlices(vs, true), dout, ns_add);
}
return make_setitem(zeros, VarSlices(vs, true), dout, ns_void);
}
void GetitemOp::grads(Var** dout, VarPtr* dins) {
VarPtr x = dout[1];
VarPtr y = dout[0];
if (!x) {
auto in = inputs().front();
// ns.data represents this is the last split var
if (ns.data)
x = make_empty(in->shape, in->dtype());
else
x = make_number(0, in);
}
if (!y) {
y = make_number(0, outputs().front());
}
dins[0] = make_setitem(x, VarSlices(vs, true), y, ns_void);
}
void GetitemOp::jit_prepare(JK& jk) {
auto in = inputs().front();
int idim = i_to_vs.size();
jk << _CS("[Ti:") << in->dtype();
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size());
jk << _CS("][ODIM=") << JK::hex1(o_shape.size());
if (first_oid_of_var>=0) {
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var);
jk << _CS("][VD=") << JK::hex1(var_dim);
}
for (int i=0; i<idim; i++) {
auto iv = i_to_vs[i];
auto io = i_to_o[i];
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv);
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io);
auto& v = vs.slices[iv];
if (iv>=0 && io==-1) {
if (v.is_int()) {
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
} else
if (v.is_str()) {
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5");
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str();
} else {
ASSERT(v.is_var());
auto var = v.var;
auto vshape = var->shape;
auto vdim = vshape.size();
int vsmask = 0;
for (int j=0; j<vdim; j++) {
int k = first_oid_of_var+j+var_dim-vdim;
if (vshape[j] == o_shape[k])
vsmask |= 1<<(j+var_dim-vdim);
}
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask);
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype();
}
} else
if (iv>=0 && io>=0) {
ASSERT(v.is_slice());
jk << _CS("][VS") << JK::hex1(i) << ':';
if (std::abs(v.slice.step) <= 1)
jk << JK::shex1(v.slice.step);
else
jk << '0';
}
}
#ifdef HAS_CUDA
if (use_cuda) {
int no = o_shape.size();
STACK_ALLOC(int, masks, no);
int tdims[6];
cuda_loop_schedule(o_shape, masks, tdims);
for (int i=0; i<no; i++) {
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]);
}
}
#endif
jk << ']';
}
#else // JIT
#pragma GCC diagnostic ignored "-Wunused-variable"
void GetitemOp::jit_run() {
auto in = inputs().front();
auto out = outputs().front();
if (out->num == 0) return;
if (ns.get(GetitemOp::_inplace) &&
in->allocator == out->allocator &&
in->allocation == out->allocation)
return;
@for(i, 0, ODIM, index_t oshape@i = o_shape[@i];)
@if(ODIM>0,
index_t ostride@{ODIM-1} = 1;
@for(i, ODIM-2, -1, -1, index_t ostride@i = ostride@{i+1} * oshape@{i+1};)
)
Ti* op = out->ptr<Ti>();
Ti* ip = in->ptr<Ti>();
@for(i, 0, IDIM, index_t ishape@i =
@if(IV@i==-1,oshape@{IO@i},
@if(IV@i==-2,1,in->shape[@i]));
)
index_t istride@{IDIM-1} = 1;
@for(i, IDIM-2, -1, -1, index_t istride@i = istride@{i+1} * ishape@{i+1};)
@for(i, 0, IDIM,
@if(IV@i>=0 && IO@i>=0,
index_t vstart@i = vs.slices[@{IV@i}].slice.start;
index_t vstep@i = @if(VS@i==0,vs.slices[@{IV@i}].slice.step;,@{VS@i});
)
)
@for(i, 0, IDIM,
@if(IV@i>=0 && IO@i<0,
@if(VS@i==-1,index_t vi@i = vs.slices[@{IV@i}].slice.start;);
)
)
@for(i, 0, IDIM,
@if(IV@i>=0 && IO@i<0,
@if(VS@i>=0,
index_t vs@i@@s@{VD-1} = 1;
VST@i* vp@i = vs.slices[IV@i].var->ptr<VST@i>();
@for(j,VD-2,-1,-1,index_t vs@i@@s@j = vs@i@@s@{j+1} *
@if((VS@i>>(j+1))&1,oshape@{j+1+FOV},1);
)
);
)
)
@for(d, 0, ODIM, for (index_t i@d=0; i@d < oshape@d; i@d++)) {
index_t oid = 0 @for(d, 0, ODIM, + i@d * ostride@d);
@for(d, 0, IDIM, index_t iid@d =
@if(IV@d==-1, i@{IO@d},
@if(IV@d==-2, 0,
@if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d),
@if(VS@d==-1, vi@d,
@if(VS@d==-5, VSS@d,
@if(VS@d>=0,
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
, ??? ))))));
)
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
op[oid] = ip[iid];
}
}
#endif // JIT
} // jittor