polish jit key

This commit is contained in:
Dun Liang 2022-05-23 14:23:02 +08:00
parent 91fe1fac85
commit c5ccdaf330
66 changed files with 407 additions and 407 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.3.4.7' __version__ = '1.3.4.8'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -59,13 +59,13 @@ void CubArgReduceOp::infer_shape() {
} }
void CubArgReduceOp::jit_prepare(JK& jk) { void CubArgReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Toffsets:") << offsets->dtype(); jk << "«Toffsets:" << offsets->dtype();
jk << _CS("][FUNC:"); jk << "«FUNC:";
if (op==ns_minimum) if (op==ns_minimum)
jk << _CS("ArgMin]"); jk << "ArgMin";
else else
jk << _CS("ArgMax]"); jk << "ArgMax";
} }
#else // JIT #else // JIT

View File

@ -51,15 +51,15 @@ void CubArgsortOp::infer_shape() {
} }
void CubArgsortOp::jit_prepare(JK& jk) { void CubArgsortOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Tindexes:") << indexes->dtype(); jk << "«Tindexes:" << indexes->dtype();
jk << _CS("][Toffsets:") << offsets->dtype(); jk << "«Toffsets:" << offsets->dtype();
jk << _CS("][Ty:") << y->dtype(); jk << "«Ty:" << y->dtype();
jk << _CS("][FUNC:"); jk << "«FUNC:";
if (descending) if (descending)
jk << _CS("SortPairsDescending]"); jk << "SortPairsDescending";
else else
jk << _CS("SortPairs]"); jk << "SortPairs";
} }
#else // JIT #else // JIT

View File

@ -38,10 +38,9 @@ void CubCumsumOp::infer_shape() {
} }
void CubCumsumOp::jit_prepare(JK& jk) { void CubCumsumOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Ty:") << y->dtype(); jk << "«Ty:" << y->dtype();
jk << _CS("][reverse:") << reverse; jk << "«reverse:" << reverse;
jk << _CS("]");
} }
VarPtr CubCumsumOp::grad(Var* out, Var* dout, Var* v, int v_index) { VarPtr CubCumsumOp::grad(Var* out, Var* dout, Var* v, int v_index) {

View File

@ -24,7 +24,7 @@ CubTestOp::CubTestOp(string cmd) : cmd(cmd) {
} }
void CubTestOp::jit_prepare(JK& jk) { void CubTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]"); jk << "«T:float32";
} }
#else // JIT #else // JIT

View File

@ -40,10 +40,9 @@ void CubWhereOp::infer_shape() {
} }
void CubWhereOp::jit_prepare(JK& jk) { void CubWhereOp::jit_prepare(JK& jk) {
jk << _CS("[Ti:") << cond->dtype(); jk << "«Ti:" << cond->dtype();
jk << _CS("][To:") << outs[0]->dtype(); jk << "«To:" << outs[0]->dtype();
jk << _CS("][NDIM=") << JK::hex1(cond->shape.size()); jk << "«NDIM=" << JK::hex1(cond->shape.size());
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -54,11 +54,10 @@ void CublasAccMatmulOp::infer_shape() {
} }
void CublasAccMatmulOp::jit_prepare(JK& jk) { void CublasAccMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype(); jk << "«T:" << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -89,11 +89,10 @@ void CublasBatchedMatmulOp::infer_shape(){
} }
void CublasBatchedMatmulOp::jit_prepare(JK& jk) { void CublasBatchedMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype(); jk << "«T:" << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -50,11 +50,10 @@ void CublasMatmulOp::infer_shape() {
} }
void CublasMatmulOp::jit_prepare(JK& jk) { void CublasMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype(); jk << "«T:" << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -19,7 +19,7 @@ CublasTestOp::CublasTestOp(int size_mult) : size_mult(size_mult) {
} }
void CublasTestOp::jit_prepare(JK& jk) { void CublasTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]"); jk << "«T:float32";
} }
#else // JIT #else // JIT

View File

@ -55,10 +55,9 @@ void CudnnConv3dBackwardWOp::infer_shape() {
} }
void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) { void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Ty:") << dy->dtype(); jk << "«Ty:" << dy->dtype();
jk << _CS("][Tw:") << dw->dtype(); jk << "«Tw:" << dw->dtype();
jk << ']';
} }
static auto make_conv3d = get_op_info("cudnn_conv3d") static auto make_conv3d = get_op_info("cudnn_conv3d")

View File

@ -52,10 +52,9 @@ void CudnnConv3dBackwardXOp::infer_shape() {
} }
void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) { void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << dx->dtype(); jk << "«Tx:" << dx->dtype();
jk << _CS("][Ty:") << dy->dtype(); jk << "«Ty:" << dy->dtype();
jk << _CS("][Tw:") << w->dtype(); jk << "«Tw:" << w->dtype();
jk << ']';
} }

View File

@ -53,10 +53,9 @@ void CudnnConv3dOp::infer_shape() {
} }
void CudnnConv3dOp::jit_prepare(JK& jk) { void CudnnConv3dOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Ty:") << y->dtype(); jk << "«Ty:" << y->dtype();
jk << _CS("][Tw:") << w->dtype(); jk << "«Tw:" << w->dtype();
jk << ']';
} }
static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x") static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x")

View File

@ -71,13 +71,12 @@ void CudnnConvBackwardWOp::infer_shape() {
} }
void CudnnConvBackwardWOp::jit_prepare(JK& jk) { void CudnnConvBackwardWOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Ty:") << dy->dtype(); jk << "«Ty:" << dy->dtype();
jk << _CS("][Tw:") << dw->dtype(); jk << "«Tw:" << dw->dtype();
jk << _CS("][XFORMAT:") << xformat; jk << "«XFORMAT:" << xformat;
jk << _CS("][WFORMAT:") << wformat; jk << "«WFORMAT:" << wformat;
jk << _CS("][YFORMAT:") << yformat; jk << "«YFORMAT:" << yformat;
jk << ']';
} }
static auto make_conv = get_op_info("cudnn_conv") static auto make_conv = get_op_info("cudnn_conv")

View File

@ -70,13 +70,12 @@ void CudnnConvBackwardXOp::infer_shape() {
} }
void CudnnConvBackwardXOp::jit_prepare(JK& jk) { void CudnnConvBackwardXOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << dx->dtype(); jk << "«Tx:" << dx->dtype();
jk << _CS("][Ty:") << dy->dtype(); jk << "«Ty:" << dy->dtype();
jk << _CS("][Tw:") << w->dtype(); jk << "«Tw:" << w->dtype();
jk << _CS("][XFORMAT:") << xformat; jk << "«XFORMAT:" << xformat;
jk << _CS("][WFORMAT:") << wformat; jk << "«WFORMAT:" << wformat;
jk << _CS("][YFORMAT:") << yformat; jk << "«YFORMAT:" << yformat;
jk << ']';
} }
static auto make_conv = get_op_info("cudnn_conv") static auto make_conv = get_op_info("cudnn_conv")

View File

@ -72,13 +72,12 @@ void CudnnConvOp::infer_shape() {
} }
void CudnnConvOp::jit_prepare(JK& jk) { void CudnnConvOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Ty:") << y->dtype(); jk << "«Ty:" << y->dtype();
jk << _CS("][Tw:") << w->dtype(); jk << "«Tw:" << w->dtype();
jk << _CS("][XFORMAT:") << xformat; jk << "«XFORMAT:" << xformat;
jk << _CS("][WFORMAT:") << wformat; jk << "«WFORMAT:" << wformat;
jk << _CS("][YFORMAT:") << yformat; jk << "«YFORMAT:" << yformat;
jk << ']';
} }
static auto make_backwardx = get_op_info("cudnn_conv_backward_x") static auto make_backwardx = get_op_info("cudnn_conv_backward_x")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>(); .get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();

View File

@ -79,10 +79,9 @@ void CudnnRnnBackwardXOp::infer_shape() {
} }
void CudnnRnnBackwardXOp::jit_prepare(JK& jk) { void CudnnRnnBackwardXOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << hx->dtype(); jk << "«Tx:" << hx->dtype();
jk << _CS("][Ty:") << y->dtype(); jk << "«Ty:" << y->dtype();
jk << _CS("][Tw:") << w->dtype(); jk << "«Tw:" << w->dtype();
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -98,10 +98,9 @@ void CudnnRnnOp::infer_shape() {
} }
void CudnnRnnOp::jit_prepare(JK& jk) { void CudnnRnnOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Ty:") << y->dtype(); jk << "«Ty:" << y->dtype();
jk << _CS("][Tw:") << w->dtype(); jk << "«Tw:" << w->dtype();
jk << ']';
} }
static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x") static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x")

View File

@ -20,7 +20,7 @@ CudnnTestOp::CudnnTestOp(string cmd) : cmd(move(cmd)) {
} }
void CudnnTestOp::jit_prepare(JK& jk) { void CudnnTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]"); jk << "«T:float32";
} }
#else // JIT #else // JIT

View File

@ -44,9 +44,9 @@ void CufftFftOp::jit_prepare(JK& jk) {
printf("not supported fft dtype: %s\n", y->dtype().to_cstring()); printf("not supported fft dtype: %s\n", y->dtype().to_cstring());
ASSERT(false); ASSERT(false);
} }
jk << _CS("[T:") << y->dtype(); jk << "«T:" << y->dtype();
jk << _CS("][I:")<<inverse<<"]"; jk << "«I:"<<inverse<<"]";
jk << _CS("[TS:\"")<<y->dtype()<<"\"]"; jk << _CS("«TS:\"")<<y->dtype()<<"\"]";
} }
#else // JIT #else // JIT

View File

@ -25,8 +25,8 @@ CurandRandomOp::CurandRandomOp(NanoVector shape, NanoString dtype, NanoString ty
} }
void CurandRandomOp::jit_prepare(JK& jk) { void CurandRandomOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << output->dtype(); jk << "«T:" << output->dtype();
jk << _CS("][R:") << type << ']'; jk << "«R:" << type;
} }
#else // JIT #else // JIT

View File

@ -21,7 +21,7 @@ CuttTestOp::CuttTestOp(string cmd) : cmd(cmd) {
} }
void CuttTestOp::jit_prepare(JK& jk) { void CuttTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]"); jk << "«T:float32";
} }
#else // JIT #else // JIT

View File

@ -59,7 +59,7 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
void CuttTransposeOp::jit_prepare(JK& jk) { void CuttTransposeOp::jit_prepare(JK& jk) {
// do nothing // do nothing
jk << _CS("[T:1]"); jk << "«T:1";
} }
unordered_map<string, unsigned int> cutt_plan_cache; unordered_map<string, unsigned int> cutt_plan_cache;

View File

@ -37,7 +37,7 @@ VarPtr NcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
} }
void NcclAllReduceOp::jit_prepare(JK& jk) { void NcclAllReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() << ']'; jk << "«Tx:" << x->dtype();
} }
#else // JIT #else // JIT

View File

@ -35,7 +35,7 @@ VarPtr NcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
} }
void NcclBroadcastOp::jit_prepare(JK& jk) { void NcclBroadcastOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() << ']'; jk << "«Tx:" << x->dtype();
} }
#else // JIT #else // JIT

View File

@ -35,7 +35,7 @@ VarPtr NcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
} }
void NcclReduceOp::jit_prepare(JK& jk) { void NcclReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() << ']'; jk << "«Tx:" << x->dtype();
} }
#else // JIT #else // JIT

View File

@ -20,7 +20,7 @@ NcclTestOp::NcclTestOp(string cmd) : cmd(cmd) {
} }
void NcclTestOp::jit_prepare(JK& jk) { void NcclTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]"); jk << "«T:float32";
} }
#else // JIT #else // JIT

View File

@ -79,16 +79,15 @@ static const char* short_type(Var* x) {
} }
void MklConvBackwardWOp::jit_prepare(JK& jk) { void MklConvBackwardWOp::jit_prepare(JK& jk) {
jk << _CS("[Txd:") << x->dtype(); jk << "«Txd:" << x->dtype();
jk << _CS("][Tyd:") << dy->dtype(); jk << "«Tyd:" << dy->dtype();
jk << _CS("][Twd:") << dw->dtype(); jk << "«Twd:" << dw->dtype();
jk << _CS("][Tx:") << short_type(x); jk << "«Tx:" << short_type(x);
jk << _CS("][Tw:") << short_type(dw); jk << "«Tw:" << short_type(dw);
jk << _CS("][Ty:") << short_type(dy); jk << "«Ty:" << short_type(dy);
jk << _CS("][XFORMAT:") << xformat; jk << "«XFORMAT:" << xformat;
jk << _CS("][WFORMAT:") << wformat; jk << "«WFORMAT:" << wformat;
jk << _CS("][YFORMAT:") << yformat; jk << "«YFORMAT:" << yformat;
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -77,16 +77,15 @@ static const char* short_type(Var* x) {
} }
void MklConvBackwardXOp::jit_prepare(JK& jk) { void MklConvBackwardXOp::jit_prepare(JK& jk) {
jk << _CS("[Tyd:") << dy->dtype(); jk << "«Tyd:" << dy->dtype();
jk << _CS("][Twd:") << w->dtype(); jk << "«Twd:" << w->dtype();
jk << _CS("][Txd:") << dx->dtype(); jk << "«Txd:" << dx->dtype();
jk << _CS("][Tx:") << short_type(dx); jk << "«Tx:" << short_type(dx);
jk << _CS("][Tw:") << short_type(w); jk << "«Tw:" << short_type(w);
jk << _CS("][Ty:") << short_type(dy); jk << "«Ty:" << short_type(dy);
jk << _CS("][XFORMAT:") << xformat; jk << "«XFORMAT:" << xformat;
jk << _CS("][WFORMAT:") << wformat; jk << "«WFORMAT:" << wformat;
jk << _CS("][YFORMAT:") << yformat; jk << "«YFORMAT:" << yformat;
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -81,16 +81,15 @@ static const char* short_type(Var* x) {
} }
void MklConvOp::jit_prepare(JK& jk) { void MklConvOp::jit_prepare(JK& jk) {
jk << _CS("[Txd:") << x->dtype(); jk << "«Txd:" << x->dtype();
jk << _CS("][Tyd:") << y->dtype(); jk << "«Tyd:" << y->dtype();
jk << _CS("][Twd:") << w->dtype(); jk << "«Twd:" << w->dtype();
jk << _CS("][Tx:") << short_type(x); jk << "«Tx:" << short_type(x);
jk << _CS("][Tw:") << short_type(w); jk << "«Tw:" << short_type(w);
jk << _CS("][Ty:") << short_type(y); jk << "«Ty:" << short_type(y);
jk << _CS("][XFORMAT:") << xformat; jk << "«XFORMAT:" << xformat;
jk << _CS("][WFORMAT:") << wformat; jk << "«WFORMAT:" << wformat;
jk << _CS("][YFORMAT:") << yformat; jk << "«YFORMAT:" << yformat;
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -44,9 +44,9 @@ void MklMatmulOp::infer_shape() {
} }
void MklMatmulOp::jit_prepare(JK& jk) { void MklMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype(); jk << "«T:" << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N') << ']'; jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
} }
#else // JIT #else // JIT

View File

@ -19,7 +19,7 @@ MklTestOp::MklTestOp() {
} }
void MklTestOp::jit_prepare(JK& jk) { void MklTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]"); jk << "«T:float32";
} }
#else // JIT #else // JIT

View File

@ -62,8 +62,8 @@ VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
} }
void MpiAllReduceOp::jit_prepare(JK& jk) { void MpiAllReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][OP:") << op << ']'; jk << "«OP:" << op;
} }
#else // JIT #else // JIT

View File

@ -49,7 +49,7 @@ VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
} }
void MpiBroadcastOp::jit_prepare(JK& jk) { void MpiBroadcastOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() << ']'; jk << "«Tx:" << x->dtype();
} }
#else // JIT #else // JIT

View File

@ -62,8 +62,8 @@ VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
} }
void MpiReduceOp::jit_prepare(JK& jk) { void MpiReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][OP:") << op << ']'; jk << "«OP:" << op;
} }
#else // JIT #else // JIT

View File

@ -17,7 +17,7 @@ MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) {
} }
void MpiTestOp::jit_prepare(JK& jk) { void MpiTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]"); jk << "«T:float32";
} }
#else // JIT #else // JIT

View File

@ -585,6 +585,10 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
if (!v->allocator->is_cuda()) if (!v->allocator->is_cuda())
migrate_to_gpu(v, allocator); migrate_to_gpu(v, allocator);
} }
for (Var* v : op->outputs()) {
if (!v->allocator->is_cuda())
migrate_to_gpu(v, allocator);
}
} }
#endif #endif
#ifdef NODE_MEMCHECK #ifdef NODE_MEMCHECK

View File

@ -159,55 +159,51 @@ void FusedOp::do_jit_prepare(JK& jk) {
jk.clear(); jk.clear();
for (uint i=0; i<ops.size(); i++) { for (uint i=0; i<ops.size(); i++) {
Op* op = ops[i]; Op* op = ops[i];
jk << "[opkey" << i << JK::val; jk << "«opkey" << i << JK::val;
jk << op->name(); jk << op->name();
op->jit_prepare(jk); op->jit_prepare(jk);
jk << JK::end;
} }
jk << _CS("[JIT:1]"); jk << "«JIT:1";
if (!use_cuda) { if (!use_cuda) {
// only cpu // only cpu
jk << _CS("[JIT_cpu:1]"); jk << "«JIT_cpu:1";
this->flags.set(NodeFlags::_cuda, 0); this->flags.set(NodeFlags::_cuda, 0);
this->flags.set(NodeFlags::_cpu, 1); this->flags.set(NodeFlags::_cpu, 1);
} else { } else {
jk << _CS("[JIT_cuda:1]"); jk << "«JIT_cuda:1";
this->flags.set(NodeFlags::_cpu, 0); this->flags.set(NodeFlags::_cpu, 0);
this->flags.set(NodeFlags::_cuda, 1); this->flags.set(NodeFlags::_cuda, 1);
} }
jk << _CS("[graph:"); jk << "«graph:";
for (auto& t : edges) { for (auto& t : edges) {
uint i,j,k,l; uint i,j,k,l;
std::tie(i,j,k,l) = t; std::tie(i,j,k,l) = t;
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ','; jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
} }
jk << _CS("][var_info:") << JK::val; jk << "«var_info:" << JK::val;
bool use_int64_t = false; bool use_int64_t = false;
for (auto& vi : vars) { for (auto& vi : vars) {
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size()); jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
if (vi.type != 1 && vi.var->num >= std::numeric_limits<int32_t>::max()) if (vi.type != 1 && vi.var->num >= std::numeric_limits<int32_t>::max())
use_int64_t = true; use_int64_t = true;
} }
jk << JK::end;
if (use_int64_t) if (use_int64_t)
jk << _CS("[index_t:int64]"); jk << "«index_t:int64";
else else
jk << _CS("[index_t:int32]"); jk << "«index_t:int32";
if (loop_options->size()) { if (loop_options->size()) {
if (get_loop_option("compile_shapes")) { if (get_loop_option("compile_shapes")) {
jk << _CS("[shapes:"); jk << "«shapes:";
for (auto& vi : vars) { for (auto& vi : vars) {
jk << '['; jk << '[';
for (auto a : vi.var->shape) for (auto a : vi.var->shape)
jk << a << ','; jk << a << ',';
jk << _CS("],"); jk << "],";
} }
jk << JK::end;
} }
jk << _CS("[choices:"); jk << "«choices:";
for (auto& kv : *loop_options) for (auto& kv : *loop_options)
jk << kv.first << ':' << kv.second << ','; jk << kv.first << ':' << kv.second << ',';
jk << JK::end;
} }
jk.finilize(); jk.finilize();
} }

View File

@ -194,7 +194,7 @@ void run_cmd(string cmd, string cwd="") {
static string get_symbol_name(const string& jit_key) { static string get_symbol_name(const string& jit_key) {
int i=0; int i=0;
while (i<jit_key.size() && jit_key[i]!='[') i++; while (i<jit_key.size() && jit_key[i]>=0) i++;
string op_name = i ? jit_key.substr(0, i) : "fused"; string op_name = i ? jit_key.substr(0, i) : "fused";
op_name = Op::file_name_to_class_name(op_name); op_name = Op::file_name_to_class_name(op_name);
// _ZN7jittorXyyyyyy7jit_runEv // _ZN7jittorXyyyyyy7jit_runEv

View File

@ -80,42 +80,26 @@ static void convert_itof(string& s) {
vector<pair<string,string>> parse_jit_keys(const string& s) { vector<pair<string,string>> parse_jit_keys(const string& s) {
vector<pair<string,string>> jit_keys; vector<pair<string,string>> jit_keys;
int presum = 0; auto sp = split(s, JitKey::key);
char state=0; for (auto& ss : sp) {
string key, val; if (!ss.size()) continue;
for (char c : s) { string key, val;
if (c==JK::key) { char state=0;
presum++; for (auto c : ss) {
if (presum==1) { if (state == 0 &&
(c==JK::val || c==JK::hex_val)) {
state = c; state = c;
continue; continue;
} }
} else if (state == 0) key += c;
if (c==JK::val || c==JK::hex_val) { else val += c;
if (presum==1 && state==JK::key) {
state = c;
continue;
}
} else
if (c==JK::end) {
presum--;
if (presum==0) {
if (state == JK::hex_val)
hex_to_dec(val);
if (startswith(val, "itof"))
convert_itof(val);
jit_keys.emplace_back(move(key), move(val));
continue;
}
}
if (presum) {
if (state==JK::key)
key += c;
if (state==JK::val || state==JK::hex_val)
val += c;
} }
if (state == JK::hex_val)
hex_to_dec(val);
if (startswith(val, "itof"))
convert_itof(val);
jit_keys.emplace_back(move(key), move(val));
} }
ASSERT(presum==0) << s;
return jit_keys; return jit_keys;
} }

View File

@ -5,6 +5,7 @@
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
#pragma once #pragma once
#include <cstring>
#include "common.h" #include "common.h"
#include "misc/nano_string.h" #include "misc/nano_string.h"
#include "misc/nano_vector.h" #include "misc/nano_vector.h"
@ -13,11 +14,10 @@ namespace jittor {
struct JitKey { struct JitKey {
static constexpr size_t buffer_size = 2*1024*1024; static constexpr size_t buffer_size = 2*1024*1024;
static constexpr char static constexpr const char
key = '[', *key = "«",
val = ':', val = ':',
hex_val = '=', hex_val = '=';
end = ']';
int64 size=0; int64 size=0;
uint64 flags=0; uint64 flags=0;
char buffer[buffer_size]; char buffer[buffer_size];
@ -27,7 +27,7 @@ struct JitKey {
inline void clear() {size = flags = 0;} inline void clear() {size = flags = 0;}
inline void finilize() { buffer[size] = 0; } inline void finilize() { buffer[size] = 0; }
inline bool empty() { return buffer[size-1] != end; } inline bool empty() { return !size; }
inline const char* to_cstring() { inline const char* to_cstring() {
return &buffer[0]; return &buffer[0];
} }
@ -81,11 +81,38 @@ struct __jk_int256 {
typedef JitKey JK; typedef JitKey JK;
EXTERN_LIB JK& get_jk(); EXTERN_LIB JK& get_jk();
inline void jk_put_str_with_len(JK& jk, const char* a, int n) {
char* xx = &jk.buffer[jk.size];
int i=0;
while (i+32<=n) {
((__jk_int256*)(xx+i))[0] = ((const __jk_int256*)(a+i))[0];
i+=32;
}
while (i+16<=n) {
((__jk_int128*)(xx+i))[0] = ((const __jk_int128*)(a+i))[0];
i+=16;
}
while (i+8<=n) {
((long long*)(xx+i))[0] = ((const long long*)(a+i))[0];
i+=8;
}
while (i+4<=n) {
((int*)(xx+i))[0] = ((const int*)(a+i))[0];
i+=4;
}
while (i+2<=n) {
((int16_t*)(xx+i))[0] = ((const int16_t*)(a+i))[0];
i+=2;
}
while (i+1<=n) {
((char*)(xx+i))[0] = ((const char*)(a+i))[0];
i+=1;
}
jk.size += n;
}
inline JK& operator<<(JK& jk, const char* s) { inline JK& operator<<(JK& jk, const char* s) {
int i; jk_put_str_with_len(jk, s, strlen(s));
for (i=0; s[i]; i++)
jk.buffer[jk.size+i] = s[i];
jk.size += i;
return jk; return jk;
} }
@ -199,129 +226,133 @@ vector<pair<string,string>> parse_jit_keys(const string& s);
template <class Ta, class Tb> template <class Ta, class Tb>
void add_jit_define(JK& jk, const Ta& key, const Tb& val) { void add_jit_define(JK& jk, const Ta& key, const Tb& val) {
jk << JK::key << key << JK::val << val << JK::end; jk << JK::key << key << JK::val << val;
} }
template <class Ta, class Tb, class Tc> template <class Ta, class Tb, class Tc>
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const Tc& val) { void add_jit_define(JK& jk, const Ta& key, const Tb& i, const Tc& val) {
jk << JK::key << key << i << JK::val << val << JK::end; jk << JK::key << key << i << JK::val << val;
} }
template <class Ta> template <class Ta>
void add_jit_define(JK& jk, const Ta& key, const JK::hex& val) { void add_jit_define(JK& jk, const Ta& key, const JK::hex& val) {
jk << JK::key << key << JK::hex_val << val << JK::end; jk << JK::key << key << JK::hex_val << val;
} }
template <class Ta, class Tb> template <class Ta, class Tb>
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex& val) { void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex& val) {
jk << JK::key << key << i << JK::hex_val << val << JK::end; jk << JK::key << key << i << JK::hex_val << val;
} }
template <class Ta> template <class Ta>
void add_jit_define(JK& jk, const Ta& key, const JK::hex1& val) { void add_jit_define(JK& jk, const Ta& key, const JK::hex1& val) {
jk << JK::key << key << JK::hex_val << val << JK::end; jk << JK::key << key << JK::hex_val << val;
} }
template <class Ta, class Tb> template <class Ta, class Tb>
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex1& val) { void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex1& val) {
jk << JK::key << key << i << JK::hex_val << val << JK::end; jk << JK::key << key << i << JK::hex_val << val;
} }
template <class Ta> template <class Ta>
void add_jit_define(JK& jk, const Ta& key, const JK::hex2& val) { void add_jit_define(JK& jk, const Ta& key, const JK::hex2& val) {
jk << JK::key << key << JK::hex_val << val << JK::end; jk << JK::key << key << JK::hex_val << val;
} }
template <class Ta, class Tb> template <class Ta, class Tb>
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex2& val) { void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex2& val) {
jk << JK::key << key << i << JK::hex_val << val << JK::end; jk << JK::key << key << i << JK::hex_val << val;
} }
#define _CS(x) x
// // begin of const string
// #define MAX_CONST_CHAR 32
// begin of const string // #define _CS_MIN(a,b) (a)<(b)?(a):(b)
#define MAX_CONST_CHAR 32
#define _CS_MIN(a,b) (a)<(b)?(a):(b) // #define _CS_T(s)\
// getChr(s,0),\
// getChr(s,1),\
// getChr(s,2),\
// getChr(s,3),\
// getChr(s,4),\
// getChr(s,5),\
// getChr(s,6),\
// getChr(s,7),\
// getChr(s,8),\
// getChr(s,9),\
// getChr(s,10),\
// getChr(s,11),\
// getChr(s,12),\
// getChr(s,13),\
// getChr(s,14),\
// getChr(s,15),\
// getChr(s,16),\
// getChr(s,17),\
// getChr(s,18),\
// getChr(s,19),\
// getChr(s,20),\
// getChr(s,21),\
// getChr(s,22),\
// getChr(s,23),\
// getChr(s,24),\
// getChr(s,25),\
// getChr(s,26),\
// getChr(s,27),\
// getChr(s,28),\
// getChr(s,29),\
// getChr(s,30),\
// getChr(s,31),\
// getChr(s,32),\
// getChr(s,33),\
// getChr(s,34),\
// getChr(s,35)
#define _CS_T(s)\ // #define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0)
getChr(s,0),\
getChr(s,1),\
getChr(s,2),\
getChr(s,3),\
getChr(s,4),\
getChr(s,5),\
getChr(s,6),\
getChr(s,7),\
getChr(s,8),\
getChr(s,9),\
getChr(s,10),\
getChr(s,11),\
getChr(s,12),\
getChr(s,13),\
getChr(s,14),\
getChr(s,15),\
getChr(s,16),\
getChr(s,17),\
getChr(s,18),\
getChr(s,19),\
getChr(s,20),\
getChr(s,21),\
getChr(s,22),\
getChr(s,23),\
getChr(s,24),\
getChr(s,25),\
getChr(s,26),\
getChr(s,27),\
getChr(s,28),\
getChr(s,29),\
getChr(s,30),\
getChr(s,31),\
getChr(s,32),\
getChr(s,33),\
getChr(s,34),\
getChr(s,35)
#define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0) // #ifdef _MSC_VER
// #define _CS(str) str
// #else
// #define _CS(str) _CS_G<_CS_T(str)>()
// #endif
#ifdef _MSC_VER // template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G {
#define _CS(str) str // };
#else
#define _CS(str) _CS_G<_CS_T(str)>()
#endif
template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G { // template<> struct _CS_G<0,0,0,0> {};
};
template<> struct _CS_G<0,0,0,0> {}; // template <char c1, char c2, char c3, char c4, char... Chars_>
// inline JK& operator<<(JK& jk, const _CS_G<c1,c2,c3,c4,Chars_...>& _) {
// ((uint32*)(jk.buffer+jk.size))[0] =
// (uint32((uint8)(c4))<<24)+
// (uint32((uint8)(c3))<<16)+
// (uint32((uint8)(c2))<<8)+
// uint32((uint8)(c1));
// if (c4) {
// jk.size += 4;
// jk << _CS_G<Chars_...>();
// } else
// if (c3) {
// jk.size += 3;
// } else
// if (c2) {
// jk.size += 2;
// } else
// if (c1) {
// jk.size += 1;
// }
// return jk;
// }
template <char c1, char c2, char c3, char c4, char... Chars_> // template <>
inline JK& operator<<(JK& jk, const _CS_G<c1,c2,c3,c4,Chars_...>& _) { // inline JK& operator<<(JK& jk, const _CS_G<0,0,0,0>& _) {
((int*)(jk.buffer+jk.size))[0] = c4*(1<<24)+c3*(1<<16)+c2*(1<<8)+c1; // return jk;
if (c4) { // }
jk.size += 4;
jk << _CS_G<Chars_...>();
} else
if (c3) {
jk.size += 3;
} else
if (c2) {
jk.size += 2;
} else
if (c1) {
jk.size += 1;
}
return jk;
}
template <>
inline JK& operator<<(JK& jk, const _CS_G<0,0,0,0>& _) {
return jk;
}
inline JK& operator<<(JK& jk, float64 f) { inline JK& operator<<(JK& jk, float64 f) {
return jk << _CS("itof(0x") << JK::hex(ftoi(f)) << ')'; return jk << "itof(0x" << JK::hex(ftoi(f)) << ')';
} }
} // jittor } // jittor

View File

@ -128,14 +128,16 @@ string Op::get_hash_name() {
void Op::do_jit_prepare(JK& jk) { void Op::do_jit_prepare(JK& jk) {
memcheck_all_exist(); memcheck_all_exist();
jk << name(); jk << name();
auto pre_size = jk.size;
jit_prepare(jk); jit_prepare(jk);
if (jk.empty()) { if (jk.size == pre_size) {
// not a jit op // not a jit op
bool has_cuda = flags.get(NodeFlags::_cuda); bool has_cuda = flags.get(NodeFlags::_cuda);
bool has_cpu = flags.get(NodeFlags::_cpu); bool has_cpu = flags.get(NodeFlags::_cpu);
CHECK(has_cuda || has_cpu); CHECK(has_cuda || has_cpu);
if (has_cuda && has_cpu && !use_cuda) if (has_cuda && has_cpu && !use_cuda)
flags.set(NodeFlags::_cuda, 0); flags.set(NodeFlags::_cuda, 0);
jk.clear();
} else { } else {
bool use_int64_t = false; bool use_int64_t = false;
// TODO: fused op do not have inputs, // TODO: fused op do not have inputs,
@ -149,9 +151,9 @@ void Op::do_jit_prepare(JK& jk) {
if (var->num >= std::numeric_limits<int32_t>::max()) if (var->num >= std::numeric_limits<int32_t>::max())
use_int64_t = true; use_int64_t = true;
} }
jk << _CS("[JIT:1]"); jk << "«JIT:1";
if (use_cuda_op && flags.get(NodeFlags::_cuda)) { if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
jk << _CS("[JIT_cuda:1]"); jk << "«JIT_cuda:1";
flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cpu, 0);
// TODO: 64bit index in CUDA // TODO: 64bit index in CUDA
// use_int64_t = false; // use_int64_t = false;
@ -164,14 +166,14 @@ void Op::do_jit_prepare(JK& jk) {
} }
ASSERT(flags.get(NodeFlags::_cpu)) ASSERT(flags.get(NodeFlags::_cpu))
<< "Op" << name() << "doesn't have cpu version"; << "Op" << name() << "doesn't have cpu version";
jk << _CS("[JIT_cpu:1]"); jk << "«JIT_cpu:1";
flags.set(NodeFlags::_cuda, 0); flags.set(NodeFlags::_cuda, 0);
} }
if (try_use_32bit_index) use_int64_t = false; if (try_use_32bit_index) use_int64_t = false;
if (use_int64_t) if (use_int64_t)
jk << _CS("[index_t:int64]"); jk << "«index_t:int64";
else else
jk << _CS("[index_t:int32]"); jk << "«index_t:int32";
} }
jk.finilize(); jk.finilize();
} }

View File

@ -158,14 +158,13 @@ void ArgReduceOp::infer_shape() {
} }
void ArgReduceOp::jit_prepare(JK& jk) { void ArgReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Ty:") << y->dtype(); jk << "«Ty:" << y->dtype();
jk << _CS("][XDIM=") << JK::hex1(x->shape.size()); jk << "«XDIM=" << JK::hex1(x->shape.size());
jk << _CS("][YDIM=") << JK::hex1(y->shape.size()); jk << "«YDIM=" << JK::hex1(y->shape.size());
jk << _CS("][KEEPDIMS:") << (keepdims ? '1' : '0'); jk << "«KEEPDIMS:" << (keepdims ? '1' : '0');
jk << _CS("][DIM=") << JK::hex1(dim); jk << "«DIM=" << JK::hex1(dim);
jk << _CS("][CMP:") << (op==ns_minimum ? "<" : ">"); jk << "«CMP:" << (op==ns_minimum ? "<" : ">");
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -130,12 +130,11 @@ void ArgsortOp::infer_shape() {
} }
void ArgsortOp::jit_prepare(JK& jk) { void ArgsortOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Ty:") << y->dtype(); jk << "«Ty:" << y->dtype();
jk << _CS("][XDIM=") << JK::hex1(x->shape.size()); jk << "«XDIM=" << JK::hex1(x->shape.size());
jk << _CS("][DIM=") << JK::hex1(dim); jk << "«DIM=" << JK::hex1(dim);
jk << _CS("][CMP:") << (descending ? '>' : '<'); jk << "«CMP:" << (descending ? '>' : '<');
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -78,7 +78,7 @@ ArrayOp::ArrayOp(ArrayArgs&& args) {
void ArrayOp::jit_prepare(JK& jk) { void ArrayOp::jit_prepare(JK& jk) {
if (output->flags.get(NodeFlags::_force_fuse)) { if (output->flags.get(NodeFlags::_force_fuse)) {
jk << _CS("[T:") << output->dtype() << ']'; jk << "«T:" << output->dtype();
// fill or find cbuffer for const var pass // fill or find cbuffer for const var pass
if (output->dtype().dsize() == 4) { if (output->dtype().dsize() == 4) {
@ -86,7 +86,7 @@ void ArrayOp::jit_prepare(JK& jk) {
auto y = std::abs(ptr<float32>()[0]); auto y = std::abs(ptr<float32>()[0]);
auto z = ptr<uint32>()[0]; auto z = ptr<uint32>()[0];
if ((x<=2) || (y==1.0f || y==2.0f)) if ((x<=2) || (y==1.0f || y==2.0f))
jk << _CS("[o:") << z << ']'; jk << "«o:" << z;
} }
// end of fill cbuffer // end of fill cbuffer
} }

View File

@ -540,10 +540,10 @@ void BinaryOp::infer_shape() {
} }
void BinaryOp::jit_prepare(JK& jk) { void BinaryOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() jk << "«Tx:" << x->dtype()
<< _CS("][Ty:") << y->dtype() << "«Ty:" << y->dtype()
<< _CS("][Tz:") << z->dtype() << "«Tz:" << z->dtype()
<< _CS("][OP:") << ns << ']'; << "«OP:" << ns;
} }
#else // JIT #else // JIT

View File

@ -167,9 +167,9 @@ void BroadcastToOp::infer_shape() {
} }
void BroadcastToOp::jit_prepare(JK& jk) { void BroadcastToOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() jk << "«Tx:" << x->dtype()
<< _CS("][DIM=") << JK::hex1(z->shape.size()) << "«DIM=" << JK::hex1(z->shape.size())
<< _CS("][BCAST=") << JK::hex(bcast_mask) << ']'; << "«BCAST=" << JK::hex(bcast_mask);
} }
#else // JIT #else // JIT

View File

@ -25,10 +25,10 @@ void CandidateOp::infer_shape() {
} }
void CandidateOp::jit_prepare(JK& jk) { void CandidateOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Ty:") << y->dtype(); jk << "«Ty:" << y->dtype();
jk << _CS("][FUNC:") << fail_cond; jk << "«FUNC:" << fail_cond;
jk << _CS("][XDIM=") << JK::hex1(x->shape.size()) << ']'; jk << "«XDIM=" << JK::hex1(x->shape.size());
} }
#else // JIT #else // JIT

View File

@ -123,18 +123,18 @@ void CodeOp::jit_prepare(JK& jk) {
// forward: in0 in1 in2 -> out0 out1 // forward: in0 in1 in2 -> out0 out1
// backward: in0 in1 in2 in3(pout0) in4(pout1) // backward: in0 in1 in2 in3(pout0) in4(pout1)
jk << _CS("[IN_SIZE=") << JK::hex(_inputs.size()); jk << "«IN_SIZE=" << JK::hex(_inputs.size());
for (uint i=0; i<_inputs.size(); i++) { for (uint i=0; i<_inputs.size(); i++) {
jk << _CS("][in") << JK::hex(i) << _CS("_dim=") jk << "«in" << JK::hex(i) << "_dim="
<< JK::hex1(_inputs[i]->shape.size()); << JK::hex1(_inputs[i]->shape.size());
jk << _CS("][in") << JK::hex(i) << _CS("_type:") jk << "«in" << JK::hex(i) << "_type:"
<< _inputs[i]->dtype(); << _inputs[i]->dtype();
} }
jk << _CS("][OUT_SIZE=") << JK::hex(_outputs.size()); jk << "«OUT_SIZE=" << JK::hex(_outputs.size());
for (uint i=0; i<_outputs.size(); i++) { for (uint i=0; i<_outputs.size(); i++) {
jk << _CS("][out") << JK::hex(i) << _CS("_dim=") jk << "«out" << JK::hex(i) << "_dim="
<< JK::hex1(_outputs[i]->shape.size()); << JK::hex1(_outputs[i]->shape.size());
jk << _CS("][out") << JK::hex(i) << _CS("_type:") jk << "«out" << JK::hex(i) << "_type:"
<< _outputs[i]->dtype(); << _outputs[i]->dtype();
} }
string& header = flags.get(NodeFlags::_cuda) ? string& header = flags.get(NodeFlags::_cuda) ?
@ -142,9 +142,9 @@ void CodeOp::jit_prepare(JK& jk) {
string& src = flags.get(NodeFlags::_cuda) ? string& src = flags.get(NodeFlags::_cuda) ?
cuda_src : cpu_src; cuda_src : cpu_src;
jk << _CS("][HEADER:") << header; jk << "«HEADER:" << header;
CHECK(src.size()); CHECK(src.size());
jk << _CS("\nnamespace jittor {\n"); jk << "\nnamespace jittor {\n";
int i=0; int i=0;
// move cuda kernel function into header // move cuda kernel function into header
for (; i<src.size(); i++) { for (; i<src.size(); i++) {
@ -165,9 +165,8 @@ void CodeOp::jit_prepare(JK& jk) {
} }
} else break; } else break;
} }
jk << _CS("}][CODE:"); jk << "}«CODE:";
for (; i<src.size(); i++) jk << src[i]; for (; i<src.size(); i++) jk << src[i];
jk << ']';
} }
} // jittor } // jittor

View File

@ -74,12 +74,11 @@ VarPtr FuseTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
void FuseTransposeOp::jit_prepare(JK& jk) { void FuseTransposeOp::jit_prepare(JK& jk) {
auto bc = type()==OpType::broadcast; auto bc = type()==OpType::broadcast;
auto ax = bc ? axes : get_reverse(axes); auto ax = bc ? axes : get_reverse(axes);
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][DIM=") << JK::hex1(axes.size()); jk << "«DIM=" << JK::hex1(axes.size());
jk << _CS("][BC:") << JK::hex1(bc); jk << "«BC:" << JK::hex1(bc);
for (uint i=0; i<ax.size(); i++) for (uint i=0; i<ax.size(); i++)
jk << _CS("][AXES") << JK::hex1(ax[i]) << '=' << JK::hex1(i); jk << "«AXES" << JK::hex1(ax[i]) << '=' << JK::hex1(i);
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -444,26 +444,26 @@ void GetitemOp::grads(Var** dout, VarPtr* dins) {
void GetitemOp::jit_prepare(JK& jk) { void GetitemOp::jit_prepare(JK& jk) {
auto in = inputs().front(); auto in = inputs().front();
int idim = i_to_vs.size(); int idim = i_to_vs.size();
jk << _CS("[Ti:") << in->dtype(); jk << "«Ti:" << in->dtype();
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size()); jk << "«IDIM=" << JK::hex1(i_to_vs.size());
jk << _CS("][ODIM=") << JK::hex1(o_shape.size()); jk << "«ODIM=" << JK::hex1(o_shape.size());
if (first_oid_of_var>=0) { if (first_oid_of_var>=0) {
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var); jk << "«FOV=" << JK::hex1(first_oid_of_var);
jk << _CS("][VD=") << JK::hex1(var_dim); jk << "«VD=" << JK::hex1(var_dim);
} }
for (int i=0; i<idim; i++) { for (int i=0; i<idim; i++) {
auto iv = i_to_vs[i]; auto iv = i_to_vs[i];
auto io = i_to_o[i]; auto io = i_to_o[i];
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv); jk << "«IV" << JK::hex1(i) << ':' << JK::shex1(iv);
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io); jk << "«IO" << JK::hex1(i) << ':' << JK::shex1(io);
auto& v = vs.slices[iv]; auto& v = vs.slices[iv];
if (iv>=0 && io==-1) { if (iv>=0 && io==-1) {
if (v.is_int()) { if (v.is_int()) {
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1"); jk << "«VS" << JK::hex1(i) << ":-1";
} else } else
if (v.is_str()) { if (v.is_str()) {
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5"); jk << "«VS" << JK::hex1(i) << ":-5";
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str(); jk << "«VSS" << JK::hex1(i) << ":" << v.get_str();
} else { } else {
ASSERT(v.is_var()); ASSERT(v.is_var());
auto var = v.var; auto var = v.var;
@ -475,13 +475,13 @@ void GetitemOp::jit_prepare(JK& jk) {
if (vshape[j] == o_shape[k]) if (vshape[j] == o_shape[k])
vsmask |= 1<<(j+var_dim-vdim); vsmask |= 1<<(j+var_dim-vdim);
} }
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask); jk << "«VS" << JK::hex1(i) << '=' << JK::hex(vsmask);
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype(); jk << "«VST" << JK::hex1(i) << ':' << var->dtype();
} }
} else } else
if (iv>=0 && io>=0) { if (iv>=0 && io>=0) {
ASSERT(v.is_slice()); ASSERT(v.is_slice());
jk << _CS("][VS") << JK::hex1(i) << ':'; jk << "«VS" << JK::hex1(i) << ':';
if (std::abs(v.slice.step) <= 1) if (std::abs(v.slice.step) <= 1)
jk << JK::shex1(v.slice.step); jk << JK::shex1(v.slice.step);
else else
@ -495,11 +495,10 @@ void GetitemOp::jit_prepare(JK& jk) {
int tdims[6]; int tdims[6];
cuda_loop_schedule(o_shape, masks, tdims); cuda_loop_schedule(o_shape, masks, tdims);
for (int i=0; i<no; i++) { for (int i=0; i<no; i++) {
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]); jk << "«LO" << JK::hex1(i) << '=' << JK::hex(masks[i]);
} }
} }
#endif #endif
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -39,8 +39,8 @@ RandomOp::RandomOp(NanoVector shape, NanoString dtype, NanoString type) {
} }
void RandomOp::jit_prepare(JK& jk) { void RandomOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << output->dtype(); jk << "«T:" << output->dtype();
jk << _CS("][R:") << type << ']'; jk << "«R:" << type;
} }
#else // JIT #else // JIT

View File

@ -338,12 +338,12 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
} }
void ReduceOp::jit_prepare(JK& jk) { void ReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() jk << "«Tx:" << x->dtype()
<< _CS("][Ty:") << y->dtype() << "«Ty:" << y->dtype()
<< _CS("][Tz:") << y->dtype() << "«Tz:" << y->dtype()
<< _CS("][OP:") << ns << "«OP:" << ns
<< _CS("][DIM=") << JK::hex1(x->shape.size()) << "«DIM=" << JK::hex1(x->shape.size())
<< _CS("][REDUCE=") << JK::hex(reduce_mask) << ']'; << "«REDUCE=" << JK::hex(reduce_mask);
} }
#else // JIT #else // JIT

View File

@ -91,21 +91,20 @@ void ReindexOp::infer_shape() {
} }
void ReindexOp::jit_prepare(JK& jk) { void ReindexOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() jk << "«Tx:" << x->dtype()
<< _CS("][XDIM=") << JK::hex1(x->shape.size()) << "«XDIM=" << JK::hex1(x->shape.size())
<< _CS("][YDIM=") << JK::hex1(y->shape.size()) << "«YDIM=" << JK::hex1(y->shape.size())
<< _CS("][OVERFLOW:") << overflow_value; << "«OVERFLOW:" << overflow_value;
for (uint i=0; i<indexes.size(); i++) for (uint i=0; i<indexes.size(); i++)
jk << _CS("][INDEX") << JK::hex1(i) << ':' << indexes[i]; jk << "«INDEX" << JK::hex1(i) << ':' << indexes[i];
jk << _CS("][OSIZE=") << JK::hex1(overflow_conditions.size()); jk << "«OSIZE=" << JK::hex1(overflow_conditions.size());
for (uint i=0; i<overflow_conditions.size(); i++) for (uint i=0; i<overflow_conditions.size(); i++)
jk << _CS("][OFD") << JK::hex1(i) << ':' << overflow_conditions[i]; jk << "«OFD" << JK::hex1(i) << ':' << overflow_conditions[i];
jk << _CS("][ESIZE=") << JK::hex1(extras.size()); jk << "«ESIZE=" << JK::hex1(extras.size());
for (uint i=0; i<extras.size(); i++) { for (uint i=0; i<extras.size(); i++) {
jk << _CS("][EDIM") << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size()); jk << "«EDIM" << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
jk << _CS("][Te") << JK::hex1(i) << ':' << extras[i]->dtype(); jk << "«Te" << JK::hex1(i) << ':' << extras[i]->dtype();
} }
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -73,21 +73,20 @@ void ReindexReduceOp::infer_shape() {
} }
void ReindexReduceOp::jit_prepare(JK& jk) { void ReindexReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() jk << "«Tx:" << x->dtype()
<< _CS("][OP:") << ns << "«OP:" << ns
<< _CS("][YDIM=") << JK::hex1(y->shape.size()) << "«YDIM=" << JK::hex1(y->shape.size())
<< _CS("][XDIM=") << JK::hex1(x->shape.size()); << "«XDIM=" << JK::hex1(x->shape.size());
for (uint i=0; i<indexes.size(); i++) for (uint i=0; i<indexes.size(); i++)
jk << _CS("][INDEX") << JK::hex1(i) << ':' << indexes[i]; jk << "«INDEX" << JK::hex1(i) << ':' << indexes[i];
jk << _CS("][OSIZE=") << JK::hex1(overflow_conditions.size()); jk << "«OSIZE=" << JK::hex1(overflow_conditions.size());
for (uint i=0; i<overflow_conditions.size(); i++) for (uint i=0; i<overflow_conditions.size(); i++)
jk << _CS("][OFD") << JK::hex1(i) << ':' << overflow_conditions[i]; jk << "«OFD" << JK::hex1(i) << ':' << overflow_conditions[i];
jk << _CS("][ESIZE=") << JK::hex1(extras.size()); jk << "«ESIZE=" << JK::hex1(extras.size());
for (uint i=0; i<extras.size(); i++) { for (uint i=0; i<extras.size(); i++) {
jk << _CS("][EDIM") << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size()); jk << "«EDIM" << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
jk << _CS("][Te") << JK::hex1(i) << ':' << extras[i]->dtype(); jk << "«Te" << JK::hex1(i) << ':' << extras[i]->dtype();
} }
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -30,7 +30,7 @@ void SafeClipOp::infer_shape() {
} }
void SafeClipOp::jit_prepare(JK& jk) { void SafeClipOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() <<']'; jk << "«Tx:" << x->dtype() <<"«";
} }
#else // JIT #else // JIT

View File

@ -201,32 +201,32 @@ void SetitemOp::jit_prepare(JK& jk) {
break; break;
} }
auto data = input(1); auto data = input(1);
jk << _CS("[OP:") << op jk << "«OP:" << op
<< _CS("][Td:") << data->dtype() << "«Td:" << data->dtype()
<< _CS("][BMASK=") << JK::hex(bmask); << "«BMASK=" << JK::hex(bmask);
// TODO: merge code // TODO: merge code
auto in = inputs().front(); auto in = inputs().front();
int idim = i_to_vs.size(); int idim = i_to_vs.size();
jk << _CS("][Ti:") << in->dtype(); jk << "«Ti:" << in->dtype();
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size()); jk << "«IDIM=" << JK::hex1(i_to_vs.size());
jk << _CS("][ODIM=") << JK::hex1(o_shape.size()); jk << "«ODIM=" << JK::hex1(o_shape.size());
if (first_oid_of_var>=0) { if (first_oid_of_var>=0) {
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var); jk << "«FOV=" << JK::hex1(first_oid_of_var);
jk << _CS("][VD=") << JK::hex1(var_dim); jk << "«VD=" << JK::hex1(var_dim);
} }
for (int i=0; i<idim; i++) { for (int i=0; i<idim; i++) {
auto iv = i_to_vs[i]; auto iv = i_to_vs[i];
auto io = i_to_o[i]; auto io = i_to_o[i];
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv); jk << "«IV" << JK::hex1(i) << ':' << JK::shex1(iv);
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io); jk << "«IO" << JK::hex1(i) << ':' << JK::shex1(io);
auto& v = vs.slices[iv]; auto& v = vs.slices[iv];
if (iv>=0 && io==-1) { if (iv>=0 && io==-1) {
if (v.is_int()) { if (v.is_int()) {
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1"); jk << "«VS" << JK::hex1(i) << ":-1";
} else } else
if (v.is_str()) { if (v.is_str()) {
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5"); jk << "«VS" << JK::hex1(i) << ":-5";
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str(); jk << "«VSS" << JK::hex1(i) << ":" << v.get_str();
} else { } else {
ASSERT(v.is_var()); ASSERT(v.is_var());
auto var = v.var; auto var = v.var;
@ -238,13 +238,13 @@ void SetitemOp::jit_prepare(JK& jk) {
if (vshape[j] == o_shape[k]) if (vshape[j] == o_shape[k])
vsmask |= 1<<(j+var_dim-vdim); vsmask |= 1<<(j+var_dim-vdim);
} }
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask); jk << "«VS" << JK::hex1(i) << '=' << JK::hex(vsmask);
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype(); jk << "«VST" << JK::hex1(i) << ':' << var->dtype();
} }
} else } else
if (iv>=0 && io>=0) { if (iv>=0 && io>=0) {
ASSERT(v.is_slice()); ASSERT(v.is_slice());
jk << _CS("][VS") << JK::hex1(i) << ':'; jk << "«VS" << JK::hex1(i) << ':';
if (std::abs(v.slice.step) <= 1) if (std::abs(v.slice.step) <= 1)
jk << JK::shex1(v.slice.step); jk << JK::shex1(v.slice.step);
else else
@ -258,11 +258,10 @@ void SetitemOp::jit_prepare(JK& jk) {
int tdims[6]; int tdims[6];
cuda_loop_schedule(o_shape, masks, tdims); cuda_loop_schedule(o_shape, masks, tdims);
for (int i=0; i<no; i++) { for (int i=0; i<no; i++) {
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]); jk << "«LO" << JK::hex1(i) << '=' << JK::hex(masks[i]);
} }
} }
#endif #endif
jk << ']';
} }
void SetitemOp::compile_optimize(string& src) { void SetitemOp::compile_optimize(string& src) {

View File

@ -64,10 +64,10 @@ void TernaryOp::infer_shape() {
} }
void TernaryOp::jit_prepare(JK& jk) { void TernaryOp::jit_prepare(JK& jk) {
jk << _CS("[Tc:") << cond->dtype(); jk << "«Tc:" << cond->dtype();
jk << _CS("][Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][Ty:") << y->dtype(); jk << "«Ty:" << y->dtype();
jk << _CS("][Tz:") << z->dtype() << ']'; jk << "«Tz:" << z->dtype();
} }
#else // JIT #else // JIT

View File

@ -79,11 +79,10 @@ VarPtr TransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
} }
void TransposeOp::jit_prepare(JK& jk) { void TransposeOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype(); jk << "«Tx:" << x->dtype();
jk << _CS("][DIM=") << JK::hex1(axes.size()); jk << "«DIM=" << JK::hex1(axes.size());
for (uint i=0; i<axes.size(); i++) for (uint i=0; i<axes.size(); i++)
jk << _CS("][AXES") << JK::hex1(axes[i]) << '=' << JK::hex1(i); jk << "«AXES" << JK::hex1(axes[i]) << '=' << JK::hex1(i);
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -681,9 +681,9 @@ void UnaryOp::infer_shape() {
} }
void UnaryOp::jit_prepare(JK& jk) { void UnaryOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() jk << "«Tx:" << x->dtype()
<< _CS("][Ty:") << y->dtype() << "«Ty:" << y->dtype()
<< _CS("][OP:") << ns << ']'; << "«OP:" << ns;
} }
#else // JIT #else // JIT

View File

@ -54,10 +54,9 @@ void WhereOp::infer_shape() {
} }
void WhereOp::jit_prepare(JK& jk) { void WhereOp::jit_prepare(JK& jk) {
jk << _CS("[Ti:") << cond->dtype(); jk << "«Ti:" << cond->dtype();
jk << _CS("][To:") << outs[0]->dtype(); jk << "«To:" << outs[0]->dtype();
jk << _CS("][NDIM=") << JK::hex1(cond->shape.size()); jk << "«NDIM=" << JK::hex1(cond->shape.size());
jk << ']';
} }
#else // JIT #else // JIT

View File

@ -19,16 +19,16 @@ JIT_TEST(jit_key) {
}); });
std::cerr << "get segfault, ok" << std::endl; std::cerr << "get segfault, ok" << std::endl;
jk << JK::key << "key" << JK::val << "value" << JK::end; jk << JK::key << "key" << JK::val << "value";
jk << JK::key << "key" << JK::val << JK::hex(0x123123) << JK::end; jk << JK::key << "key" << JK::val << JK::hex(0x123123);
jk << JK::key << "key" << JK::val << JK::hex1(0x123123) << JK::end; jk << JK::key << "key" << JK::val << JK::hex1(0x123123);
jk << JK::key << "key" << JK::val << JK::hex2(0x123123) << JK::end; jk << JK::key << "key" << JK::val << JK::hex2(0x123123);
jk << JK::key << "key" << JK::val << JK::Oxhex(0x123123) << JK::end; jk << JK::key << "key" << JK::val << JK::Oxhex(0x123123);
jk << JK::key << "key" << JK::val << JK::Oxhex1(0x123123) << JK::end; jk << JK::key << "key" << JK::val << JK::Oxhex1(0x123123);
jk << JK::key << "key" << JK::val << JK::Oxhex2(0x123123) << JK::end; jk << JK::key << "key" << JK::val << JK::Oxhex2(0x123123);
string key = "[key:value][key:123123][key:3][key:23][key:0x123123][key:0x3][key:0x23]"; string key = "«key:value«key:123123«key:3«key:23«key:0x123123«key:0x3«key:0x23";
ASSERTop(jk.to_string(),==,key); ASSERTop(jk.to_string(),==,key);
auto keys = parse_jit_keys("[a:11][b:22][a[3]:b::[x]][x=11][f=itof(0x0)]"); auto keys = parse_jit_keys("«a:11«b:22«a[3]:b::[x]«x=11«f=itof(0x0)");
vector<pair<string,string>> k2 = vector<pair<string,string>> k2 =
{{"a","11"},{"b","22"},{"a[3]","b::[x]"},{"x","17"},{"f","0"}}; {{"a","11"},{"b","22"},{"a[3]","b::[x]"},{"x","17"},{"f","0"}};
ASSERTop(keys,==,k2); ASSERTop(keys,==,k2);

View File

@ -94,7 +94,7 @@ JIT_TEST(fused_op_relay_matmul) {
auto allocator = get_allocator(); auto allocator = get_allocator();
for (auto& v : fop.vars) for (auto& v : fop.vars)
if (v.type!=1) v.var->alloc(allocator); if (v.type!=1) v.var->alloc(allocator);
auto entry = oc.compile("[OP:_fused_op_relay_matmul]", oc.src); auto entry = oc.compile("«OP:_fused_op_relay_matmul", oc.src);
for (uint i=0; i<a->num; i++) for (uint i=0; i<a->num; i++)
a->ptr<float>()[i] = b->ptr<float>()[i] = 1; a->ptr<float>()[i] = b->ptr<float>()[i] = 1;
entry(&fop); entry(&fop);

View File

@ -46,7 +46,7 @@ def test(shape, op1, op2):
with jt.log_capture_scope(log_v=0, log_vprefix="fused_op.cc=100") as logs: with jt.log_capture_scope(log_v=0, log_vprefix="fused_op.cc=100") as logs:
d__ = d.data d__ = d.data
logs = find_log_with_re(logs, logs = find_log_with_re(logs,
"Jit (fused )?op key (not )?found: \[opkey0:array\[T:float32") "Jit (fused )?op key (not )?found: «opkey0:array«T:float32")
assert(len(logs)==1), logs assert(len(logs)==1), logs
a_ = a.data a_ = a.data

View File

@ -81,7 +81,7 @@ def test_case(box_num, out_size, time_limit):
for i in range(1, len(rep)): for i in range(1, len(rep)):
t += float(rep[i][3]) / 1e9 t += float(rep[i][3]) / 1e9
name = rep[i][0] name = rep[i][0]
if name.startswith('[') and (not '[graph:]' in name): if name.startswith('«') and (not '«graph:«' in name):
fused_op_num += 1 fused_op_num += 1
assert fused_op_num == 1, fused_op_num assert fused_op_num == 1, fused_op_num
assert t <= time_limit, t assert t <= time_limit, t

View File

@ -428,6 +428,15 @@ class TestSetitem(unittest.TestCase):
np.arange(4)[None,::-1]] np.arange(4)[None,::-1]]
np.testing.assert_allclose(nb, b.data) np.testing.assert_allclose(nb, b.data)
def test_cuda_slice_migrate_bug(self):
a = jt.array([1,2,3,4,5])
jt.sync_all()
if not jt.has_cuda: return
with jt.flag_scope(use_cuda=1):
b = a[0]
b.sync(True)
assert b.item() == 1
if __name__ == "__main__": if __name__ == "__main__":