mirror of https://github.com/Jittor/Jittor
polish jit key
This commit is contained in:
parent
91fe1fac85
commit
c5ccdaf330
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.3.4.7'
|
__version__ = '1.3.4.8'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -59,13 +59,13 @@ void CubArgReduceOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CubArgReduceOp::jit_prepare(JK& jk) {
|
void CubArgReduceOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Toffsets:") << offsets->dtype();
|
jk << "«Toffsets:" << offsets->dtype();
|
||||||
jk << _CS("][FUNC:");
|
jk << "«FUNC:";
|
||||||
if (op==ns_minimum)
|
if (op==ns_minimum)
|
||||||
jk << _CS("ArgMin]");
|
jk << "ArgMin";
|
||||||
else
|
else
|
||||||
jk << _CS("ArgMax]");
|
jk << "ArgMax";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -51,15 +51,15 @@ void CubArgsortOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CubArgsortOp::jit_prepare(JK& jk) {
|
void CubArgsortOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Tindexes:") << indexes->dtype();
|
jk << "«Tindexes:" << indexes->dtype();
|
||||||
jk << _CS("][Toffsets:") << offsets->dtype();
|
jk << "«Toffsets:" << offsets->dtype();
|
||||||
jk << _CS("][Ty:") << y->dtype();
|
jk << "«Ty:" << y->dtype();
|
||||||
jk << _CS("][FUNC:");
|
jk << "«FUNC:";
|
||||||
if (descending)
|
if (descending)
|
||||||
jk << _CS("SortPairsDescending]");
|
jk << "SortPairsDescending";
|
||||||
else
|
else
|
||||||
jk << _CS("SortPairs]");
|
jk << "SortPairs";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -38,10 +38,9 @@ void CubCumsumOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CubCumsumOp::jit_prepare(JK& jk) {
|
void CubCumsumOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Ty:") << y->dtype();
|
jk << "«Ty:" << y->dtype();
|
||||||
jk << _CS("][reverse:") << reverse;
|
jk << "«reverse:" << reverse;
|
||||||
jk << _CS("]");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
VarPtr CubCumsumOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
VarPtr CubCumsumOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
|
|
|
@ -24,7 +24,7 @@ CubTestOp::CubTestOp(string cmd) : cmd(cmd) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CubTestOp::jit_prepare(JK& jk) {
|
void CubTestOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:float32]");
|
jk << "«T:float32";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -40,10 +40,9 @@ void CubWhereOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CubWhereOp::jit_prepare(JK& jk) {
|
void CubWhereOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Ti:") << cond->dtype();
|
jk << "«Ti:" << cond->dtype();
|
||||||
jk << _CS("][To:") << outs[0]->dtype();
|
jk << "«To:" << outs[0]->dtype();
|
||||||
jk << _CS("][NDIM=") << JK::hex1(cond->shape.size());
|
jk << "«NDIM=" << JK::hex1(cond->shape.size());
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -54,11 +54,10 @@ void CublasAccMatmulOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasAccMatmulOp::jit_prepare(JK& jk) {
|
void CublasAccMatmulOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:") << a->dtype();
|
jk << "«T:" << a->dtype();
|
||||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
|
||||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
|
jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
|
||||||
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -89,11 +89,10 @@ void CublasBatchedMatmulOp::infer_shape(){
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasBatchedMatmulOp::jit_prepare(JK& jk) {
|
void CublasBatchedMatmulOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:") << a->dtype();
|
jk << "«T:" << a->dtype();
|
||||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
|
||||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
|
jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
|
||||||
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -50,11 +50,10 @@ void CublasMatmulOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasMatmulOp::jit_prepare(JK& jk) {
|
void CublasMatmulOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:") << a->dtype();
|
jk << "«T:" << a->dtype();
|
||||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
|
||||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
|
jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
|
||||||
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -19,7 +19,7 @@ CublasTestOp::CublasTestOp(int size_mult) : size_mult(size_mult) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasTestOp::jit_prepare(JK& jk) {
|
void CublasTestOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:float32]");
|
jk << "«T:float32";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -55,10 +55,9 @@ void CudnnConv3dBackwardWOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) {
|
void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Ty:") << dy->dtype();
|
jk << "«Ty:" << dy->dtype();
|
||||||
jk << _CS("][Tw:") << dw->dtype();
|
jk << "«Tw:" << dw->dtype();
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static auto make_conv3d = get_op_info("cudnn_conv3d")
|
static auto make_conv3d = get_op_info("cudnn_conv3d")
|
||||||
|
|
|
@ -52,10 +52,9 @@ void CudnnConv3dBackwardXOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) {
|
void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << dx->dtype();
|
jk << "«Tx:" << dx->dtype();
|
||||||
jk << _CS("][Ty:") << dy->dtype();
|
jk << "«Ty:" << dy->dtype();
|
||||||
jk << _CS("][Tw:") << w->dtype();
|
jk << "«Tw:" << w->dtype();
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -53,10 +53,9 @@ void CudnnConv3dOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudnnConv3dOp::jit_prepare(JK& jk) {
|
void CudnnConv3dOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Ty:") << y->dtype();
|
jk << "«Ty:" << y->dtype();
|
||||||
jk << _CS("][Tw:") << w->dtype();
|
jk << "«Tw:" << w->dtype();
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x")
|
static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x")
|
||||||
|
|
|
@ -71,13 +71,12 @@ void CudnnConvBackwardWOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudnnConvBackwardWOp::jit_prepare(JK& jk) {
|
void CudnnConvBackwardWOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Ty:") << dy->dtype();
|
jk << "«Ty:" << dy->dtype();
|
||||||
jk << _CS("][Tw:") << dw->dtype();
|
jk << "«Tw:" << dw->dtype();
|
||||||
jk << _CS("][XFORMAT:") << xformat;
|
jk << "«XFORMAT:" << xformat;
|
||||||
jk << _CS("][WFORMAT:") << wformat;
|
jk << "«WFORMAT:" << wformat;
|
||||||
jk << _CS("][YFORMAT:") << yformat;
|
jk << "«YFORMAT:" << yformat;
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static auto make_conv = get_op_info("cudnn_conv")
|
static auto make_conv = get_op_info("cudnn_conv")
|
||||||
|
|
|
@ -70,13 +70,12 @@ void CudnnConvBackwardXOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudnnConvBackwardXOp::jit_prepare(JK& jk) {
|
void CudnnConvBackwardXOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << dx->dtype();
|
jk << "«Tx:" << dx->dtype();
|
||||||
jk << _CS("][Ty:") << dy->dtype();
|
jk << "«Ty:" << dy->dtype();
|
||||||
jk << _CS("][Tw:") << w->dtype();
|
jk << "«Tw:" << w->dtype();
|
||||||
jk << _CS("][XFORMAT:") << xformat;
|
jk << "«XFORMAT:" << xformat;
|
||||||
jk << _CS("][WFORMAT:") << wformat;
|
jk << "«WFORMAT:" << wformat;
|
||||||
jk << _CS("][YFORMAT:") << yformat;
|
jk << "«YFORMAT:" << yformat;
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static auto make_conv = get_op_info("cudnn_conv")
|
static auto make_conv = get_op_info("cudnn_conv")
|
||||||
|
|
|
@ -72,13 +72,12 @@ void CudnnConvOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudnnConvOp::jit_prepare(JK& jk) {
|
void CudnnConvOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Ty:") << y->dtype();
|
jk << "«Ty:" << y->dtype();
|
||||||
jk << _CS("][Tw:") << w->dtype();
|
jk << "«Tw:" << w->dtype();
|
||||||
jk << _CS("][XFORMAT:") << xformat;
|
jk << "«XFORMAT:" << xformat;
|
||||||
jk << _CS("][WFORMAT:") << wformat;
|
jk << "«WFORMAT:" << wformat;
|
||||||
jk << _CS("][YFORMAT:") << yformat;
|
jk << "«YFORMAT:" << yformat;
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
static auto make_backwardx = get_op_info("cudnn_conv_backward_x")
|
static auto make_backwardx = get_op_info("cudnn_conv_backward_x")
|
||||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
|
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
|
||||||
|
|
|
@ -79,10 +79,9 @@ void CudnnRnnBackwardXOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudnnRnnBackwardXOp::jit_prepare(JK& jk) {
|
void CudnnRnnBackwardXOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << hx->dtype();
|
jk << "«Tx:" << hx->dtype();
|
||||||
jk << _CS("][Ty:") << y->dtype();
|
jk << "«Ty:" << y->dtype();
|
||||||
jk << _CS("][Tw:") << w->dtype();
|
jk << "«Tw:" << w->dtype();
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -98,10 +98,9 @@ void CudnnRnnOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudnnRnnOp::jit_prepare(JK& jk) {
|
void CudnnRnnOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Ty:") << y->dtype();
|
jk << "«Ty:" << y->dtype();
|
||||||
jk << _CS("][Tw:") << w->dtype();
|
jk << "«Tw:" << w->dtype();
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x")
|
static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x")
|
||||||
|
|
|
@ -20,7 +20,7 @@ CudnnTestOp::CudnnTestOp(string cmd) : cmd(move(cmd)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudnnTestOp::jit_prepare(JK& jk) {
|
void CudnnTestOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:float32]");
|
jk << "«T:float32";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -44,9 +44,9 @@ void CufftFftOp::jit_prepare(JK& jk) {
|
||||||
printf("not supported fft dtype: %s\n", y->dtype().to_cstring());
|
printf("not supported fft dtype: %s\n", y->dtype().to_cstring());
|
||||||
ASSERT(false);
|
ASSERT(false);
|
||||||
}
|
}
|
||||||
jk << _CS("[T:") << y->dtype();
|
jk << "«T:" << y->dtype();
|
||||||
jk << _CS("][I:")<<inverse<<"]";
|
jk << "«I:"<<inverse<<"]";
|
||||||
jk << _CS("[TS:\"")<<y->dtype()<<"\"]";
|
jk << _CS("«TS:\"")<<y->dtype()<<"\"]";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -25,8 +25,8 @@ CurandRandomOp::CurandRandomOp(NanoVector shape, NanoString dtype, NanoString ty
|
||||||
}
|
}
|
||||||
|
|
||||||
void CurandRandomOp::jit_prepare(JK& jk) {
|
void CurandRandomOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:") << output->dtype();
|
jk << "«T:" << output->dtype();
|
||||||
jk << _CS("][R:") << type << ']';
|
jk << "«R:" << type;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -21,7 +21,7 @@ CuttTestOp::CuttTestOp(string cmd) : cmd(cmd) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CuttTestOp::jit_prepare(JK& jk) {
|
void CuttTestOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:float32]");
|
jk << "«T:float32";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -59,7 +59,7 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
|
|
||||||
void CuttTransposeOp::jit_prepare(JK& jk) {
|
void CuttTransposeOp::jit_prepare(JK& jk) {
|
||||||
// do nothing
|
// do nothing
|
||||||
jk << _CS("[T:1]");
|
jk << "«T:1";
|
||||||
}
|
}
|
||||||
|
|
||||||
unordered_map<string, unsigned int> cutt_plan_cache;
|
unordered_map<string, unsigned int> cutt_plan_cache;
|
||||||
|
|
|
@ -37,7 +37,7 @@ VarPtr NcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void NcclAllReduceOp::jit_prepare(JK& jk) {
|
void NcclAllReduceOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
jk << "«Tx:" << x->dtype();
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -35,7 +35,7 @@ VarPtr NcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void NcclBroadcastOp::jit_prepare(JK& jk) {
|
void NcclBroadcastOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
jk << "«Tx:" << x->dtype();
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -35,7 +35,7 @@ VarPtr NcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void NcclReduceOp::jit_prepare(JK& jk) {
|
void NcclReduceOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
jk << "«Tx:" << x->dtype();
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -20,7 +20,7 @@ NcclTestOp::NcclTestOp(string cmd) : cmd(cmd) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void NcclTestOp::jit_prepare(JK& jk) {
|
void NcclTestOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:float32]");
|
jk << "«T:float32";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -79,16 +79,15 @@ static const char* short_type(Var* x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MklConvBackwardWOp::jit_prepare(JK& jk) {
|
void MklConvBackwardWOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Txd:") << x->dtype();
|
jk << "«Txd:" << x->dtype();
|
||||||
jk << _CS("][Tyd:") << dy->dtype();
|
jk << "«Tyd:" << dy->dtype();
|
||||||
jk << _CS("][Twd:") << dw->dtype();
|
jk << "«Twd:" << dw->dtype();
|
||||||
jk << _CS("][Tx:") << short_type(x);
|
jk << "«Tx:" << short_type(x);
|
||||||
jk << _CS("][Tw:") << short_type(dw);
|
jk << "«Tw:" << short_type(dw);
|
||||||
jk << _CS("][Ty:") << short_type(dy);
|
jk << "«Ty:" << short_type(dy);
|
||||||
jk << _CS("][XFORMAT:") << xformat;
|
jk << "«XFORMAT:" << xformat;
|
||||||
jk << _CS("][WFORMAT:") << wformat;
|
jk << "«WFORMAT:" << wformat;
|
||||||
jk << _CS("][YFORMAT:") << yformat;
|
jk << "«YFORMAT:" << yformat;
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -77,16 +77,15 @@ static const char* short_type(Var* x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MklConvBackwardXOp::jit_prepare(JK& jk) {
|
void MklConvBackwardXOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tyd:") << dy->dtype();
|
jk << "«Tyd:" << dy->dtype();
|
||||||
jk << _CS("][Twd:") << w->dtype();
|
jk << "«Twd:" << w->dtype();
|
||||||
jk << _CS("][Txd:") << dx->dtype();
|
jk << "«Txd:" << dx->dtype();
|
||||||
jk << _CS("][Tx:") << short_type(dx);
|
jk << "«Tx:" << short_type(dx);
|
||||||
jk << _CS("][Tw:") << short_type(w);
|
jk << "«Tw:" << short_type(w);
|
||||||
jk << _CS("][Ty:") << short_type(dy);
|
jk << "«Ty:" << short_type(dy);
|
||||||
jk << _CS("][XFORMAT:") << xformat;
|
jk << "«XFORMAT:" << xformat;
|
||||||
jk << _CS("][WFORMAT:") << wformat;
|
jk << "«WFORMAT:" << wformat;
|
||||||
jk << _CS("][YFORMAT:") << yformat;
|
jk << "«YFORMAT:" << yformat;
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -81,16 +81,15 @@ static const char* short_type(Var* x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MklConvOp::jit_prepare(JK& jk) {
|
void MklConvOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Txd:") << x->dtype();
|
jk << "«Txd:" << x->dtype();
|
||||||
jk << _CS("][Tyd:") << y->dtype();
|
jk << "«Tyd:" << y->dtype();
|
||||||
jk << _CS("][Twd:") << w->dtype();
|
jk << "«Twd:" << w->dtype();
|
||||||
jk << _CS("][Tx:") << short_type(x);
|
jk << "«Tx:" << short_type(x);
|
||||||
jk << _CS("][Tw:") << short_type(w);
|
jk << "«Tw:" << short_type(w);
|
||||||
jk << _CS("][Ty:") << short_type(y);
|
jk << "«Ty:" << short_type(y);
|
||||||
jk << _CS("][XFORMAT:") << xformat;
|
jk << "«XFORMAT:" << xformat;
|
||||||
jk << _CS("][WFORMAT:") << wformat;
|
jk << "«WFORMAT:" << wformat;
|
||||||
jk << _CS("][YFORMAT:") << yformat;
|
jk << "«YFORMAT:" << yformat;
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -44,9 +44,9 @@ void MklMatmulOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MklMatmulOp::jit_prepare(JK& jk) {
|
void MklMatmulOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:") << a->dtype();
|
jk << "«T:" << a->dtype();
|
||||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
jk << "«Trans_a:" << (trans_a ? 'T' : 'N');
|
||||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N') << ']';
|
jk << "«Trans_b:" << (trans_b ? 'T' : 'N');
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -19,7 +19,7 @@ MklTestOp::MklTestOp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MklTestOp::jit_prepare(JK& jk) {
|
void MklTestOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:float32]");
|
jk << "«T:float32";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -62,8 +62,8 @@ VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MpiAllReduceOp::jit_prepare(JK& jk) {
|
void MpiAllReduceOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][OP:") << op << ']';
|
jk << "«OP:" << op;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -49,7 +49,7 @@ VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MpiBroadcastOp::jit_prepare(JK& jk) {
|
void MpiBroadcastOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
jk << "«Tx:" << x->dtype();
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -62,8 +62,8 @@ VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MpiReduceOp::jit_prepare(JK& jk) {
|
void MpiReduceOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][OP:") << op << ']';
|
jk << "«OP:" << op;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -17,7 +17,7 @@ MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void MpiTestOp::jit_prepare(JK& jk) {
|
void MpiTestOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:float32]");
|
jk << "«T:float32";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -585,6 +585,10 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
||||||
if (!v->allocator->is_cuda())
|
if (!v->allocator->is_cuda())
|
||||||
migrate_to_gpu(v, allocator);
|
migrate_to_gpu(v, allocator);
|
||||||
}
|
}
|
||||||
|
for (Var* v : op->outputs()) {
|
||||||
|
if (!v->allocator->is_cuda())
|
||||||
|
migrate_to_gpu(v, allocator);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef NODE_MEMCHECK
|
#ifdef NODE_MEMCHECK
|
||||||
|
|
|
@ -159,55 +159,51 @@ void FusedOp::do_jit_prepare(JK& jk) {
|
||||||
jk.clear();
|
jk.clear();
|
||||||
for (uint i=0; i<ops.size(); i++) {
|
for (uint i=0; i<ops.size(); i++) {
|
||||||
Op* op = ops[i];
|
Op* op = ops[i];
|
||||||
jk << "[opkey" << i << JK::val;
|
jk << "«opkey" << i << JK::val;
|
||||||
jk << op->name();
|
jk << op->name();
|
||||||
op->jit_prepare(jk);
|
op->jit_prepare(jk);
|
||||||
jk << JK::end;
|
|
||||||
}
|
}
|
||||||
jk << _CS("[JIT:1]");
|
jk << "«JIT:1";
|
||||||
if (!use_cuda) {
|
if (!use_cuda) {
|
||||||
// only cpu
|
// only cpu
|
||||||
jk << _CS("[JIT_cpu:1]");
|
jk << "«JIT_cpu:1";
|
||||||
this->flags.set(NodeFlags::_cuda, 0);
|
this->flags.set(NodeFlags::_cuda, 0);
|
||||||
this->flags.set(NodeFlags::_cpu, 1);
|
this->flags.set(NodeFlags::_cpu, 1);
|
||||||
} else {
|
} else {
|
||||||
jk << _CS("[JIT_cuda:1]");
|
jk << "«JIT_cuda:1";
|
||||||
this->flags.set(NodeFlags::_cpu, 0);
|
this->flags.set(NodeFlags::_cpu, 0);
|
||||||
this->flags.set(NodeFlags::_cuda, 1);
|
this->flags.set(NodeFlags::_cuda, 1);
|
||||||
}
|
}
|
||||||
jk << _CS("[graph:");
|
jk << "«graph:";
|
||||||
for (auto& t : edges) {
|
for (auto& t : edges) {
|
||||||
uint i,j,k,l;
|
uint i,j,k,l;
|
||||||
std::tie(i,j,k,l) = t;
|
std::tie(i,j,k,l) = t;
|
||||||
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
|
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
|
||||||
}
|
}
|
||||||
jk << _CS("][var_info:") << JK::val;
|
jk << "«var_info:" << JK::val;
|
||||||
bool use_int64_t = false;
|
bool use_int64_t = false;
|
||||||
for (auto& vi : vars) {
|
for (auto& vi : vars) {
|
||||||
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
|
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
|
||||||
if (vi.type != 1 && vi.var->num >= std::numeric_limits<int32_t>::max())
|
if (vi.type != 1 && vi.var->num >= std::numeric_limits<int32_t>::max())
|
||||||
use_int64_t = true;
|
use_int64_t = true;
|
||||||
}
|
}
|
||||||
jk << JK::end;
|
|
||||||
if (use_int64_t)
|
if (use_int64_t)
|
||||||
jk << _CS("[index_t:int64]");
|
jk << "«index_t:int64";
|
||||||
else
|
else
|
||||||
jk << _CS("[index_t:int32]");
|
jk << "«index_t:int32";
|
||||||
if (loop_options->size()) {
|
if (loop_options->size()) {
|
||||||
if (get_loop_option("compile_shapes")) {
|
if (get_loop_option("compile_shapes")) {
|
||||||
jk << _CS("[shapes:");
|
jk << "«shapes:";
|
||||||
for (auto& vi : vars) {
|
for (auto& vi : vars) {
|
||||||
jk << '[';
|
jk << '[';
|
||||||
for (auto a : vi.var->shape)
|
for (auto a : vi.var->shape)
|
||||||
jk << a << ',';
|
jk << a << ',';
|
||||||
jk << _CS("],");
|
jk << "],";
|
||||||
}
|
}
|
||||||
jk << JK::end;
|
|
||||||
}
|
}
|
||||||
jk << _CS("[choices:");
|
jk << "«choices:";
|
||||||
for (auto& kv : *loop_options)
|
for (auto& kv : *loop_options)
|
||||||
jk << kv.first << ':' << kv.second << ',';
|
jk << kv.first << ':' << kv.second << ',';
|
||||||
jk << JK::end;
|
|
||||||
}
|
}
|
||||||
jk.finilize();
|
jk.finilize();
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,7 +194,7 @@ void run_cmd(string cmd, string cwd="") {
|
||||||
|
|
||||||
static string get_symbol_name(const string& jit_key) {
|
static string get_symbol_name(const string& jit_key) {
|
||||||
int i=0;
|
int i=0;
|
||||||
while (i<jit_key.size() && jit_key[i]!='[') i++;
|
while (i<jit_key.size() && jit_key[i]>=0) i++;
|
||||||
string op_name = i ? jit_key.substr(0, i) : "fused";
|
string op_name = i ? jit_key.substr(0, i) : "fused";
|
||||||
op_name = Op::file_name_to_class_name(op_name);
|
op_name = Op::file_name_to_class_name(op_name);
|
||||||
// _ZN7jittorXyyyyyy7jit_runEv
|
// _ZN7jittorXyyyyyy7jit_runEv
|
||||||
|
|
|
@ -80,42 +80,26 @@ static void convert_itof(string& s) {
|
||||||
|
|
||||||
vector<pair<string,string>> parse_jit_keys(const string& s) {
|
vector<pair<string,string>> parse_jit_keys(const string& s) {
|
||||||
vector<pair<string,string>> jit_keys;
|
vector<pair<string,string>> jit_keys;
|
||||||
int presum = 0;
|
auto sp = split(s, JitKey::key);
|
||||||
char state=0;
|
for (auto& ss : sp) {
|
||||||
string key, val;
|
if (!ss.size()) continue;
|
||||||
for (char c : s) {
|
string key, val;
|
||||||
if (c==JK::key) {
|
char state=0;
|
||||||
presum++;
|
for (auto c : ss) {
|
||||||
if (presum==1) {
|
if (state == 0 &&
|
||||||
|
(c==JK::val || c==JK::hex_val)) {
|
||||||
state = c;
|
state = c;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
} else
|
if (state == 0) key += c;
|
||||||
if (c==JK::val || c==JK::hex_val) {
|
else val += c;
|
||||||
if (presum==1 && state==JK::key) {
|
|
||||||
state = c;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} else
|
|
||||||
if (c==JK::end) {
|
|
||||||
presum--;
|
|
||||||
if (presum==0) {
|
|
||||||
if (state == JK::hex_val)
|
|
||||||
hex_to_dec(val);
|
|
||||||
if (startswith(val, "itof"))
|
|
||||||
convert_itof(val);
|
|
||||||
jit_keys.emplace_back(move(key), move(val));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (presum) {
|
|
||||||
if (state==JK::key)
|
|
||||||
key += c;
|
|
||||||
if (state==JK::val || state==JK::hex_val)
|
|
||||||
val += c;
|
|
||||||
}
|
}
|
||||||
|
if (state == JK::hex_val)
|
||||||
|
hex_to_dec(val);
|
||||||
|
if (startswith(val, "itof"))
|
||||||
|
convert_itof(val);
|
||||||
|
jit_keys.emplace_back(move(key), move(val));
|
||||||
}
|
}
|
||||||
ASSERT(presum==0) << s;
|
|
||||||
return jit_keys;
|
return jit_keys;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
// file 'LICENSE.txt', which is part of this source code package.
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include <cstring>
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "misc/nano_string.h"
|
#include "misc/nano_string.h"
|
||||||
#include "misc/nano_vector.h"
|
#include "misc/nano_vector.h"
|
||||||
|
@ -13,11 +14,10 @@ namespace jittor {
|
||||||
|
|
||||||
struct JitKey {
|
struct JitKey {
|
||||||
static constexpr size_t buffer_size = 2*1024*1024;
|
static constexpr size_t buffer_size = 2*1024*1024;
|
||||||
static constexpr char
|
static constexpr const char
|
||||||
key = '[',
|
*key = "«",
|
||||||
val = ':',
|
val = ':',
|
||||||
hex_val = '=',
|
hex_val = '=';
|
||||||
end = ']';
|
|
||||||
int64 size=0;
|
int64 size=0;
|
||||||
uint64 flags=0;
|
uint64 flags=0;
|
||||||
char buffer[buffer_size];
|
char buffer[buffer_size];
|
||||||
|
@ -27,7 +27,7 @@ struct JitKey {
|
||||||
|
|
||||||
inline void clear() {size = flags = 0;}
|
inline void clear() {size = flags = 0;}
|
||||||
inline void finilize() { buffer[size] = 0; }
|
inline void finilize() { buffer[size] = 0; }
|
||||||
inline bool empty() { return buffer[size-1] != end; }
|
inline bool empty() { return !size; }
|
||||||
inline const char* to_cstring() {
|
inline const char* to_cstring() {
|
||||||
return &buffer[0];
|
return &buffer[0];
|
||||||
}
|
}
|
||||||
|
@ -81,11 +81,38 @@ struct __jk_int256 {
|
||||||
typedef JitKey JK;
|
typedef JitKey JK;
|
||||||
EXTERN_LIB JK& get_jk();
|
EXTERN_LIB JK& get_jk();
|
||||||
|
|
||||||
|
inline void jk_put_str_with_len(JK& jk, const char* a, int n) {
|
||||||
|
char* xx = &jk.buffer[jk.size];
|
||||||
|
int i=0;
|
||||||
|
while (i+32<=n) {
|
||||||
|
((__jk_int256*)(xx+i))[0] = ((const __jk_int256*)(a+i))[0];
|
||||||
|
i+=32;
|
||||||
|
}
|
||||||
|
while (i+16<=n) {
|
||||||
|
((__jk_int128*)(xx+i))[0] = ((const __jk_int128*)(a+i))[0];
|
||||||
|
i+=16;
|
||||||
|
}
|
||||||
|
while (i+8<=n) {
|
||||||
|
((long long*)(xx+i))[0] = ((const long long*)(a+i))[0];
|
||||||
|
i+=8;
|
||||||
|
}
|
||||||
|
while (i+4<=n) {
|
||||||
|
((int*)(xx+i))[0] = ((const int*)(a+i))[0];
|
||||||
|
i+=4;
|
||||||
|
}
|
||||||
|
while (i+2<=n) {
|
||||||
|
((int16_t*)(xx+i))[0] = ((const int16_t*)(a+i))[0];
|
||||||
|
i+=2;
|
||||||
|
}
|
||||||
|
while (i+1<=n) {
|
||||||
|
((char*)(xx+i))[0] = ((const char*)(a+i))[0];
|
||||||
|
i+=1;
|
||||||
|
}
|
||||||
|
jk.size += n;
|
||||||
|
}
|
||||||
|
|
||||||
inline JK& operator<<(JK& jk, const char* s) {
|
inline JK& operator<<(JK& jk, const char* s) {
|
||||||
int i;
|
jk_put_str_with_len(jk, s, strlen(s));
|
||||||
for (i=0; s[i]; i++)
|
|
||||||
jk.buffer[jk.size+i] = s[i];
|
|
||||||
jk.size += i;
|
|
||||||
return jk;
|
return jk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,129 +226,133 @@ vector<pair<string,string>> parse_jit_keys(const string& s);
|
||||||
|
|
||||||
template <class Ta, class Tb>
|
template <class Ta, class Tb>
|
||||||
void add_jit_define(JK& jk, const Ta& key, const Tb& val) {
|
void add_jit_define(JK& jk, const Ta& key, const Tb& val) {
|
||||||
jk << JK::key << key << JK::val << val << JK::end;
|
jk << JK::key << key << JK::val << val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Ta, class Tb, class Tc>
|
template <class Ta, class Tb, class Tc>
|
||||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const Tc& val) {
|
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const Tc& val) {
|
||||||
jk << JK::key << key << i << JK::val << val << JK::end;
|
jk << JK::key << key << i << JK::val << val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <class Ta>
|
template <class Ta>
|
||||||
void add_jit_define(JK& jk, const Ta& key, const JK::hex& val) {
|
void add_jit_define(JK& jk, const Ta& key, const JK::hex& val) {
|
||||||
jk << JK::key << key << JK::hex_val << val << JK::end;
|
jk << JK::key << key << JK::hex_val << val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Ta, class Tb>
|
template <class Ta, class Tb>
|
||||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex& val) {
|
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex& val) {
|
||||||
jk << JK::key << key << i << JK::hex_val << val << JK::end;
|
jk << JK::key << key << i << JK::hex_val << val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Ta>
|
template <class Ta>
|
||||||
void add_jit_define(JK& jk, const Ta& key, const JK::hex1& val) {
|
void add_jit_define(JK& jk, const Ta& key, const JK::hex1& val) {
|
||||||
jk << JK::key << key << JK::hex_val << val << JK::end;
|
jk << JK::key << key << JK::hex_val << val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Ta, class Tb>
|
template <class Ta, class Tb>
|
||||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex1& val) {
|
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex1& val) {
|
||||||
jk << JK::key << key << i << JK::hex_val << val << JK::end;
|
jk << JK::key << key << i << JK::hex_val << val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Ta>
|
template <class Ta>
|
||||||
void add_jit_define(JK& jk, const Ta& key, const JK::hex2& val) {
|
void add_jit_define(JK& jk, const Ta& key, const JK::hex2& val) {
|
||||||
jk << JK::key << key << JK::hex_val << val << JK::end;
|
jk << JK::key << key << JK::hex_val << val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Ta, class Tb>
|
template <class Ta, class Tb>
|
||||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex2& val) {
|
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex2& val) {
|
||||||
jk << JK::key << key << i << JK::hex_val << val << JK::end;
|
jk << JK::key << key << i << JK::hex_val << val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define _CS(x) x
|
||||||
|
// // begin of const string
|
||||||
|
// #define MAX_CONST_CHAR 32
|
||||||
|
|
||||||
// begin of const string
|
// #define _CS_MIN(a,b) (a)<(b)?(a):(b)
|
||||||
#define MAX_CONST_CHAR 32
|
|
||||||
|
|
||||||
#define _CS_MIN(a,b) (a)<(b)?(a):(b)
|
// #define _CS_T(s)\
|
||||||
|
// getChr(s,0),\
|
||||||
|
// getChr(s,1),\
|
||||||
|
// getChr(s,2),\
|
||||||
|
// getChr(s,3),\
|
||||||
|
// getChr(s,4),\
|
||||||
|
// getChr(s,5),\
|
||||||
|
// getChr(s,6),\
|
||||||
|
// getChr(s,7),\
|
||||||
|
// getChr(s,8),\
|
||||||
|
// getChr(s,9),\
|
||||||
|
// getChr(s,10),\
|
||||||
|
// getChr(s,11),\
|
||||||
|
// getChr(s,12),\
|
||||||
|
// getChr(s,13),\
|
||||||
|
// getChr(s,14),\
|
||||||
|
// getChr(s,15),\
|
||||||
|
// getChr(s,16),\
|
||||||
|
// getChr(s,17),\
|
||||||
|
// getChr(s,18),\
|
||||||
|
// getChr(s,19),\
|
||||||
|
// getChr(s,20),\
|
||||||
|
// getChr(s,21),\
|
||||||
|
// getChr(s,22),\
|
||||||
|
// getChr(s,23),\
|
||||||
|
// getChr(s,24),\
|
||||||
|
// getChr(s,25),\
|
||||||
|
// getChr(s,26),\
|
||||||
|
// getChr(s,27),\
|
||||||
|
// getChr(s,28),\
|
||||||
|
// getChr(s,29),\
|
||||||
|
// getChr(s,30),\
|
||||||
|
// getChr(s,31),\
|
||||||
|
// getChr(s,32),\
|
||||||
|
// getChr(s,33),\
|
||||||
|
// getChr(s,34),\
|
||||||
|
// getChr(s,35)
|
||||||
|
|
||||||
#define _CS_T(s)\
|
// #define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0)
|
||||||
getChr(s,0),\
|
|
||||||
getChr(s,1),\
|
|
||||||
getChr(s,2),\
|
|
||||||
getChr(s,3),\
|
|
||||||
getChr(s,4),\
|
|
||||||
getChr(s,5),\
|
|
||||||
getChr(s,6),\
|
|
||||||
getChr(s,7),\
|
|
||||||
getChr(s,8),\
|
|
||||||
getChr(s,9),\
|
|
||||||
getChr(s,10),\
|
|
||||||
getChr(s,11),\
|
|
||||||
getChr(s,12),\
|
|
||||||
getChr(s,13),\
|
|
||||||
getChr(s,14),\
|
|
||||||
getChr(s,15),\
|
|
||||||
getChr(s,16),\
|
|
||||||
getChr(s,17),\
|
|
||||||
getChr(s,18),\
|
|
||||||
getChr(s,19),\
|
|
||||||
getChr(s,20),\
|
|
||||||
getChr(s,21),\
|
|
||||||
getChr(s,22),\
|
|
||||||
getChr(s,23),\
|
|
||||||
getChr(s,24),\
|
|
||||||
getChr(s,25),\
|
|
||||||
getChr(s,26),\
|
|
||||||
getChr(s,27),\
|
|
||||||
getChr(s,28),\
|
|
||||||
getChr(s,29),\
|
|
||||||
getChr(s,30),\
|
|
||||||
getChr(s,31),\
|
|
||||||
getChr(s,32),\
|
|
||||||
getChr(s,33),\
|
|
||||||
getChr(s,34),\
|
|
||||||
getChr(s,35)
|
|
||||||
|
|
||||||
#define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0)
|
// #ifdef _MSC_VER
|
||||||
|
// #define _CS(str) str
|
||||||
|
// #else
|
||||||
|
// #define _CS(str) _CS_G<_CS_T(str)>()
|
||||||
|
// #endif
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
// template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G {
|
||||||
#define _CS(str) str
|
// };
|
||||||
#else
|
|
||||||
#define _CS(str) _CS_G<_CS_T(str)>()
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G {
|
// template<> struct _CS_G<0,0,0,0> {};
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct _CS_G<0,0,0,0> {};
|
// template <char c1, char c2, char c3, char c4, char... Chars_>
|
||||||
|
// inline JK& operator<<(JK& jk, const _CS_G<c1,c2,c3,c4,Chars_...>& _) {
|
||||||
|
// ((uint32*)(jk.buffer+jk.size))[0] =
|
||||||
|
// (uint32((uint8)(c4))<<24)+
|
||||||
|
// (uint32((uint8)(c3))<<16)+
|
||||||
|
// (uint32((uint8)(c2))<<8)+
|
||||||
|
// uint32((uint8)(c1));
|
||||||
|
// if (c4) {
|
||||||
|
// jk.size += 4;
|
||||||
|
// jk << _CS_G<Chars_...>();
|
||||||
|
// } else
|
||||||
|
// if (c3) {
|
||||||
|
// jk.size += 3;
|
||||||
|
// } else
|
||||||
|
// if (c2) {
|
||||||
|
// jk.size += 2;
|
||||||
|
// } else
|
||||||
|
// if (c1) {
|
||||||
|
// jk.size += 1;
|
||||||
|
// }
|
||||||
|
// return jk;
|
||||||
|
// }
|
||||||
|
|
||||||
template <char c1, char c2, char c3, char c4, char... Chars_>
|
// template <>
|
||||||
inline JK& operator<<(JK& jk, const _CS_G<c1,c2,c3,c4,Chars_...>& _) {
|
// inline JK& operator<<(JK& jk, const _CS_G<0,0,0,0>& _) {
|
||||||
((int*)(jk.buffer+jk.size))[0] = c4*(1<<24)+c3*(1<<16)+c2*(1<<8)+c1;
|
// return jk;
|
||||||
if (c4) {
|
// }
|
||||||
jk.size += 4;
|
|
||||||
jk << _CS_G<Chars_...>();
|
|
||||||
} else
|
|
||||||
if (c3) {
|
|
||||||
jk.size += 3;
|
|
||||||
} else
|
|
||||||
if (c2) {
|
|
||||||
jk.size += 2;
|
|
||||||
} else
|
|
||||||
if (c1) {
|
|
||||||
jk.size += 1;
|
|
||||||
}
|
|
||||||
return jk;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline JK& operator<<(JK& jk, const _CS_G<0,0,0,0>& _) {
|
|
||||||
return jk;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline JK& operator<<(JK& jk, float64 f) {
|
inline JK& operator<<(JK& jk, float64 f) {
|
||||||
return jk << _CS("itof(0x") << JK::hex(ftoi(f)) << ')';
|
return jk << "itof(0x" << JK::hex(ftoi(f)) << ')';
|
||||||
}
|
}
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -128,14 +128,16 @@ string Op::get_hash_name() {
|
||||||
void Op::do_jit_prepare(JK& jk) {
|
void Op::do_jit_prepare(JK& jk) {
|
||||||
memcheck_all_exist();
|
memcheck_all_exist();
|
||||||
jk << name();
|
jk << name();
|
||||||
|
auto pre_size = jk.size;
|
||||||
jit_prepare(jk);
|
jit_prepare(jk);
|
||||||
if (jk.empty()) {
|
if (jk.size == pre_size) {
|
||||||
// not a jit op
|
// not a jit op
|
||||||
bool has_cuda = flags.get(NodeFlags::_cuda);
|
bool has_cuda = flags.get(NodeFlags::_cuda);
|
||||||
bool has_cpu = flags.get(NodeFlags::_cpu);
|
bool has_cpu = flags.get(NodeFlags::_cpu);
|
||||||
CHECK(has_cuda || has_cpu);
|
CHECK(has_cuda || has_cpu);
|
||||||
if (has_cuda && has_cpu && !use_cuda)
|
if (has_cuda && has_cpu && !use_cuda)
|
||||||
flags.set(NodeFlags::_cuda, 0);
|
flags.set(NodeFlags::_cuda, 0);
|
||||||
|
jk.clear();
|
||||||
} else {
|
} else {
|
||||||
bool use_int64_t = false;
|
bool use_int64_t = false;
|
||||||
// TODO: fused op do not have inputs,
|
// TODO: fused op do not have inputs,
|
||||||
|
@ -149,9 +151,9 @@ void Op::do_jit_prepare(JK& jk) {
|
||||||
if (var->num >= std::numeric_limits<int32_t>::max())
|
if (var->num >= std::numeric_limits<int32_t>::max())
|
||||||
use_int64_t = true;
|
use_int64_t = true;
|
||||||
}
|
}
|
||||||
jk << _CS("[JIT:1]");
|
jk << "«JIT:1";
|
||||||
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
|
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
|
||||||
jk << _CS("[JIT_cuda:1]");
|
jk << "«JIT_cuda:1";
|
||||||
flags.set(NodeFlags::_cpu, 0);
|
flags.set(NodeFlags::_cpu, 0);
|
||||||
// TODO: 64bit index in CUDA
|
// TODO: 64bit index in CUDA
|
||||||
// use_int64_t = false;
|
// use_int64_t = false;
|
||||||
|
@ -164,14 +166,14 @@ void Op::do_jit_prepare(JK& jk) {
|
||||||
}
|
}
|
||||||
ASSERT(flags.get(NodeFlags::_cpu))
|
ASSERT(flags.get(NodeFlags::_cpu))
|
||||||
<< "Op" << name() << "doesn't have cpu version";
|
<< "Op" << name() << "doesn't have cpu version";
|
||||||
jk << _CS("[JIT_cpu:1]");
|
jk << "«JIT_cpu:1";
|
||||||
flags.set(NodeFlags::_cuda, 0);
|
flags.set(NodeFlags::_cuda, 0);
|
||||||
}
|
}
|
||||||
if (try_use_32bit_index) use_int64_t = false;
|
if (try_use_32bit_index) use_int64_t = false;
|
||||||
if (use_int64_t)
|
if (use_int64_t)
|
||||||
jk << _CS("[index_t:int64]");
|
jk << "«index_t:int64";
|
||||||
else
|
else
|
||||||
jk << _CS("[index_t:int32]");
|
jk << "«index_t:int32";
|
||||||
}
|
}
|
||||||
jk.finilize();
|
jk.finilize();
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,14 +158,13 @@ void ArgReduceOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArgReduceOp::jit_prepare(JK& jk) {
|
void ArgReduceOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Ty:") << y->dtype();
|
jk << "«Ty:" << y->dtype();
|
||||||
jk << _CS("][XDIM=") << JK::hex1(x->shape.size());
|
jk << "«XDIM=" << JK::hex1(x->shape.size());
|
||||||
jk << _CS("][YDIM=") << JK::hex1(y->shape.size());
|
jk << "«YDIM=" << JK::hex1(y->shape.size());
|
||||||
jk << _CS("][KEEPDIMS:") << (keepdims ? '1' : '0');
|
jk << "«KEEPDIMS:" << (keepdims ? '1' : '0');
|
||||||
jk << _CS("][DIM=") << JK::hex1(dim);
|
jk << "«DIM=" << JK::hex1(dim);
|
||||||
jk << _CS("][CMP:") << (op==ns_minimum ? "<" : ">");
|
jk << "«CMP:" << (op==ns_minimum ? "<" : ">");
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -130,12 +130,11 @@ void ArgsortOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArgsortOp::jit_prepare(JK& jk) {
|
void ArgsortOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Ty:") << y->dtype();
|
jk << "«Ty:" << y->dtype();
|
||||||
jk << _CS("][XDIM=") << JK::hex1(x->shape.size());
|
jk << "«XDIM=" << JK::hex1(x->shape.size());
|
||||||
jk << _CS("][DIM=") << JK::hex1(dim);
|
jk << "«DIM=" << JK::hex1(dim);
|
||||||
jk << _CS("][CMP:") << (descending ? '>' : '<');
|
jk << "«CMP:" << (descending ? '>' : '<');
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -78,7 +78,7 @@ ArrayOp::ArrayOp(ArrayArgs&& args) {
|
||||||
|
|
||||||
void ArrayOp::jit_prepare(JK& jk) {
|
void ArrayOp::jit_prepare(JK& jk) {
|
||||||
if (output->flags.get(NodeFlags::_force_fuse)) {
|
if (output->flags.get(NodeFlags::_force_fuse)) {
|
||||||
jk << _CS("[T:") << output->dtype() << ']';
|
jk << "«T:" << output->dtype();
|
||||||
|
|
||||||
// fill or find cbuffer for const var pass
|
// fill or find cbuffer for const var pass
|
||||||
if (output->dtype().dsize() == 4) {
|
if (output->dtype().dsize() == 4) {
|
||||||
|
@ -86,7 +86,7 @@ void ArrayOp::jit_prepare(JK& jk) {
|
||||||
auto y = std::abs(ptr<float32>()[0]);
|
auto y = std::abs(ptr<float32>()[0]);
|
||||||
auto z = ptr<uint32>()[0];
|
auto z = ptr<uint32>()[0];
|
||||||
if ((x<=2) || (y==1.0f || y==2.0f))
|
if ((x<=2) || (y==1.0f || y==2.0f))
|
||||||
jk << _CS("[o:") << z << ']';
|
jk << "«o:" << z;
|
||||||
}
|
}
|
||||||
// end of fill cbuffer
|
// end of fill cbuffer
|
||||||
}
|
}
|
||||||
|
|
|
@ -540,10 +540,10 @@ void BinaryOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void BinaryOp::jit_prepare(JK& jk) {
|
void BinaryOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype()
|
jk << "«Tx:" << x->dtype()
|
||||||
<< _CS("][Ty:") << y->dtype()
|
<< "«Ty:" << y->dtype()
|
||||||
<< _CS("][Tz:") << z->dtype()
|
<< "«Tz:" << z->dtype()
|
||||||
<< _CS("][OP:") << ns << ']';
|
<< "«OP:" << ns;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -167,9 +167,9 @@ void BroadcastToOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void BroadcastToOp::jit_prepare(JK& jk) {
|
void BroadcastToOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype()
|
jk << "«Tx:" << x->dtype()
|
||||||
<< _CS("][DIM=") << JK::hex1(z->shape.size())
|
<< "«DIM=" << JK::hex1(z->shape.size())
|
||||||
<< _CS("][BCAST=") << JK::hex(bcast_mask) << ']';
|
<< "«BCAST=" << JK::hex(bcast_mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -25,10 +25,10 @@ void CandidateOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CandidateOp::jit_prepare(JK& jk) {
|
void CandidateOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Ty:") << y->dtype();
|
jk << "«Ty:" << y->dtype();
|
||||||
jk << _CS("][FUNC:") << fail_cond;
|
jk << "«FUNC:" << fail_cond;
|
||||||
jk << _CS("][XDIM=") << JK::hex1(x->shape.size()) << ']';
|
jk << "«XDIM=" << JK::hex1(x->shape.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -123,18 +123,18 @@ void CodeOp::jit_prepare(JK& jk) {
|
||||||
|
|
||||||
// forward: in0 in1 in2 -> out0 out1
|
// forward: in0 in1 in2 -> out0 out1
|
||||||
// backward: in0 in1 in2 in3(pout0) in4(pout1)
|
// backward: in0 in1 in2 in3(pout0) in4(pout1)
|
||||||
jk << _CS("[IN_SIZE=") << JK::hex(_inputs.size());
|
jk << "«IN_SIZE=" << JK::hex(_inputs.size());
|
||||||
for (uint i=0; i<_inputs.size(); i++) {
|
for (uint i=0; i<_inputs.size(); i++) {
|
||||||
jk << _CS("][in") << JK::hex(i) << _CS("_dim=")
|
jk << "«in" << JK::hex(i) << "_dim="
|
||||||
<< JK::hex1(_inputs[i]->shape.size());
|
<< JK::hex1(_inputs[i]->shape.size());
|
||||||
jk << _CS("][in") << JK::hex(i) << _CS("_type:")
|
jk << "«in" << JK::hex(i) << "_type:"
|
||||||
<< _inputs[i]->dtype();
|
<< _inputs[i]->dtype();
|
||||||
}
|
}
|
||||||
jk << _CS("][OUT_SIZE=") << JK::hex(_outputs.size());
|
jk << "«OUT_SIZE=" << JK::hex(_outputs.size());
|
||||||
for (uint i=0; i<_outputs.size(); i++) {
|
for (uint i=0; i<_outputs.size(); i++) {
|
||||||
jk << _CS("][out") << JK::hex(i) << _CS("_dim=")
|
jk << "«out" << JK::hex(i) << "_dim="
|
||||||
<< JK::hex1(_outputs[i]->shape.size());
|
<< JK::hex1(_outputs[i]->shape.size());
|
||||||
jk << _CS("][out") << JK::hex(i) << _CS("_type:")
|
jk << "«out" << JK::hex(i) << "_type:"
|
||||||
<< _outputs[i]->dtype();
|
<< _outputs[i]->dtype();
|
||||||
}
|
}
|
||||||
string& header = flags.get(NodeFlags::_cuda) ?
|
string& header = flags.get(NodeFlags::_cuda) ?
|
||||||
|
@ -142,9 +142,9 @@ void CodeOp::jit_prepare(JK& jk) {
|
||||||
string& src = flags.get(NodeFlags::_cuda) ?
|
string& src = flags.get(NodeFlags::_cuda) ?
|
||||||
cuda_src : cpu_src;
|
cuda_src : cpu_src;
|
||||||
|
|
||||||
jk << _CS("][HEADER:") << header;
|
jk << "«HEADER:" << header;
|
||||||
CHECK(src.size());
|
CHECK(src.size());
|
||||||
jk << _CS("\nnamespace jittor {\n");
|
jk << "\nnamespace jittor {\n";
|
||||||
int i=0;
|
int i=0;
|
||||||
// move cuda kernel function into header
|
// move cuda kernel function into header
|
||||||
for (; i<src.size(); i++) {
|
for (; i<src.size(); i++) {
|
||||||
|
@ -165,9 +165,8 @@ void CodeOp::jit_prepare(JK& jk) {
|
||||||
}
|
}
|
||||||
} else break;
|
} else break;
|
||||||
}
|
}
|
||||||
jk << _CS("}][CODE:");
|
jk << "}«CODE:";
|
||||||
for (; i<src.size(); i++) jk << src[i];
|
for (; i<src.size(); i++) jk << src[i];
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
||||||
|
|
|
@ -74,12 +74,11 @@ VarPtr FuseTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
void FuseTransposeOp::jit_prepare(JK& jk) {
|
void FuseTransposeOp::jit_prepare(JK& jk) {
|
||||||
auto bc = type()==OpType::broadcast;
|
auto bc = type()==OpType::broadcast;
|
||||||
auto ax = bc ? axes : get_reverse(axes);
|
auto ax = bc ? axes : get_reverse(axes);
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][DIM=") << JK::hex1(axes.size());
|
jk << "«DIM=" << JK::hex1(axes.size());
|
||||||
jk << _CS("][BC:") << JK::hex1(bc);
|
jk << "«BC:" << JK::hex1(bc);
|
||||||
for (uint i=0; i<ax.size(); i++)
|
for (uint i=0; i<ax.size(); i++)
|
||||||
jk << _CS("][AXES") << JK::hex1(ax[i]) << '=' << JK::hex1(i);
|
jk << "«AXES" << JK::hex1(ax[i]) << '=' << JK::hex1(i);
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -444,26 +444,26 @@ void GetitemOp::grads(Var** dout, VarPtr* dins) {
|
||||||
void GetitemOp::jit_prepare(JK& jk) {
|
void GetitemOp::jit_prepare(JK& jk) {
|
||||||
auto in = inputs().front();
|
auto in = inputs().front();
|
||||||
int idim = i_to_vs.size();
|
int idim = i_to_vs.size();
|
||||||
jk << _CS("[Ti:") << in->dtype();
|
jk << "«Ti:" << in->dtype();
|
||||||
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size());
|
jk << "«IDIM=" << JK::hex1(i_to_vs.size());
|
||||||
jk << _CS("][ODIM=") << JK::hex1(o_shape.size());
|
jk << "«ODIM=" << JK::hex1(o_shape.size());
|
||||||
if (first_oid_of_var>=0) {
|
if (first_oid_of_var>=0) {
|
||||||
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var);
|
jk << "«FOV=" << JK::hex1(first_oid_of_var);
|
||||||
jk << _CS("][VD=") << JK::hex1(var_dim);
|
jk << "«VD=" << JK::hex1(var_dim);
|
||||||
}
|
}
|
||||||
for (int i=0; i<idim; i++) {
|
for (int i=0; i<idim; i++) {
|
||||||
auto iv = i_to_vs[i];
|
auto iv = i_to_vs[i];
|
||||||
auto io = i_to_o[i];
|
auto io = i_to_o[i];
|
||||||
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv);
|
jk << "«IV" << JK::hex1(i) << ':' << JK::shex1(iv);
|
||||||
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io);
|
jk << "«IO" << JK::hex1(i) << ':' << JK::shex1(io);
|
||||||
auto& v = vs.slices[iv];
|
auto& v = vs.slices[iv];
|
||||||
if (iv>=0 && io==-1) {
|
if (iv>=0 && io==-1) {
|
||||||
if (v.is_int()) {
|
if (v.is_int()) {
|
||||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
|
jk << "«VS" << JK::hex1(i) << ":-1";
|
||||||
} else
|
} else
|
||||||
if (v.is_str()) {
|
if (v.is_str()) {
|
||||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5");
|
jk << "«VS" << JK::hex1(i) << ":-5";
|
||||||
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str();
|
jk << "«VSS" << JK::hex1(i) << ":" << v.get_str();
|
||||||
} else {
|
} else {
|
||||||
ASSERT(v.is_var());
|
ASSERT(v.is_var());
|
||||||
auto var = v.var;
|
auto var = v.var;
|
||||||
|
@ -475,13 +475,13 @@ void GetitemOp::jit_prepare(JK& jk) {
|
||||||
if (vshape[j] == o_shape[k])
|
if (vshape[j] == o_shape[k])
|
||||||
vsmask |= 1<<(j+var_dim-vdim);
|
vsmask |= 1<<(j+var_dim-vdim);
|
||||||
}
|
}
|
||||||
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask);
|
jk << "«VS" << JK::hex1(i) << '=' << JK::hex(vsmask);
|
||||||
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype();
|
jk << "«VST" << JK::hex1(i) << ':' << var->dtype();
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
if (iv>=0 && io>=0) {
|
if (iv>=0 && io>=0) {
|
||||||
ASSERT(v.is_slice());
|
ASSERT(v.is_slice());
|
||||||
jk << _CS("][VS") << JK::hex1(i) << ':';
|
jk << "«VS" << JK::hex1(i) << ':';
|
||||||
if (std::abs(v.slice.step) <= 1)
|
if (std::abs(v.slice.step) <= 1)
|
||||||
jk << JK::shex1(v.slice.step);
|
jk << JK::shex1(v.slice.step);
|
||||||
else
|
else
|
||||||
|
@ -495,11 +495,10 @@ void GetitemOp::jit_prepare(JK& jk) {
|
||||||
int tdims[6];
|
int tdims[6];
|
||||||
cuda_loop_schedule(o_shape, masks, tdims);
|
cuda_loop_schedule(o_shape, masks, tdims);
|
||||||
for (int i=0; i<no; i++) {
|
for (int i=0; i<no; i++) {
|
||||||
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]);
|
jk << "«LO" << JK::hex1(i) << '=' << JK::hex(masks[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -39,8 +39,8 @@ RandomOp::RandomOp(NanoVector shape, NanoString dtype, NanoString type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomOp::jit_prepare(JK& jk) {
|
void RandomOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[T:") << output->dtype();
|
jk << "«T:" << output->dtype();
|
||||||
jk << _CS("][R:") << type << ']';
|
jk << "«R:" << type;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -338,12 +338,12 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReduceOp::jit_prepare(JK& jk) {
|
void ReduceOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype()
|
jk << "«Tx:" << x->dtype()
|
||||||
<< _CS("][Ty:") << y->dtype()
|
<< "«Ty:" << y->dtype()
|
||||||
<< _CS("][Tz:") << y->dtype()
|
<< "«Tz:" << y->dtype()
|
||||||
<< _CS("][OP:") << ns
|
<< "«OP:" << ns
|
||||||
<< _CS("][DIM=") << JK::hex1(x->shape.size())
|
<< "«DIM=" << JK::hex1(x->shape.size())
|
||||||
<< _CS("][REDUCE=") << JK::hex(reduce_mask) << ']';
|
<< "«REDUCE=" << JK::hex(reduce_mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -91,21 +91,20 @@ void ReindexOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReindexOp::jit_prepare(JK& jk) {
|
void ReindexOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype()
|
jk << "«Tx:" << x->dtype()
|
||||||
<< _CS("][XDIM=") << JK::hex1(x->shape.size())
|
<< "«XDIM=" << JK::hex1(x->shape.size())
|
||||||
<< _CS("][YDIM=") << JK::hex1(y->shape.size())
|
<< "«YDIM=" << JK::hex1(y->shape.size())
|
||||||
<< _CS("][OVERFLOW:") << overflow_value;
|
<< "«OVERFLOW:" << overflow_value;
|
||||||
for (uint i=0; i<indexes.size(); i++)
|
for (uint i=0; i<indexes.size(); i++)
|
||||||
jk << _CS("][INDEX") << JK::hex1(i) << ':' << indexes[i];
|
jk << "«INDEX" << JK::hex1(i) << ':' << indexes[i];
|
||||||
jk << _CS("][OSIZE=") << JK::hex1(overflow_conditions.size());
|
jk << "«OSIZE=" << JK::hex1(overflow_conditions.size());
|
||||||
for (uint i=0; i<overflow_conditions.size(); i++)
|
for (uint i=0; i<overflow_conditions.size(); i++)
|
||||||
jk << _CS("][OFD") << JK::hex1(i) << ':' << overflow_conditions[i];
|
jk << "«OFD" << JK::hex1(i) << ':' << overflow_conditions[i];
|
||||||
jk << _CS("][ESIZE=") << JK::hex1(extras.size());
|
jk << "«ESIZE=" << JK::hex1(extras.size());
|
||||||
for (uint i=0; i<extras.size(); i++) {
|
for (uint i=0; i<extras.size(); i++) {
|
||||||
jk << _CS("][EDIM") << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
|
jk << "«EDIM" << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
|
||||||
jk << _CS("][Te") << JK::hex1(i) << ':' << extras[i]->dtype();
|
jk << "«Te" << JK::hex1(i) << ':' << extras[i]->dtype();
|
||||||
}
|
}
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -73,21 +73,20 @@ void ReindexReduceOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReindexReduceOp::jit_prepare(JK& jk) {
|
void ReindexReduceOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype()
|
jk << "«Tx:" << x->dtype()
|
||||||
<< _CS("][OP:") << ns
|
<< "«OP:" << ns
|
||||||
<< _CS("][YDIM=") << JK::hex1(y->shape.size())
|
<< "«YDIM=" << JK::hex1(y->shape.size())
|
||||||
<< _CS("][XDIM=") << JK::hex1(x->shape.size());
|
<< "«XDIM=" << JK::hex1(x->shape.size());
|
||||||
for (uint i=0; i<indexes.size(); i++)
|
for (uint i=0; i<indexes.size(); i++)
|
||||||
jk << _CS("][INDEX") << JK::hex1(i) << ':' << indexes[i];
|
jk << "«INDEX" << JK::hex1(i) << ':' << indexes[i];
|
||||||
jk << _CS("][OSIZE=") << JK::hex1(overflow_conditions.size());
|
jk << "«OSIZE=" << JK::hex1(overflow_conditions.size());
|
||||||
for (uint i=0; i<overflow_conditions.size(); i++)
|
for (uint i=0; i<overflow_conditions.size(); i++)
|
||||||
jk << _CS("][OFD") << JK::hex1(i) << ':' << overflow_conditions[i];
|
jk << "«OFD" << JK::hex1(i) << ':' << overflow_conditions[i];
|
||||||
jk << _CS("][ESIZE=") << JK::hex1(extras.size());
|
jk << "«ESIZE=" << JK::hex1(extras.size());
|
||||||
for (uint i=0; i<extras.size(); i++) {
|
for (uint i=0; i<extras.size(); i++) {
|
||||||
jk << _CS("][EDIM") << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
|
jk << "«EDIM" << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
|
||||||
jk << _CS("][Te") << JK::hex1(i) << ':' << extras[i]->dtype();
|
jk << "«Te" << JK::hex1(i) << ':' << extras[i]->dtype();
|
||||||
}
|
}
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -30,7 +30,7 @@ void SafeClipOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void SafeClipOp::jit_prepare(JK& jk) {
|
void SafeClipOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype() <<']';
|
jk << "«Tx:" << x->dtype() <<"«";
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -201,32 +201,32 @@ void SetitemOp::jit_prepare(JK& jk) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
auto data = input(1);
|
auto data = input(1);
|
||||||
jk << _CS("[OP:") << op
|
jk << "«OP:" << op
|
||||||
<< _CS("][Td:") << data->dtype()
|
<< "«Td:" << data->dtype()
|
||||||
<< _CS("][BMASK=") << JK::hex(bmask);
|
<< "«BMASK=" << JK::hex(bmask);
|
||||||
// TODO: merge code
|
// TODO: merge code
|
||||||
auto in = inputs().front();
|
auto in = inputs().front();
|
||||||
int idim = i_to_vs.size();
|
int idim = i_to_vs.size();
|
||||||
jk << _CS("][Ti:") << in->dtype();
|
jk << "«Ti:" << in->dtype();
|
||||||
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size());
|
jk << "«IDIM=" << JK::hex1(i_to_vs.size());
|
||||||
jk << _CS("][ODIM=") << JK::hex1(o_shape.size());
|
jk << "«ODIM=" << JK::hex1(o_shape.size());
|
||||||
if (first_oid_of_var>=0) {
|
if (first_oid_of_var>=0) {
|
||||||
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var);
|
jk << "«FOV=" << JK::hex1(first_oid_of_var);
|
||||||
jk << _CS("][VD=") << JK::hex1(var_dim);
|
jk << "«VD=" << JK::hex1(var_dim);
|
||||||
}
|
}
|
||||||
for (int i=0; i<idim; i++) {
|
for (int i=0; i<idim; i++) {
|
||||||
auto iv = i_to_vs[i];
|
auto iv = i_to_vs[i];
|
||||||
auto io = i_to_o[i];
|
auto io = i_to_o[i];
|
||||||
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv);
|
jk << "«IV" << JK::hex1(i) << ':' << JK::shex1(iv);
|
||||||
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io);
|
jk << "«IO" << JK::hex1(i) << ':' << JK::shex1(io);
|
||||||
auto& v = vs.slices[iv];
|
auto& v = vs.slices[iv];
|
||||||
if (iv>=0 && io==-1) {
|
if (iv>=0 && io==-1) {
|
||||||
if (v.is_int()) {
|
if (v.is_int()) {
|
||||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
|
jk << "«VS" << JK::hex1(i) << ":-1";
|
||||||
} else
|
} else
|
||||||
if (v.is_str()) {
|
if (v.is_str()) {
|
||||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-5");
|
jk << "«VS" << JK::hex1(i) << ":-5";
|
||||||
jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str();
|
jk << "«VSS" << JK::hex1(i) << ":" << v.get_str();
|
||||||
} else {
|
} else {
|
||||||
ASSERT(v.is_var());
|
ASSERT(v.is_var());
|
||||||
auto var = v.var;
|
auto var = v.var;
|
||||||
|
@ -238,13 +238,13 @@ void SetitemOp::jit_prepare(JK& jk) {
|
||||||
if (vshape[j] == o_shape[k])
|
if (vshape[j] == o_shape[k])
|
||||||
vsmask |= 1<<(j+var_dim-vdim);
|
vsmask |= 1<<(j+var_dim-vdim);
|
||||||
}
|
}
|
||||||
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask);
|
jk << "«VS" << JK::hex1(i) << '=' << JK::hex(vsmask);
|
||||||
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype();
|
jk << "«VST" << JK::hex1(i) << ':' << var->dtype();
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
if (iv>=0 && io>=0) {
|
if (iv>=0 && io>=0) {
|
||||||
ASSERT(v.is_slice());
|
ASSERT(v.is_slice());
|
||||||
jk << _CS("][VS") << JK::hex1(i) << ':';
|
jk << "«VS" << JK::hex1(i) << ':';
|
||||||
if (std::abs(v.slice.step) <= 1)
|
if (std::abs(v.slice.step) <= 1)
|
||||||
jk << JK::shex1(v.slice.step);
|
jk << JK::shex1(v.slice.step);
|
||||||
else
|
else
|
||||||
|
@ -258,11 +258,10 @@ void SetitemOp::jit_prepare(JK& jk) {
|
||||||
int tdims[6];
|
int tdims[6];
|
||||||
cuda_loop_schedule(o_shape, masks, tdims);
|
cuda_loop_schedule(o_shape, masks, tdims);
|
||||||
for (int i=0; i<no; i++) {
|
for (int i=0; i<no; i++) {
|
||||||
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]);
|
jk << "«LO" << JK::hex1(i) << '=' << JK::hex(masks[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetitemOp::compile_optimize(string& src) {
|
void SetitemOp::compile_optimize(string& src) {
|
||||||
|
|
|
@ -64,10 +64,10 @@ void TernaryOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TernaryOp::jit_prepare(JK& jk) {
|
void TernaryOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tc:") << cond->dtype();
|
jk << "«Tc:" << cond->dtype();
|
||||||
jk << _CS("][Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][Ty:") << y->dtype();
|
jk << "«Ty:" << y->dtype();
|
||||||
jk << _CS("][Tz:") << z->dtype() << ']';
|
jk << "«Tz:" << z->dtype();
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -79,11 +79,10 @@ VarPtr TransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TransposeOp::jit_prepare(JK& jk) {
|
void TransposeOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype();
|
jk << "«Tx:" << x->dtype();
|
||||||
jk << _CS("][DIM=") << JK::hex1(axes.size());
|
jk << "«DIM=" << JK::hex1(axes.size());
|
||||||
for (uint i=0; i<axes.size(); i++)
|
for (uint i=0; i<axes.size(); i++)
|
||||||
jk << _CS("][AXES") << JK::hex1(axes[i]) << '=' << JK::hex1(i);
|
jk << "«AXES" << JK::hex1(axes[i]) << '=' << JK::hex1(i);
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -681,9 +681,9 @@ void UnaryOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void UnaryOp::jit_prepare(JK& jk) {
|
void UnaryOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Tx:") << x->dtype()
|
jk << "«Tx:" << x->dtype()
|
||||||
<< _CS("][Ty:") << y->dtype()
|
<< "«Ty:" << y->dtype()
|
||||||
<< _CS("][OP:") << ns << ']';
|
<< "«OP:" << ns;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -54,10 +54,9 @@ void WhereOp::infer_shape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void WhereOp::jit_prepare(JK& jk) {
|
void WhereOp::jit_prepare(JK& jk) {
|
||||||
jk << _CS("[Ti:") << cond->dtype();
|
jk << "«Ti:" << cond->dtype();
|
||||||
jk << _CS("][To:") << outs[0]->dtype();
|
jk << "«To:" << outs[0]->dtype();
|
||||||
jk << _CS("][NDIM=") << JK::hex1(cond->shape.size());
|
jk << "«NDIM=" << JK::hex1(cond->shape.size());
|
||||||
jk << ']';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // JIT
|
#else // JIT
|
||||||
|
|
|
@ -19,16 +19,16 @@ JIT_TEST(jit_key) {
|
||||||
});
|
});
|
||||||
std::cerr << "get segfault, ok" << std::endl;
|
std::cerr << "get segfault, ok" << std::endl;
|
||||||
|
|
||||||
jk << JK::key << "key" << JK::val << "value" << JK::end;
|
jk << JK::key << "key" << JK::val << "value";
|
||||||
jk << JK::key << "key" << JK::val << JK::hex(0x123123) << JK::end;
|
jk << JK::key << "key" << JK::val << JK::hex(0x123123);
|
||||||
jk << JK::key << "key" << JK::val << JK::hex1(0x123123) << JK::end;
|
jk << JK::key << "key" << JK::val << JK::hex1(0x123123);
|
||||||
jk << JK::key << "key" << JK::val << JK::hex2(0x123123) << JK::end;
|
jk << JK::key << "key" << JK::val << JK::hex2(0x123123);
|
||||||
jk << JK::key << "key" << JK::val << JK::Oxhex(0x123123) << JK::end;
|
jk << JK::key << "key" << JK::val << JK::Oxhex(0x123123);
|
||||||
jk << JK::key << "key" << JK::val << JK::Oxhex1(0x123123) << JK::end;
|
jk << JK::key << "key" << JK::val << JK::Oxhex1(0x123123);
|
||||||
jk << JK::key << "key" << JK::val << JK::Oxhex2(0x123123) << JK::end;
|
jk << JK::key << "key" << JK::val << JK::Oxhex2(0x123123);
|
||||||
string key = "[key:value][key:123123][key:3][key:23][key:0x123123][key:0x3][key:0x23]";
|
string key = "«key:value«key:123123«key:3«key:23«key:0x123123«key:0x3«key:0x23";
|
||||||
ASSERTop(jk.to_string(),==,key);
|
ASSERTop(jk.to_string(),==,key);
|
||||||
auto keys = parse_jit_keys("[a:11][b:22][a[3]:b::[x]][x=11][f=itof(0x0)]");
|
auto keys = parse_jit_keys("«a:11«b:22«a[3]:b::[x]«x=11«f=itof(0x0)");
|
||||||
vector<pair<string,string>> k2 =
|
vector<pair<string,string>> k2 =
|
||||||
{{"a","11"},{"b","22"},{"a[3]","b::[x]"},{"x","17"},{"f","0"}};
|
{{"a","11"},{"b","22"},{"a[3]","b::[x]"},{"x","17"},{"f","0"}};
|
||||||
ASSERTop(keys,==,k2);
|
ASSERTop(keys,==,k2);
|
||||||
|
|
|
@ -94,7 +94,7 @@ JIT_TEST(fused_op_relay_matmul) {
|
||||||
auto allocator = get_allocator();
|
auto allocator = get_allocator();
|
||||||
for (auto& v : fop.vars)
|
for (auto& v : fop.vars)
|
||||||
if (v.type!=1) v.var->alloc(allocator);
|
if (v.type!=1) v.var->alloc(allocator);
|
||||||
auto entry = oc.compile("[OP:_fused_op_relay_matmul]", oc.src);
|
auto entry = oc.compile("«OP:_fused_op_relay_matmul", oc.src);
|
||||||
for (uint i=0; i<a->num; i++)
|
for (uint i=0; i<a->num; i++)
|
||||||
a->ptr<float>()[i] = b->ptr<float>()[i] = 1;
|
a->ptr<float>()[i] = b->ptr<float>()[i] = 1;
|
||||||
entry(&fop);
|
entry(&fop);
|
||||||
|
|
|
@ -46,7 +46,7 @@ def test(shape, op1, op2):
|
||||||
with jt.log_capture_scope(log_v=0, log_vprefix="fused_op.cc=100") as logs:
|
with jt.log_capture_scope(log_v=0, log_vprefix="fused_op.cc=100") as logs:
|
||||||
d__ = d.data
|
d__ = d.data
|
||||||
logs = find_log_with_re(logs,
|
logs = find_log_with_re(logs,
|
||||||
"Jit (fused )?op key (not )?found: \[opkey0:array\[T:float32")
|
"Jit (fused )?op key (not )?found: «opkey0:array«T:float32")
|
||||||
assert(len(logs)==1), logs
|
assert(len(logs)==1), logs
|
||||||
|
|
||||||
a_ = a.data
|
a_ = a.data
|
||||||
|
|
|
@ -81,7 +81,7 @@ def test_case(box_num, out_size, time_limit):
|
||||||
for i in range(1, len(rep)):
|
for i in range(1, len(rep)):
|
||||||
t += float(rep[i][3]) / 1e9
|
t += float(rep[i][3]) / 1e9
|
||||||
name = rep[i][0]
|
name = rep[i][0]
|
||||||
if name.startswith('[') and (not '[graph:]' in name):
|
if name.startswith('«') and (not '«graph:«' in name):
|
||||||
fused_op_num += 1
|
fused_op_num += 1
|
||||||
assert fused_op_num == 1, fused_op_num
|
assert fused_op_num == 1, fused_op_num
|
||||||
assert t <= time_limit, t
|
assert t <= time_limit, t
|
||||||
|
|
|
@ -428,6 +428,15 @@ class TestSetitem(unittest.TestCase):
|
||||||
np.arange(4)[None,::-1]]
|
np.arange(4)[None,::-1]]
|
||||||
np.testing.assert_allclose(nb, b.data)
|
np.testing.assert_allclose(nb, b.data)
|
||||||
|
|
||||||
|
def test_cuda_slice_migrate_bug(self):
|
||||||
|
a = jt.array([1,2,3,4,5])
|
||||||
|
jt.sync_all()
|
||||||
|
if not jt.has_cuda: return
|
||||||
|
with jt.flag_scope(use_cuda=1):
|
||||||
|
b = a[0]
|
||||||
|
b.sync(True)
|
||||||
|
assert b.item() == 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue