mirror of https://github.com/Jittor/Jittor
polish jit key
This commit is contained in:
parent
91fe1fac85
commit
c5ccdaf330
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.4.7'
|
||||
__version__ = '1.3.4.8'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -59,13 +59,13 @@ void CubArgReduceOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CubArgReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Toffsets:") << offsets->dtype();
|
||||
jk << _CS("][FUNC:");
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Toffsets:" << offsets->dtype();
|
||||
jk << "«FUNC:";
|
||||
if (op==ns_minimum)
|
||||
jk << _CS("ArgMin]");
|
||||
jk << "ArgMin";
|
||||
else
|
||||
jk << _CS("ArgMax]");
|
||||
jk << "ArgMax";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -51,15 +51,15 @@ void CubArgsortOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CubArgsortOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Tindexes:") << indexes->dtype();
|
||||
jk << _CS("][Toffsets:") << offsets->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][FUNC:");
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Tindexes:" << indexes->dtype();
|
||||
jk << "«Toffsets:" << offsets->dtype();
|
||||
jk << "«Ty:" << y->dtype();
|
||||
jk << "«FUNC:";
|
||||
if (descending)
|
||||
jk << _CS("SortPairsDescending]");
|
||||
jk << "SortPairsDescending";
|
||||
else
|
||||
jk << _CS("SortPairs]");
|
||||
jk << "SortPairs";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -38,10 +38,9 @@ void CubCumsumOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CubCumsumOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][reverse:") << reverse;
|
||||
jk << _CS("]");
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Ty:" << y->dtype();
|
||||
jk << "«reverse:" << reverse;
|
||||
}
|
||||
|
||||
VarPtr CubCumsumOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
|
|
|
@ -24,7 +24,7 @@ CubTestOp::CubTestOp(string cmd) : cmd(cmd) {
|
|||
}
|
||||
|
||||
void CubTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
jk << "«T:float32";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -40,10 +40,9 @@ void CubWhereOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CubWhereOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Ti:") << cond->dtype();
|
||||
jk << _CS("][To:") << outs[0]->dtype();
|
||||
jk << _CS("][NDIM=") << JK::hex1(cond->shape.size());
|
||||
jk << ']';
|
||||
jk << "«Ti:" << cond->dtype();
|
||||
jk << "«To:" << outs[0]->dtype();
|
||||
jk << "«NDIM=" << JK::hex1(cond->shape.size());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -54,11 +54,10 @@ void CublasAccMatmulOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CublasAccMatmulOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << a->dtype();
|
||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
|
||||
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||
jk << ']';
|
||||
jk << "«T:" << a->dtype();
|
||||
jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
|
||||
jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
|
||||
jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -89,11 +89,10 @@ void CublasBatchedMatmulOp::infer_shape(){
|
|||
}
|
||||
|
||||
void CublasBatchedMatmulOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << a->dtype();
|
||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
|
||||
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||
jk << ']';
|
||||
jk << "«T:" << a->dtype();
|
||||
jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
|
||||
jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
|
||||
jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -50,11 +50,10 @@ void CublasMatmulOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CublasMatmulOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << a->dtype();
|
||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
|
||||
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||
jk << ']';
|
||||
jk << "«T:" << a->dtype();
|
||||
jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
|
||||
jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
|
||||
jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -19,7 +19,7 @@ CublasTestOp::CublasTestOp(int size_mult) : size_mult(size_mult) {
|
|||
}
|
||||
|
||||
void CublasTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
jk << "«T:float32";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -55,10 +55,9 @@ void CudnnConv3dBackwardWOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << dy->dtype();
|
||||
jk << _CS("][Tw:") << dw->dtype();
|
||||
jk << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Ty:" << dy->dtype();
|
||||
jk << "«Tw:" << dw->dtype();
|
||||
}
|
||||
|
||||
static auto make_conv3d = get_op_info("cudnn_conv3d")
|
||||
|
|
|
@ -52,10 +52,9 @@ void CudnnConv3dBackwardXOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << dx->dtype();
|
||||
jk << _CS("][Ty:") << dy->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << ']';
|
||||
jk << "«Tx:" << dx->dtype();
|
||||
jk << "«Ty:" << dy->dtype();
|
||||
jk << "«Tw:" << w->dtype();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -53,10 +53,9 @@ void CudnnConv3dOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CudnnConv3dOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Ty:" << y->dtype();
|
||||
jk << "«Tw:" << w->dtype();
|
||||
}
|
||||
|
||||
static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x")
|
||||
|
|
|
@ -71,13 +71,12 @@ void CudnnConvBackwardWOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CudnnConvBackwardWOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << dy->dtype();
|
||||
jk << _CS("][Tw:") << dw->dtype();
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Ty:" << dy->dtype();
|
||||
jk << "«Tw:" << dw->dtype();
|
||||
jk << "«XFORMAT:" << xformat;
|
||||
jk << "«WFORMAT:" << wformat;
|
||||
jk << "«YFORMAT:" << yformat;
|
||||
}
|
||||
|
||||
static auto make_conv = get_op_info("cudnn_conv")
|
||||
|
|
|
@ -70,13 +70,12 @@ void CudnnConvBackwardXOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CudnnConvBackwardXOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << dx->dtype();
|
||||
jk << _CS("][Ty:") << dy->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
jk << "«Tx:" << dx->dtype();
|
||||
jk << "«Ty:" << dy->dtype();
|
||||
jk << "«Tw:" << w->dtype();
|
||||
jk << "«XFORMAT:" << xformat;
|
||||
jk << "«WFORMAT:" << wformat;
|
||||
jk << "«YFORMAT:" << yformat;
|
||||
}
|
||||
|
||||
static auto make_conv = get_op_info("cudnn_conv")
|
||||
|
|
|
@ -72,13 +72,12 @@ void CudnnConvOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CudnnConvOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Ty:" << y->dtype();
|
||||
jk << "«Tw:" << w->dtype();
|
||||
jk << "«XFORMAT:" << xformat;
|
||||
jk << "«WFORMAT:" << wformat;
|
||||
jk << "«YFORMAT:" << yformat;
|
||||
}
|
||||
static auto make_backwardx = get_op_info("cudnn_conv_backward_x")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
|
||||
|
|
|
@ -79,10 +79,9 @@ void CudnnRnnBackwardXOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CudnnRnnBackwardXOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << hx->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << ']';
|
||||
jk << "«Tx:" << hx->dtype();
|
||||
jk << "«Ty:" << y->dtype();
|
||||
jk << "«Tw:" << w->dtype();
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -98,10 +98,9 @@ void CudnnRnnOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CudnnRnnOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Ty:" << y->dtype();
|
||||
jk << "«Tw:" << w->dtype();
|
||||
}
|
||||
|
||||
static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x")
|
||||
|
|
|
@ -20,7 +20,7 @@ CudnnTestOp::CudnnTestOp(string cmd) : cmd(move(cmd)) {
|
|||
}
|
||||
|
||||
void CudnnTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
jk << "«T:float32";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -44,9 +44,9 @@ void CufftFftOp::jit_prepare(JK& jk) {
|
|||
printf("not supported fft dtype: %s\n", y->dtype().to_cstring());
|
||||
ASSERT(false);
|
||||
}
|
||||
jk << _CS("[T:") << y->dtype();
|
||||
jk << _CS("][I:")<<inverse<<"]";
|
||||
jk << _CS("[TS:\"")<<y->dtype()<<"\"]";
|
||||
jk << "«T:" << y->dtype();
|
||||
jk << "«I:"<<inverse<<"]";
|
||||
jk << _CS("«TS:\"")<<y->dtype()<<"\"]";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -25,8 +25,8 @@ CurandRandomOp::CurandRandomOp(NanoVector shape, NanoString dtype, NanoString ty
|
|||
}
|
||||
|
||||
void CurandRandomOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << output->dtype();
|
||||
jk << _CS("][R:") << type << ']';
|
||||
jk << "«T:" << output->dtype();
|
||||
jk << "«R:" << type;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -21,7 +21,7 @@ CuttTestOp::CuttTestOp(string cmd) : cmd(cmd) {
|
|||
}
|
||||
|
||||
void CuttTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
jk << "«T:float32";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -59,7 +59,7 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
|
||||
void CuttTransposeOp::jit_prepare(JK& jk) {
|
||||
// do nothing
|
||||
jk << _CS("[T:1]");
|
||||
jk << "«T:1";
|
||||
}
|
||||
|
||||
unordered_map<string, unsigned int> cutt_plan_cache;
|
||||
|
|
|
@ -37,7 +37,7 @@ VarPtr NcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void NcclAllReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -35,7 +35,7 @@ VarPtr NcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void NcclBroadcastOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -35,7 +35,7 @@ VarPtr NcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void NcclReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -20,7 +20,7 @@ NcclTestOp::NcclTestOp(string cmd) : cmd(cmd) {
|
|||
}
|
||||
|
||||
void NcclTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
jk << "«T:float32";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -79,16 +79,15 @@ static const char* short_type(Var* x) {
|
|||
}
|
||||
|
||||
void MklConvBackwardWOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Txd:") << x->dtype();
|
||||
jk << _CS("][Tyd:") << dy->dtype();
|
||||
jk << _CS("][Twd:") << dw->dtype();
|
||||
jk << _CS("][Tx:") << short_type(x);
|
||||
jk << _CS("][Tw:") << short_type(dw);
|
||||
jk << _CS("][Ty:") << short_type(dy);
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
jk << "«Txd:" << x->dtype();
|
||||
jk << "«Tyd:" << dy->dtype();
|
||||
jk << "«Twd:" << dw->dtype();
|
||||
jk << "«Tx:" << short_type(x);
|
||||
jk << "«Tw:" << short_type(dw);
|
||||
jk << "«Ty:" << short_type(dy);
|
||||
jk << "«XFORMAT:" << xformat;
|
||||
jk << "«WFORMAT:" << wformat;
|
||||
jk << "«YFORMAT:" << yformat;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -77,16 +77,15 @@ static const char* short_type(Var* x) {
|
|||
}
|
||||
|
||||
void MklConvBackwardXOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tyd:") << dy->dtype();
|
||||
jk << _CS("][Twd:") << w->dtype();
|
||||
jk << _CS("][Txd:") << dx->dtype();
|
||||
jk << _CS("][Tx:") << short_type(dx);
|
||||
jk << _CS("][Tw:") << short_type(w);
|
||||
jk << _CS("][Ty:") << short_type(dy);
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
jk << "«Tyd:" << dy->dtype();
|
||||
jk << "«Twd:" << w->dtype();
|
||||
jk << "«Txd:" << dx->dtype();
|
||||
jk << "«Tx:" << short_type(dx);
|
||||
jk << "«Tw:" << short_type(w);
|
||||
jk << "«Ty:" << short_type(dy);
|
||||
jk << "«XFORMAT:" << xformat;
|
||||
jk << "«WFORMAT:" << wformat;
|
||||
jk << "«YFORMAT:" << yformat;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -81,16 +81,15 @@ static const char* short_type(Var* x) {
|
|||
}
|
||||
|
||||
void MklConvOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Txd:") << x->dtype();
|
||||
jk << _CS("][Tyd:") << y->dtype();
|
||||
jk << _CS("][Twd:") << w->dtype();
|
||||
jk << _CS("][Tx:") << short_type(x);
|
||||
jk << _CS("][Tw:") << short_type(w);
|
||||
jk << _CS("][Ty:") << short_type(y);
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
jk << "«Txd:" << x->dtype();
|
||||
jk << "«Tyd:" << y->dtype();
|
||||
jk << "«Twd:" << w->dtype();
|
||||
jk << "«Tx:" << short_type(x);
|
||||
jk << "«Tw:" << short_type(w);
|
||||
jk << "«Ty:" << short_type(y);
|
||||
jk << "«XFORMAT:" << xformat;
|
||||
jk << "«WFORMAT:" << wformat;
|
||||
jk << "«YFORMAT:" << yformat;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -44,9 +44,9 @@ void MklMatmulOp::infer_shape() {
|
|||
}
|
||||
|
||||
void MklMatmulOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << a->dtype();
|
||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N') << ']';
|
||||
jk << "«T:" << a->dtype();
|
||||
jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
|
||||
jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -19,7 +19,7 @@ MklTestOp::MklTestOp() {
|
|||
}
|
||||
|
||||
void MklTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
jk << "«T:float32";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -62,8 +62,8 @@ VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void MpiAllReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][OP:") << op << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«OP:" << op;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -49,7 +49,7 @@ VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void MpiBroadcastOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -62,8 +62,8 @@ VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void MpiReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][OP:") << op << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«OP:" << op;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -17,7 +17,7 @@ MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) {
|
|||
}
|
||||
|
||||
void MpiTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
jk << "«T:float32";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -585,6 +585,10 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
|||
if (!v->allocator->is_cuda())
|
||||
migrate_to_gpu(v, allocator);
|
||||
}
|
||||
for (Var* v : op->outputs()) {
|
||||
if (!v->allocator->is_cuda())
|
||||
migrate_to_gpu(v, allocator);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef NODE_MEMCHECK
|
||||
|
|
|
@ -159,55 +159,51 @@ void FusedOp::do_jit_prepare(JK& jk) {
|
|||
jk.clear();
|
||||
for (uint i=0; i<ops.size(); i++) {
|
||||
Op* op = ops[i];
|
||||
jk << "[opkey" << i << JK::val;
|
||||
jk << "«opkey" << i << JK::val;
|
||||
jk << op->name();
|
||||
op->jit_prepare(jk);
|
||||
jk << JK::end;
|
||||
}
|
||||
jk << _CS("[JIT:1]");
|
||||
jk << "«JIT:1";
|
||||
if (!use_cuda) {
|
||||
// only cpu
|
||||
jk << _CS("[JIT_cpu:1]");
|
||||
jk << "«JIT_cpu:1";
|
||||
this->flags.set(NodeFlags::_cuda, 0);
|
||||
this->flags.set(NodeFlags::_cpu, 1);
|
||||
} else {
|
||||
jk << _CS("[JIT_cuda:1]");
|
||||
jk << "«JIT_cuda:1";
|
||||
this->flags.set(NodeFlags::_cpu, 0);
|
||||
this->flags.set(NodeFlags::_cuda, 1);
|
||||
}
|
||||
jk << _CS("[graph:");
|
||||
jk << "«graph:";
|
||||
for (auto& t : edges) {
|
||||
uint i,j,k,l;
|
||||
std::tie(i,j,k,l) = t;
|
||||
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
|
||||
}
|
||||
jk << _CS("][var_info:") << JK::val;
|
||||
jk << "«var_info:" << JK::val;
|
||||
bool use_int64_t = false;
|
||||
for (auto& vi : vars) {
|
||||
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
|
||||
if (vi.type != 1 && vi.var->num >= std::numeric_limits<int32_t>::max())
|
||||
use_int64_t = true;
|
||||
}
|
||||
jk << JK::end;
|
||||
if (use_int64_t)
|
||||
jk << _CS("[index_t:int64]");
|
||||
jk << "«index_t:int64";
|
||||
else
|
||||
jk << _CS("[index_t:int32]");
|
||||
jk << "«index_t:int32";
|
||||
if (loop_options->size()) {
|
||||
if (get_loop_option("compile_shapes")) {
|
||||
jk << _CS("[shapes:");
|
||||
jk << "«shapes:";
|
||||
for (auto& vi : vars) {
|
||||
jk << '[';
|
||||
for (auto a : vi.var->shape)
|
||||
jk << a << ',';
|
||||
jk << _CS("],");
|
||||
jk << "],";
|
||||
}
|
||||
jk << JK::end;
|
||||
}
|
||||
jk << _CS("[choices:");
|
||||
jk << "«choices:";
|
||||
for (auto& kv : *loop_options)
|
||||
jk << kv.first << ':' << kv.second << ',';
|
||||
jk << JK::end;
|
||||
}
|
||||
jk.finilize();
|
||||
}
|
||||
|
|
|
@ -194,7 +194,7 @@ void run_cmd(string cmd, string cwd="") {
|
|||
|
||||
static string get_symbol_name(const string& jit_key) {
|
||||
int i=0;
|
||||
while (i<jit_key.size() && jit_key[i]!='[') i++;
|
||||
while (i<jit_key.size() && jit_key[i]>=0) i++;
|
||||
string op_name = i ? jit_key.substr(0, i) : "fused";
|
||||
op_name = Op::file_name_to_class_name(op_name);
|
||||
// _ZN7jittorXyyyyyy7jit_runEv
|
||||
|
|
|
@ -80,42 +80,26 @@ static void convert_itof(string& s) {
|
|||
|
||||
vector<pair<string,string>> parse_jit_keys(const string& s) {
|
||||
vector<pair<string,string>> jit_keys;
|
||||
int presum = 0;
|
||||
char state=0;
|
||||
auto sp = split(s, JitKey::key);
|
||||
for (auto& ss : sp) {
|
||||
if (!ss.size()) continue;
|
||||
string key, val;
|
||||
for (char c : s) {
|
||||
if (c==JK::key) {
|
||||
presum++;
|
||||
if (presum==1) {
|
||||
char state=0;
|
||||
for (auto c : ss) {
|
||||
if (state == 0 &&
|
||||
(c==JK::val || c==JK::hex_val)) {
|
||||
state = c;
|
||||
continue;
|
||||
}
|
||||
} else
|
||||
if (c==JK::val || c==JK::hex_val) {
|
||||
if (presum==1 && state==JK::key) {
|
||||
state = c;
|
||||
continue;
|
||||
if (state == 0) key += c;
|
||||
else val += c;
|
||||
}
|
||||
} else
|
||||
if (c==JK::end) {
|
||||
presum--;
|
||||
if (presum==0) {
|
||||
if (state == JK::hex_val)
|
||||
hex_to_dec(val);
|
||||
if (startswith(val, "itof"))
|
||||
convert_itof(val);
|
||||
jit_keys.emplace_back(move(key), move(val));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (presum) {
|
||||
if (state==JK::key)
|
||||
key += c;
|
||||
if (state==JK::val || state==JK::hex_val)
|
||||
val += c;
|
||||
}
|
||||
}
|
||||
ASSERT(presum==0) << s;
|
||||
return jit_keys;
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <cstring>
|
||||
#include "common.h"
|
||||
#include "misc/nano_string.h"
|
||||
#include "misc/nano_vector.h"
|
||||
|
@ -13,11 +14,10 @@ namespace jittor {
|
|||
|
||||
struct JitKey {
|
||||
static constexpr size_t buffer_size = 2*1024*1024;
|
||||
static constexpr char
|
||||
key = '[',
|
||||
static constexpr const char
|
||||
*key = "«",
|
||||
val = ':',
|
||||
hex_val = '=',
|
||||
end = ']';
|
||||
hex_val = '=';
|
||||
int64 size=0;
|
||||
uint64 flags=0;
|
||||
char buffer[buffer_size];
|
||||
|
@ -27,7 +27,7 @@ struct JitKey {
|
|||
|
||||
inline void clear() {size = flags = 0;}
|
||||
inline void finilize() { buffer[size] = 0; }
|
||||
inline bool empty() { return buffer[size-1] != end; }
|
||||
inline bool empty() { return !size; }
|
||||
inline const char* to_cstring() {
|
||||
return &buffer[0];
|
||||
}
|
||||
|
@ -81,11 +81,38 @@ struct __jk_int256 {
|
|||
typedef JitKey JK;
|
||||
EXTERN_LIB JK& get_jk();
|
||||
|
||||
inline void jk_put_str_with_len(JK& jk, const char* a, int n) {
|
||||
char* xx = &jk.buffer[jk.size];
|
||||
int i=0;
|
||||
while (i+32<=n) {
|
||||
((__jk_int256*)(xx+i))[0] = ((const __jk_int256*)(a+i))[0];
|
||||
i+=32;
|
||||
}
|
||||
while (i+16<=n) {
|
||||
((__jk_int128*)(xx+i))[0] = ((const __jk_int128*)(a+i))[0];
|
||||
i+=16;
|
||||
}
|
||||
while (i+8<=n) {
|
||||
((long long*)(xx+i))[0] = ((const long long*)(a+i))[0];
|
||||
i+=8;
|
||||
}
|
||||
while (i+4<=n) {
|
||||
((int*)(xx+i))[0] = ((const int*)(a+i))[0];
|
||||
i+=4;
|
||||
}
|
||||
while (i+2<=n) {
|
||||
((int16_t*)(xx+i))[0] = ((const int16_t*)(a+i))[0];
|
||||
i+=2;
|
||||
}
|
||||
while (i+1<=n) {
|
||||
((char*)(xx+i))[0] = ((const char*)(a+i))[0];
|
||||
i+=1;
|
||||
}
|
||||
jk.size += n;
|
||||
}
|
||||
|
||||
inline JK& operator<<(JK& jk, const char* s) {
|
||||
int i;
|
||||
for (i=0; s[i]; i++)
|
||||
jk.buffer[jk.size+i] = s[i];
|
||||
jk.size += i;
|
||||
jk_put_str_with_len(jk, s, strlen(s));
|
||||
return jk;
|
||||
}
|
||||
|
||||
|
@ -199,129 +226,133 @@ vector<pair<string,string>> parse_jit_keys(const string& s);
|
|||
|
||||
template <class Ta, class Tb>
|
||||
void add_jit_define(JK& jk, const Ta& key, const Tb& val) {
|
||||
jk << JK::key << key << JK::val << val << JK::end;
|
||||
jk << JK::key << key << JK::val << val;
|
||||
}
|
||||
|
||||
template <class Ta, class Tb, class Tc>
|
||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const Tc& val) {
|
||||
jk << JK::key << key << i << JK::val << val << JK::end;
|
||||
jk << JK::key << key << i << JK::val << val;
|
||||
}
|
||||
|
||||
|
||||
template <class Ta>
|
||||
void add_jit_define(JK& jk, const Ta& key, const JK::hex& val) {
|
||||
jk << JK::key << key << JK::hex_val << val << JK::end;
|
||||
jk << JK::key << key << JK::hex_val << val;
|
||||
}
|
||||
|
||||
template <class Ta, class Tb>
|
||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex& val) {
|
||||
jk << JK::key << key << i << JK::hex_val << val << JK::end;
|
||||
jk << JK::key << key << i << JK::hex_val << val;
|
||||
}
|
||||
|
||||
template <class Ta>
|
||||
void add_jit_define(JK& jk, const Ta& key, const JK::hex1& val) {
|
||||
jk << JK::key << key << JK::hex_val << val << JK::end;
|
||||
jk << JK::key << key << JK::hex_val << val;
|
||||
}
|
||||
|
||||
template <class Ta, class Tb>
|
||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex1& val) {
|
||||
jk << JK::key << key << i << JK::hex_val << val << JK::end;
|
||||
jk << JK::key << key << i << JK::hex_val << val;
|
||||
}
|
||||
|
||||
template <class Ta>
|
||||
void add_jit_define(JK& jk, const Ta& key, const JK::hex2& val) {
|
||||
jk << JK::key << key << JK::hex_val << val << JK::end;
|
||||
jk << JK::key << key << JK::hex_val << val;
|
||||
}
|
||||
|
||||
template <class Ta, class Tb>
|
||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex2& val) {
|
||||
jk << JK::key << key << i << JK::hex_val << val << JK::end;
|
||||
jk << JK::key << key << i << JK::hex_val << val;
|
||||
}
|
||||
|
||||
#define _CS(x) x
|
||||
// // begin of const string
|
||||
// #define MAX_CONST_CHAR 32
|
||||
|
||||
// begin of const string
|
||||
#define MAX_CONST_CHAR 32
|
||||
// #define _CS_MIN(a,b) (a)<(b)?(a):(b)
|
||||
|
||||
#define _CS_MIN(a,b) (a)<(b)?(a):(b)
|
||||
// #define _CS_T(s)\
|
||||
// getChr(s,0),\
|
||||
// getChr(s,1),\
|
||||
// getChr(s,2),\
|
||||
// getChr(s,3),\
|
||||
// getChr(s,4),\
|
||||
// getChr(s,5),\
|
||||
// getChr(s,6),\
|
||||
// getChr(s,7),\
|
||||
// getChr(s,8),\
|
||||
// getChr(s,9),\
|
||||
// getChr(s,10),\
|
||||
// getChr(s,11),\
|
||||
// getChr(s,12),\
|
||||
// getChr(s,13),\
|
||||
// getChr(s,14),\
|
||||
// getChr(s,15),\
|
||||
// getChr(s,16),\
|
||||
// getChr(s,17),\
|
||||
// getChr(s,18),\
|
||||
// getChr(s,19),\
|
||||
// getChr(s,20),\
|
||||
// getChr(s,21),\
|
||||
// getChr(s,22),\
|
||||
// getChr(s,23),\
|
||||
// getChr(s,24),\
|
||||
// getChr(s,25),\
|
||||
// getChr(s,26),\
|
||||
// getChr(s,27),\
|
||||
// getChr(s,28),\
|
||||
// getChr(s,29),\
|
||||
// getChr(s,30),\
|
||||
// getChr(s,31),\
|
||||
// getChr(s,32),\
|
||||
// getChr(s,33),\
|
||||
// getChr(s,34),\
|
||||
// getChr(s,35)
|
||||
|
||||
#define _CS_T(s)\
|
||||
getChr(s,0),\
|
||||
getChr(s,1),\
|
||||
getChr(s,2),\
|
||||
getChr(s,3),\
|
||||
getChr(s,4),\
|
||||
getChr(s,5),\
|
||||
getChr(s,6),\
|
||||
getChr(s,7),\
|
||||
getChr(s,8),\
|
||||
getChr(s,9),\
|
||||
getChr(s,10),\
|
||||
getChr(s,11),\
|
||||
getChr(s,12),\
|
||||
getChr(s,13),\
|
||||
getChr(s,14),\
|
||||
getChr(s,15),\
|
||||
getChr(s,16),\
|
||||
getChr(s,17),\
|
||||
getChr(s,18),\
|
||||
getChr(s,19),\
|
||||
getChr(s,20),\
|
||||
getChr(s,21),\
|
||||
getChr(s,22),\
|
||||
getChr(s,23),\
|
||||
getChr(s,24),\
|
||||
getChr(s,25),\
|
||||
getChr(s,26),\
|
||||
getChr(s,27),\
|
||||
getChr(s,28),\
|
||||
getChr(s,29),\
|
||||
getChr(s,30),\
|
||||
getChr(s,31),\
|
||||
getChr(s,32),\
|
||||
getChr(s,33),\
|
||||
getChr(s,34),\
|
||||
getChr(s,35)
|
||||
// #define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0)
|
||||
|
||||
#define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0)
|
||||
// #ifdef _MSC_VER
|
||||
// #define _CS(str) str
|
||||
// #else
|
||||
// #define _CS(str) _CS_G<_CS_T(str)>()
|
||||
// #endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define _CS(str) str
|
||||
#else
|
||||
#define _CS(str) _CS_G<_CS_T(str)>()
|
||||
#endif
|
||||
// template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G {
|
||||
// };
|
||||
|
||||
template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G {
|
||||
};
|
||||
// template<> struct _CS_G<0,0,0,0> {};
|
||||
|
||||
template<> struct _CS_G<0,0,0,0> {};
|
||||
// template <char c1, char c2, char c3, char c4, char... Chars_>
|
||||
// inline JK& operator<<(JK& jk, const _CS_G<c1,c2,c3,c4,Chars_...>& _) {
|
||||
// ((uint32*)(jk.buffer+jk.size))[0] =
|
||||
// (uint32((uint8)(c4))<<24)+
|
||||
// (uint32((uint8)(c3))<<16)+
|
||||
// (uint32((uint8)(c2))<<8)+
|
||||
// uint32((uint8)(c1));
|
||||
// if (c4) {
|
||||
// jk.size += 4;
|
||||
// jk << _CS_G<Chars_...>();
|
||||
// } else
|
||||
// if (c3) {
|
||||
// jk.size += 3;
|
||||
// } else
|
||||
// if (c2) {
|
||||
// jk.size += 2;
|
||||
// } else
|
||||
// if (c1) {
|
||||
// jk.size += 1;
|
||||
// }
|
||||
// return jk;
|
||||
// }
|
||||
|
||||
template <char c1, char c2, char c3, char c4, char... Chars_>
|
||||
inline JK& operator<<(JK& jk, const _CS_G<c1,c2,c3,c4,Chars_...>& _) {
|
||||
((int*)(jk.buffer+jk.size))[0] = c4*(1<<24)+c3*(1<<16)+c2*(1<<8)+c1;
|
||||
if (c4) {
|
||||
jk.size += 4;
|
||||
jk << _CS_G<Chars_...>();
|
||||
} else
|
||||
if (c3) {
|
||||
jk.size += 3;
|
||||
} else
|
||||
if (c2) {
|
||||
jk.size += 2;
|
||||
} else
|
||||
if (c1) {
|
||||
jk.size += 1;
|
||||
}
|
||||
return jk;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline JK& operator<<(JK& jk, const _CS_G<0,0,0,0>& _) {
|
||||
return jk;
|
||||
}
|
||||
// template <>
|
||||
// inline JK& operator<<(JK& jk, const _CS_G<0,0,0,0>& _) {
|
||||
// return jk;
|
||||
// }
|
||||
|
||||
|
||||
inline JK& operator<<(JK& jk, float64 f) {
|
||||
return jk << _CS("itof(0x") << JK::hex(ftoi(f)) << ')';
|
||||
return jk << "itof(0x" << JK::hex(ftoi(f)) << ')';
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -128,14 +128,16 @@ string Op::get_hash_name() {
|
|||
void Op::do_jit_prepare(JK& jk) {
|
||||
memcheck_all_exist();
|
||||
jk << name();
|
||||
auto pre_size = jk.size;
|
||||
jit_prepare(jk);
|
||||
if (jk.empty()) {
|
||||
if (jk.size == pre_size) {
|
||||
// not a jit op
|
||||
bool has_cuda = flags.get(NodeFlags::_cuda);
|
||||
bool has_cpu = flags.get(NodeFlags::_cpu);
|
||||
CHECK(has_cuda || has_cpu);
|
||||
if (has_cuda && has_cpu && !use_cuda)
|
||||
flags.set(NodeFlags::_cuda, 0);
|
||||
jk.clear();
|
||||
} else {
|
||||
bool use_int64_t = false;
|
||||
// TODO: fused op do not have inputs,
|
||||
|
@ -149,9 +151,9 @@ void Op::do_jit_prepare(JK& jk) {
|
|||
if (var->num >= std::numeric_limits<int32_t>::max())
|
||||
use_int64_t = true;
|
||||
}
|
||||
jk << _CS("[JIT:1]");
|
||||
jk << "«JIT:1";
|
||||
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
|
||||
jk << _CS("[JIT_cuda:1]");
|
||||
jk << "«JIT_cuda:1";
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
// TODO: 64bit index in CUDA
|
||||
// use_int64_t = false;
|
||||
|
@ -164,14 +166,14 @@ void Op::do_jit_prepare(JK& jk) {
|
|||
}
|
||||
ASSERT(flags.get(NodeFlags::_cpu))
|
||||
<< "Op" << name() << "doesn't have cpu version";
|
||||
jk << _CS("[JIT_cpu:1]");
|
||||
jk << "«JIT_cpu:1";
|
||||
flags.set(NodeFlags::_cuda, 0);
|
||||
}
|
||||
if (try_use_32bit_index) use_int64_t = false;
|
||||
if (use_int64_t)
|
||||
jk << _CS("[index_t:int64]");
|
||||
jk << "«index_t:int64";
|
||||
else
|
||||
jk << _CS("[index_t:int32]");
|
||||
jk << "«index_t:int32";
|
||||
}
|
||||
jk.finilize();
|
||||
}
|
||||
|
|
|
@ -158,14 +158,13 @@ void ArgReduceOp::infer_shape() {
|
|||
}
|
||||
|
||||
void ArgReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][XDIM=") << JK::hex1(x->shape.size());
|
||||
jk << _CS("][YDIM=") << JK::hex1(y->shape.size());
|
||||
jk << _CS("][KEEPDIMS:") << (keepdims ? '1' : '0');
|
||||
jk << _CS("][DIM=") << JK::hex1(dim);
|
||||
jk << _CS("][CMP:") << (op==ns_minimum ? "<" : ">");
|
||||
jk << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Ty:" << y->dtype();
|
||||
jk << "«XDIM=" << JK::hex1(x->shape.size());
|
||||
jk << "«YDIM=" << JK::hex1(y->shape.size());
|
||||
jk << "«KEEPDIMS:" << (keepdims ? '1' : '0');
|
||||
jk << "«DIM=" << JK::hex1(dim);
|
||||
jk << "«CMP:" << (op==ns_minimum ? "<" : ">");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -130,12 +130,11 @@ void ArgsortOp::infer_shape() {
|
|||
}
|
||||
|
||||
void ArgsortOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][XDIM=") << JK::hex1(x->shape.size());
|
||||
jk << _CS("][DIM=") << JK::hex1(dim);
|
||||
jk << _CS("][CMP:") << (descending ? '>' : '<');
|
||||
jk << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Ty:" << y->dtype();
|
||||
jk << "«XDIM=" << JK::hex1(x->shape.size());
|
||||
jk << "«DIM=" << JK::hex1(dim);
|
||||
jk << "«CMP:" << (descending ? '>' : '<');
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -78,7 +78,7 @@ ArrayOp::ArrayOp(ArrayArgs&& args) {
|
|||
|
||||
void ArrayOp::jit_prepare(JK& jk) {
|
||||
if (output->flags.get(NodeFlags::_force_fuse)) {
|
||||
jk << _CS("[T:") << output->dtype() << ']';
|
||||
jk << "«T:" << output->dtype();
|
||||
|
||||
// fill or find cbuffer for const var pass
|
||||
if (output->dtype().dsize() == 4) {
|
||||
|
@ -86,7 +86,7 @@ void ArrayOp::jit_prepare(JK& jk) {
|
|||
auto y = std::abs(ptr<float32>()[0]);
|
||||
auto z = ptr<uint32>()[0];
|
||||
if ((x<=2) || (y==1.0f || y==2.0f))
|
||||
jk << _CS("[o:") << z << ']';
|
||||
jk << "«o:" << z;
|
||||
}
|
||||
// end of fill cbuffer
|
||||
}
|
||||
|
|
|
@ -540,10 +540,10 @@ void BinaryOp::infer_shape() {
|
|||
}
|
||||
|
||||
void BinaryOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][Ty:") << y->dtype()
|
||||
<< _CS("][Tz:") << z->dtype()
|
||||
<< _CS("][OP:") << ns << ']';
|
||||
jk << "«Tx:" << x->dtype()
|
||||
<< "«Ty:" << y->dtype()
|
||||
<< "«Tz:" << z->dtype()
|
||||
<< "«OP:" << ns;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -167,9 +167,9 @@ void BroadcastToOp::infer_shape() {
|
|||
}
|
||||
|
||||
void BroadcastToOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][DIM=") << JK::hex1(z->shape.size())
|
||||
<< _CS("][BCAST=") << JK::hex(bcast_mask) << ']';
|
||||
jk << "«Tx:" << x->dtype()
|
||||
<< "«DIM=" << JK::hex1(z->shape.size())
|
||||
<< "«BCAST=" << JK::hex(bcast_mask);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -25,10 +25,10 @@ void CandidateOp::infer_shape() {
|
|||
}
|
||||
|
||||
void CandidateOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][FUNC:") << fail_cond;
|
||||
jk << _CS("][XDIM=") << JK::hex1(x->shape.size()) << ']';
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Ty:" << y->dtype();
|
||||
jk << "«FUNC:" << fail_cond;
|
||||
jk << "«XDIM=" << JK::hex1(x->shape.size());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -123,18 +123,18 @@ void CodeOp::jit_prepare(JK& jk) {
|
|||
|
||||
// forward: in0 in1 in2 -> out0 out1
|
||||
// backward: in0 in1 in2 in3(pout0) in4(pout1)
|
||||
jk << _CS("[IN_SIZE=") << JK::hex(_inputs.size());
|
||||
jk << "«IN_SIZE=" << JK::hex(_inputs.size());
|
||||
for (uint i=0; i<_inputs.size(); i++) {
|
||||
jk << _CS("][in") << JK::hex(i) << _CS("_dim=")
|
||||
jk << "«in" << JK::hex(i) << "_dim="
|
||||
<< JK::hex1(_inputs[i]->shape.size());
|
||||
jk << _CS("][in") << JK::hex(i) << _CS("_type:")
|
||||
jk << "«in" << JK::hex(i) << "_type:"
|
||||
<< _inputs[i]->dtype();
|
||||
}
|
||||
jk << _CS("][OUT_SIZE=") << JK::hex(_outputs.size());
|
||||
jk << "«OUT_SIZE=" << JK::hex(_outputs.size());
|
||||
for (uint i=0; i<_outputs.size(); i++) {
|
||||
jk << _CS("][out") << JK::hex(i) << _CS("_dim=")
|
||||
jk << "«out" << JK::hex(i) << "_dim="
|
||||
<< JK::hex1(_outputs[i]->shape.size());
|
||||
jk << _CS("][out") << JK::hex(i) << _CS("_type:")
|
||||
jk << "«out" << JK::hex(i) << "_type:"
|
||||
<< _outputs[i]->dtype();
|
||||
}
|
||||
string& header = flags.get(NodeFlags::_cuda) ?
|
||||
|
@ -142,9 +142,9 @@ void CodeOp::jit_prepare(JK& jk) {
|
|||
string& src = flags.get(NodeFlags::_cuda) ?
|
||||
cuda_src : cpu_src;
|
||||
|
||||
jk << _CS("][HEADER:") << header;
|
||||
jk << "«HEADER:" << header;
|
||||
CHECK(src.size());
|
||||
jk << _CS("\nnamespace jittor {\n");
|
||||
jk << "\nnamespace jittor {\n";
|
||||
int i=0;
|
||||
// move cuda kernel function into header
|
||||
for (; i<src.size(); i++) {
|
||||
|
@ -165,9 +165,8 @@ void CodeOp::jit_prepare(JK& jk) {
|
|||
}
|
||||
} else break;
|
||||
}
|
||||
jk << _CS("}][CODE:");
|
||||
jk << "}«CODE:";
|
||||
for (; i<src.size(); i++) jk << src[i];
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -74,12 +74,11 @@ VarPtr FuseTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
void FuseTransposeOp::jit_prepare(JK& jk) {
|
||||
auto bc = type()==OpType::broadcast;
|
||||
auto ax = bc ? axes : get_reverse(axes);
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][DIM=") << JK::hex1(axes.size());
|
||||
jk << _CS("][BC:") << JK::hex1(bc);
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«DIM=" << JK::hex1(axes.size());
|
||||
jk << "«BC:" << JK::hex1(bc);
|
||||
for (uint i=0; i<ax.size(); i++)
|
||||
jk << _CS("][AXES") << JK::hex1(ax[i]) << '=' << JK::hex1(i);
|
||||
jk << ']';
|
||||
jk << "«AXES" << JK::hex1(ax[i]) << '=' << JK::hex1(i);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -444,26 +444,26 @@ void GetitemOp::grads(Var** dout, VarPtr* dins) {
|
|||
void GetitemOp::jit_prepare(JK& jk) {
|
||||
auto in = inputs().front();
|
||||
int idim = i_to_vs.size();
|
||||
jk << _CS("[Ti:") << in->dtype();
|
||||
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size());
|
||||
jk << _CS("][ODIM=") << JK::hex1(o_shape.size());
|
||||
jk << "«Ti:" << in->dtype();
|
||||
jk << "«IDIM=" << JK::hex1(i_to_vs.size());
|
||||
jk << "«ODIM=" << JK::hex1(o_shape.size());
|
||||
if (first_oid_of_var>=0) {
|
||||
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var);
|
||||
jk << _CS("][VD=") << JK::hex1(var_dim);
|
||||
jk << "«FOV=" << JK::hex1(first_oid_of_var);
|
||||
jk << "«VD=" << JK::hex1(var_dim);
|
||||
}
|
||||
for (int i=0; i<idim; i++) {
|
||||
auto iv = i_to_vs[i];
|
||||
auto io = i_to_o[i];
|
||||
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv);
|
||||
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io);
|
||||
jk << "«IV" << JK::hex1(i) << ':' << JK::shex1(iv);
|
||||
jk << "«IO" << JK::hex1(i) << ':' << JK::shex1(io);
|
||||
auto& v = vs.slices[iv];
|
||||
if (iv>=0 && io==-1) {
|
||||
if (v.is_int()) {
|
||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
|
||||
jk << "«VS" << JK::hex1(i) << ":-1";
|
||||
} else
|
||||
if (v.is_str()) {
|
||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5");
|
||||
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str();
|
||||
jk << "«VS" << JK::hex1(i) << ":-5";
|
||||
jk << "«VSS" << JK::hex1(i) << ":" << v.get_str();
|
||||
} else {
|
||||
ASSERT(v.is_var());
|
||||
auto var = v.var;
|
||||
|
@ -475,13 +475,13 @@ void GetitemOp::jit_prepare(JK& jk) {
|
|||
if (vshape[j] == o_shape[k])
|
||||
vsmask |= 1<<(j+var_dim-vdim);
|
||||
}
|
||||
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask);
|
||||
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype();
|
||||
jk << "«VS" << JK::hex1(i) << '=' << JK::hex(vsmask);
|
||||
jk << "«VST" << JK::hex1(i) << ':' << var->dtype();
|
||||
}
|
||||
} else
|
||||
if (iv>=0 && io>=0) {
|
||||
ASSERT(v.is_slice());
|
||||
jk << _CS("][VS") << JK::hex1(i) << ':';
|
||||
jk << "«VS" << JK::hex1(i) << ':';
|
||||
if (std::abs(v.slice.step) <= 1)
|
||||
jk << JK::shex1(v.slice.step);
|
||||
else
|
||||
|
@ -495,11 +495,10 @@ void GetitemOp::jit_prepare(JK& jk) {
|
|||
int tdims[6];
|
||||
cuda_loop_schedule(o_shape, masks, tdims);
|
||||
for (int i=0; i<no; i++) {
|
||||
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]);
|
||||
jk << "«LO" << JK::hex1(i) << '=' << JK::hex(masks[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -39,8 +39,8 @@ RandomOp::RandomOp(NanoVector shape, NanoString dtype, NanoString type) {
|
|||
}
|
||||
|
||||
void RandomOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << output->dtype();
|
||||
jk << _CS("][R:") << type << ']';
|
||||
jk << "«T:" << output->dtype();
|
||||
jk << "«R:" << type;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -338,12 +338,12 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void ReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][Ty:") << y->dtype()
|
||||
<< _CS("][Tz:") << y->dtype()
|
||||
<< _CS("][OP:") << ns
|
||||
<< _CS("][DIM=") << JK::hex1(x->shape.size())
|
||||
<< _CS("][REDUCE=") << JK::hex(reduce_mask) << ']';
|
||||
jk << "«Tx:" << x->dtype()
|
||||
<< "«Ty:" << y->dtype()
|
||||
<< "«Tz:" << y->dtype()
|
||||
<< "«OP:" << ns
|
||||
<< "«DIM=" << JK::hex1(x->shape.size())
|
||||
<< "«REDUCE=" << JK::hex(reduce_mask);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -91,21 +91,20 @@ void ReindexOp::infer_shape() {
|
|||
}
|
||||
|
||||
void ReindexOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][XDIM=") << JK::hex1(x->shape.size())
|
||||
<< _CS("][YDIM=") << JK::hex1(y->shape.size())
|
||||
<< _CS("][OVERFLOW:") << overflow_value;
|
||||
jk << "«Tx:" << x->dtype()
|
||||
<< "«XDIM=" << JK::hex1(x->shape.size())
|
||||
<< "«YDIM=" << JK::hex1(y->shape.size())
|
||||
<< "«OVERFLOW:" << overflow_value;
|
||||
for (uint i=0; i<indexes.size(); i++)
|
||||
jk << _CS("][INDEX") << JK::hex1(i) << ':' << indexes[i];
|
||||
jk << _CS("][OSIZE=") << JK::hex1(overflow_conditions.size());
|
||||
jk << "«INDEX" << JK::hex1(i) << ':' << indexes[i];
|
||||
jk << "«OSIZE=" << JK::hex1(overflow_conditions.size());
|
||||
for (uint i=0; i<overflow_conditions.size(); i++)
|
||||
jk << _CS("][OFD") << JK::hex1(i) << ':' << overflow_conditions[i];
|
||||
jk << _CS("][ESIZE=") << JK::hex1(extras.size());
|
||||
jk << "«OFD" << JK::hex1(i) << ':' << overflow_conditions[i];
|
||||
jk << "«ESIZE=" << JK::hex1(extras.size());
|
||||
for (uint i=0; i<extras.size(); i++) {
|
||||
jk << _CS("][EDIM") << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
|
||||
jk << _CS("][Te") << JK::hex1(i) << ':' << extras[i]->dtype();
|
||||
jk << "«EDIM" << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
|
||||
jk << "«Te" << JK::hex1(i) << ':' << extras[i]->dtype();
|
||||
}
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -73,21 +73,20 @@ void ReindexReduceOp::infer_shape() {
|
|||
}
|
||||
|
||||
void ReindexReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][OP:") << ns
|
||||
<< _CS("][YDIM=") << JK::hex1(y->shape.size())
|
||||
<< _CS("][XDIM=") << JK::hex1(x->shape.size());
|
||||
jk << "«Tx:" << x->dtype()
|
||||
<< "«OP:" << ns
|
||||
<< "«YDIM=" << JK::hex1(y->shape.size())
|
||||
<< "«XDIM=" << JK::hex1(x->shape.size());
|
||||
for (uint i=0; i<indexes.size(); i++)
|
||||
jk << _CS("][INDEX") << JK::hex1(i) << ':' << indexes[i];
|
||||
jk << _CS("][OSIZE=") << JK::hex1(overflow_conditions.size());
|
||||
jk << "«INDEX" << JK::hex1(i) << ':' << indexes[i];
|
||||
jk << "«OSIZE=" << JK::hex1(overflow_conditions.size());
|
||||
for (uint i=0; i<overflow_conditions.size(); i++)
|
||||
jk << _CS("][OFD") << JK::hex1(i) << ':' << overflow_conditions[i];
|
||||
jk << _CS("][ESIZE=") << JK::hex1(extras.size());
|
||||
jk << "«OFD" << JK::hex1(i) << ':' << overflow_conditions[i];
|
||||
jk << "«ESIZE=" << JK::hex1(extras.size());
|
||||
for (uint i=0; i<extras.size(); i++) {
|
||||
jk << _CS("][EDIM") << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
|
||||
jk << _CS("][Te") << JK::hex1(i) << ':' << extras[i]->dtype();
|
||||
jk << "«EDIM" << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
|
||||
jk << "«Te" << JK::hex1(i) << ':' << extras[i]->dtype();
|
||||
}
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -30,7 +30,7 @@ void SafeClipOp::infer_shape() {
|
|||
}
|
||||
|
||||
void SafeClipOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() <<']';
|
||||
jk << "«Tx:" << x->dtype() <<"«";
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -201,32 +201,32 @@ void SetitemOp::jit_prepare(JK& jk) {
|
|||
break;
|
||||
}
|
||||
auto data = input(1);
|
||||
jk << _CS("[OP:") << op
|
||||
<< _CS("][Td:") << data->dtype()
|
||||
<< _CS("][BMASK=") << JK::hex(bmask);
|
||||
jk << "«OP:" << op
|
||||
<< "«Td:" << data->dtype()
|
||||
<< "«BMASK=" << JK::hex(bmask);
|
||||
// TODO: merge code
|
||||
auto in = inputs().front();
|
||||
int idim = i_to_vs.size();
|
||||
jk << _CS("][Ti:") << in->dtype();
|
||||
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size());
|
||||
jk << _CS("][ODIM=") << JK::hex1(o_shape.size());
|
||||
jk << "«Ti:" << in->dtype();
|
||||
jk << "«IDIM=" << JK::hex1(i_to_vs.size());
|
||||
jk << "«ODIM=" << JK::hex1(o_shape.size());
|
||||
if (first_oid_of_var>=0) {
|
||||
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var);
|
||||
jk << _CS("][VD=") << JK::hex1(var_dim);
|
||||
jk << "«FOV=" << JK::hex1(first_oid_of_var);
|
||||
jk << "«VD=" << JK::hex1(var_dim);
|
||||
}
|
||||
for (int i=0; i<idim; i++) {
|
||||
auto iv = i_to_vs[i];
|
||||
auto io = i_to_o[i];
|
||||
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv);
|
||||
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io);
|
||||
jk << "«IV" << JK::hex1(i) << ':' << JK::shex1(iv);
|
||||
jk << "«IO" << JK::hex1(i) << ':' << JK::shex1(io);
|
||||
auto& v = vs.slices[iv];
|
||||
if (iv>=0 && io==-1) {
|
||||
if (v.is_int()) {
|
||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
|
||||
jk << "«VS" << JK::hex1(i) << ":-1";
|
||||
} else
|
||||
if (v.is_str()) {
|
||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5");
|
||||
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str();
|
||||
jk << "«VS" << JK::hex1(i) << ":-5";
|
||||
jk << "«VSS" << JK::hex1(i) << ":" << v.get_str();
|
||||
} else {
|
||||
ASSERT(v.is_var());
|
||||
auto var = v.var;
|
||||
|
@ -238,13 +238,13 @@ void SetitemOp::jit_prepare(JK& jk) {
|
|||
if (vshape[j] == o_shape[k])
|
||||
vsmask |= 1<<(j+var_dim-vdim);
|
||||
}
|
||||
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask);
|
||||
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype();
|
||||
jk << "«VS" << JK::hex1(i) << '=' << JK::hex(vsmask);
|
||||
jk << "«VST" << JK::hex1(i) << ':' << var->dtype();
|
||||
}
|
||||
} else
|
||||
if (iv>=0 && io>=0) {
|
||||
ASSERT(v.is_slice());
|
||||
jk << _CS("][VS") << JK::hex1(i) << ':';
|
||||
jk << "«VS" << JK::hex1(i) << ':';
|
||||
if (std::abs(v.slice.step) <= 1)
|
||||
jk << JK::shex1(v.slice.step);
|
||||
else
|
||||
|
@ -258,11 +258,10 @@ void SetitemOp::jit_prepare(JK& jk) {
|
|||
int tdims[6];
|
||||
cuda_loop_schedule(o_shape, masks, tdims);
|
||||
for (int i=0; i<no; i++) {
|
||||
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]);
|
||||
jk << "«LO" << JK::hex1(i) << '=' << JK::hex(masks[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
void SetitemOp::compile_optimize(string& src) {
|
||||
|
|
|
@ -64,10 +64,10 @@ void TernaryOp::infer_shape() {
|
|||
}
|
||||
|
||||
void TernaryOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tc:") << cond->dtype();
|
||||
jk << _CS("][Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][Tz:") << z->dtype() << ']';
|
||||
jk << "«Tc:" << cond->dtype();
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Ty:" << y->dtype();
|
||||
jk << "«Tz:" << z->dtype();
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -79,11 +79,10 @@ VarPtr TransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void TransposeOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][DIM=") << JK::hex1(axes.size());
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«DIM=" << JK::hex1(axes.size());
|
||||
for (uint i=0; i<axes.size(); i++)
|
||||
jk << _CS("][AXES") << JK::hex1(axes[i]) << '=' << JK::hex1(i);
|
||||
jk << ']';
|
||||
jk << "«AXES" << JK::hex1(axes[i]) << '=' << JK::hex1(i);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -681,9 +681,9 @@ void UnaryOp::infer_shape() {
|
|||
}
|
||||
|
||||
void UnaryOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][Ty:") << y->dtype()
|
||||
<< _CS("][OP:") << ns << ']';
|
||||
jk << "«Tx:" << x->dtype()
|
||||
<< "«Ty:" << y->dtype()
|
||||
<< "«OP:" << ns;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -54,10 +54,9 @@ void WhereOp::infer_shape() {
|
|||
}
|
||||
|
||||
void WhereOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Ti:") << cond->dtype();
|
||||
jk << _CS("][To:") << outs[0]->dtype();
|
||||
jk << _CS("][NDIM=") << JK::hex1(cond->shape.size());
|
||||
jk << ']';
|
||||
jk << "«Ti:" << cond->dtype();
|
||||
jk << "«To:" << outs[0]->dtype();
|
||||
jk << "«NDIM=" << JK::hex1(cond->shape.size());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -19,16 +19,16 @@ JIT_TEST(jit_key) {
|
|||
});
|
||||
std::cerr << "get segfault, ok" << std::endl;
|
||||
|
||||
jk << JK::key << "key" << JK::val << "value" << JK::end;
|
||||
jk << JK::key << "key" << JK::val << JK::hex(0x123123) << JK::end;
|
||||
jk << JK::key << "key" << JK::val << JK::hex1(0x123123) << JK::end;
|
||||
jk << JK::key << "key" << JK::val << JK::hex2(0x123123) << JK::end;
|
||||
jk << JK::key << "key" << JK::val << JK::Oxhex(0x123123) << JK::end;
|
||||
jk << JK::key << "key" << JK::val << JK::Oxhex1(0x123123) << JK::end;
|
||||
jk << JK::key << "key" << JK::val << JK::Oxhex2(0x123123) << JK::end;
|
||||
string key = "[key:value][key:123123][key:3][key:23][key:0x123123][key:0x3][key:0x23]";
|
||||
jk << JK::key << "key" << JK::val << "value";
|
||||
jk << JK::key << "key" << JK::val << JK::hex(0x123123);
|
||||
jk << JK::key << "key" << JK::val << JK::hex1(0x123123);
|
||||
jk << JK::key << "key" << JK::val << JK::hex2(0x123123);
|
||||
jk << JK::key << "key" << JK::val << JK::Oxhex(0x123123);
|
||||
jk << JK::key << "key" << JK::val << JK::Oxhex1(0x123123);
|
||||
jk << JK::key << "key" << JK::val << JK::Oxhex2(0x123123);
|
||||
string key = "«key:value«key:123123«key:3«key:23«key:0x123123«key:0x3«key:0x23";
|
||||
ASSERTop(jk.to_string(),==,key);
|
||||
auto keys = parse_jit_keys("[a:11][b:22][a[3]:b::[x]][x=11][f=itof(0x0)]");
|
||||
auto keys = parse_jit_keys("«a:11«b:22«a[3]:b::[x]«x=11«f=itof(0x0)");
|
||||
vector<pair<string,string>> k2 =
|
||||
{{"a","11"},{"b","22"},{"a[3]","b::[x]"},{"x","17"},{"f","0"}};
|
||||
ASSERTop(keys,==,k2);
|
||||
|
|
|
@ -94,7 +94,7 @@ JIT_TEST(fused_op_relay_matmul) {
|
|||
auto allocator = get_allocator();
|
||||
for (auto& v : fop.vars)
|
||||
if (v.type!=1) v.var->alloc(allocator);
|
||||
auto entry = oc.compile("[OP:_fused_op_relay_matmul]", oc.src);
|
||||
auto entry = oc.compile("«OP:_fused_op_relay_matmul", oc.src);
|
||||
for (uint i=0; i<a->num; i++)
|
||||
a->ptr<float>()[i] = b->ptr<float>()[i] = 1;
|
||||
entry(&fop);
|
||||
|
|
|
@ -46,7 +46,7 @@ def test(shape, op1, op2):
|
|||
with jt.log_capture_scope(log_v=0, log_vprefix="fused_op.cc=100") as logs:
|
||||
d__ = d.data
|
||||
logs = find_log_with_re(logs,
|
||||
"Jit (fused )?op key (not )?found: \[opkey0:array\[T:float32")
|
||||
"Jit (fused )?op key (not )?found: «opkey0:array«T:float32")
|
||||
assert(len(logs)==1), logs
|
||||
|
||||
a_ = a.data
|
||||
|
|
|
@ -81,7 +81,7 @@ def test_case(box_num, out_size, time_limit):
|
|||
for i in range(1, len(rep)):
|
||||
t += float(rep[i][3]) / 1e9
|
||||
name = rep[i][0]
|
||||
if name.startswith('[') and (not '[graph:]' in name):
|
||||
if name.startswith('«') and (not '«graph:«' in name):
|
||||
fused_op_num += 1
|
||||
assert fused_op_num == 1, fused_op_num
|
||||
assert t <= time_limit, t
|
||||
|
|
|
@ -428,6 +428,15 @@ class TestSetitem(unittest.TestCase):
|
|||
np.arange(4)[None,::-1]]
|
||||
np.testing.assert_allclose(nb, b.data)
|
||||
|
||||
def test_cuda_slice_migrate_bug(self):
|
||||
a = jt.array([1,2,3,4,5])
|
||||
jt.sync_all()
|
||||
if not jt.has_cuda: return
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
b = a[0]
|
||||
b.sync(True)
|
||||
assert b.item() == 1
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue