mirror of https://github.com/Jittor/Jittor
optimize jit_prepare
This commit is contained in:
parent
b747cfd9bf
commit
f1b7af155f
|
@ -55,10 +55,14 @@ void CubArgReduceOp::infer_shape() {
|
|||
y_key->set_shape(shape);
|
||||
}
|
||||
|
||||
void CubArgReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Toffsets", offsets->dtype());
|
||||
add_jit_define("FUNC", op==ns_minimum ? "ArgMin" : "ArgMax");
|
||||
void CubArgReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Toffsets:") << offsets->dtype();
|
||||
jk << _CS("][FUNC:");
|
||||
if (op==ns_minimum)
|
||||
jk << _CS("ArgMin]");
|
||||
else
|
||||
jk << _CS("ArgMax]");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -47,12 +47,16 @@ void CubArgsortOp::infer_shape() {
|
|||
y_key->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void CubArgsortOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Tindexes", indexes->dtype());
|
||||
add_jit_define("Toffsets", offsets->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("FUNC", descending ? "SortPairsDescending" : "SortPairs");
|
||||
void CubArgsortOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Tindexes:") << indexes->dtype();
|
||||
jk << _CS("][Toffsets:") << offsets->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][FUNC:");
|
||||
if (descending)
|
||||
jk << _CS("SortPairsDescending]");
|
||||
else
|
||||
jk << _CS("SortPairs]");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -22,8 +22,8 @@ CubTestOp::CubTestOp(string cmd) : cmd(cmd) {
|
|||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void CubTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
void CubTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -76,11 +76,12 @@ void CublasBatchedMatmulOp::infer_shape(){
|
|||
c->set_shape(c_shape);
|
||||
}
|
||||
|
||||
void CublasBatchedMatmulOp::jit_prepare() {
|
||||
add_jit_define("T", a->dtype());
|
||||
add_jit_define("Trans_a", trans_a ? "T" : "N");
|
||||
add_jit_define("Trans_b", trans_b ? "T" : "N");
|
||||
add_jit_define("op", a->dtype().dsize() == 4 ? "S" : "D");
|
||||
void CublasBatchedMatmulOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << a->dtype();
|
||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
|
||||
jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D');
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -41,11 +41,12 @@ void CublasMatmulOp::infer_shape() {
|
|||
c->set_shape({n, k});
|
||||
}
|
||||
|
||||
void CublasMatmulOp::jit_prepare() {
|
||||
add_jit_define("T", a->dtype());
|
||||
add_jit_define("Trans_a", trans_a ? "T" : "N");
|
||||
add_jit_define("Trans_b", trans_b ? "T" : "N");
|
||||
add_jit_define("op", a->dtype().dsize() == 4 ? "S" : "D");
|
||||
void CublasMatmulOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << a->dtype();
|
||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
|
||||
jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D');
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -17,8 +17,8 @@ CublasTestOp::CublasTestOp(int size_mult) : size_mult(size_mult) {
|
|||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void CublasTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
void CublasTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -62,13 +62,14 @@ void CudnnConvBackwardWOp::infer_shape() {
|
|||
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
|
||||
}
|
||||
|
||||
void CudnnConvBackwardWOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Ty", dy->dtype());
|
||||
add_jit_define("Tw", dw->dtype());
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
void CudnnConvBackwardWOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << dy->dtype();
|
||||
jk << _CS("][Tw:") << dw->dtype();
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
}
|
||||
unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||
|
||||
|
|
|
@ -62,13 +62,14 @@ void CudnnConvBackwardXOp::infer_shape() {
|
|||
set_shape(dx, "abcd", xformat, xn, xc, xh, xw);
|
||||
}
|
||||
|
||||
void CudnnConvBackwardXOp::jit_prepare() {
|
||||
add_jit_define("Ty", dy->dtype());
|
||||
add_jit_define("Tw", w->dtype());
|
||||
add_jit_define("Tx", dx->dtype());
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
void CudnnConvBackwardXOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << dx->dtype();
|
||||
jk << _CS("][Ty:") << dy->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
}
|
||||
unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||
|
||||
|
|
|
@ -64,13 +64,14 @@ void CudnnConvOp::infer_shape() {
|
|||
set_shape(y, "abcd", yformat, yn, yc, yh, yw);
|
||||
}
|
||||
|
||||
void CudnnConvOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Tw", w->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
void CudnnConvOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
}
|
||||
unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@ CudnnTestOp::CudnnTestOp(string cmd) : cmd(move(cmd)) {
|
|||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void CudnnTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
void CudnnTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -23,9 +23,9 @@ CurandRandomOp::CurandRandomOp(NanoVector shape, NanoString dtype, NanoString ty
|
|||
ASSERT(type == ns_normal || type == ns_uniform);
|
||||
}
|
||||
|
||||
void CurandRandomOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
add_jit_define("R", type);
|
||||
void CurandRandomOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << output->dtype();
|
||||
jk << _CS("][R:") << type << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -20,8 +20,8 @@ CuttTestOp::CuttTestOp(string cmd) : cmd(cmd) {
|
|||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void CuttTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
void CuttTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -58,11 +58,12 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return make_transpose(dout, reverse);
|
||||
}
|
||||
|
||||
void CuttTransposeOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("DIM", JK::hex1(axes.size()));
|
||||
void CuttTransposeOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][DIM=") << JK::hex1(axes.size());
|
||||
for (uint i=0; i<axes.size(); i++)
|
||||
add_jit_define("AXES", JK::hex1(axes[i]), S(i));
|
||||
jk << _CS("][AXES") << JK::hex1(axes[i]) << '=' << JK::hex1(i);
|
||||
jk << ']';
|
||||
}
|
||||
unordered_map<string, unsigned int> cutt_plan_cache;
|
||||
|
||||
|
|
|
@ -36,8 +36,8 @@ VarPtr NcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return nccl_all_reduce(dout);
|
||||
}
|
||||
|
||||
void NcclAllReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
void NcclAllReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -34,8 +34,8 @@ VarPtr NcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return nccl_reduce(dout,root);
|
||||
}
|
||||
|
||||
void NcclBroadcastOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
void NcclBroadcastOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -34,8 +34,8 @@ VarPtr NcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return nccl_broadcast(dout,root);
|
||||
}
|
||||
|
||||
void NcclReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
void NcclReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -19,8 +19,8 @@ NcclTestOp::NcclTestOp(string cmd) : cmd(cmd) {
|
|||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void NcclTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
void NcclTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -77,16 +77,17 @@ static const char* short_type(Var* x) {
|
|||
}
|
||||
}
|
||||
|
||||
void MklConvBackwardWOp::jit_prepare() {
|
||||
add_jit_define("Txd", x->dtype());
|
||||
add_jit_define("Tyd", dy->dtype());
|
||||
add_jit_define("Twd", dw->dtype());
|
||||
add_jit_define("Tx", short_type(x));
|
||||
add_jit_define("Tw", short_type(dw));
|
||||
add_jit_define("Ty", short_type(dy));
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
void MklConvBackwardWOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Txd:") << x->dtype();
|
||||
jk << _CS("][Tyd:") << dy->dtype();
|
||||
jk << _CS("][Twd:") << dw->dtype();
|
||||
jk << _CS("][Tx:") << short_type(x);
|
||||
jk << _CS("][Tw:") << short_type(dw);
|
||||
jk << _CS("][Ty:") << short_type(dy);
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -75,16 +75,17 @@ static const char* short_type(Var* x) {
|
|||
}
|
||||
}
|
||||
|
||||
void MklConvBackwardXOp::jit_prepare() {
|
||||
add_jit_define("Tyd", dy->dtype());
|
||||
add_jit_define("Twd", w->dtype());
|
||||
add_jit_define("Txd", dx->dtype());
|
||||
add_jit_define("Tx", short_type(dx));
|
||||
add_jit_define("Tw", short_type(w));
|
||||
add_jit_define("Ty", short_type(dy));
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
void MklConvBackwardXOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tyd:") << dy->dtype();
|
||||
jk << _CS("][Twd:") << w->dtype();
|
||||
jk << _CS("][Txd:") << dx->dtype();
|
||||
jk << _CS("][Tx:") << short_type(dx);
|
||||
jk << _CS("][Tw:") << short_type(w);
|
||||
jk << _CS("][Ty:") << short_type(dy);
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -79,13 +79,17 @@ static const char* short_type(Var* x) {
|
|||
}
|
||||
}
|
||||
|
||||
void MklConvOp::jit_prepare() {
|
||||
add_jit_define("Tx", short_type(x));
|
||||
add_jit_define("Tw", short_type(w));
|
||||
add_jit_define("Ty", short_type(y));
|
||||
add_jit_define("XFORMAT", xformat);
|
||||
add_jit_define("WFORMAT", wformat);
|
||||
add_jit_define("YFORMAT", yformat);
|
||||
void MklConvOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Txd:") << x->dtype();
|
||||
jk << _CS("][Tyd:") << y->dtype();
|
||||
jk << _CS("][Twd:") << w->dtype();
|
||||
jk << _CS("][Tx:") << short_type(x);
|
||||
jk << _CS("][Tw:") << short_type(w);
|
||||
jk << _CS("][Ty:") << short_type(y);
|
||||
jk << _CS("][XFORMAT:") << xformat;
|
||||
jk << _CS("][WFORMAT:") << wformat;
|
||||
jk << _CS("][YFORMAT:") << yformat;
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -42,10 +42,10 @@ void MklMatmulOp::infer_shape() {
|
|||
c->set_shape({n, k});
|
||||
}
|
||||
|
||||
void MklMatmulOp::jit_prepare() {
|
||||
add_jit_define("T", a->dtype());
|
||||
add_jit_define("Trans_a", trans_a ? "T" : "N");
|
||||
add_jit_define("Trans_b", trans_b ? "T" : "N");
|
||||
void MklMatmulOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << a->dtype();
|
||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N') << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -17,8 +17,8 @@ MklTestOp::MklTestOp() {
|
|||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void MklTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
void MklTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -61,9 +61,9 @@ VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return mpi_all_reduce(dout, ns_add);
|
||||
}
|
||||
|
||||
void MpiAllReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("OP", op.to_cstring());
|
||||
void MpiAllReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][OP:") << op << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -48,8 +48,8 @@ VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return mpi_reduce(dout, ns_add, root);
|
||||
}
|
||||
|
||||
void MpiBroadcastOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
void MpiBroadcastOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -61,9 +61,9 @@ VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return mpi_broadcast(dout,root);
|
||||
}
|
||||
|
||||
void MpiReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("OP", op.to_cstring());
|
||||
void MpiReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][OP:") << op << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -16,8 +16,8 @@ MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) {
|
|||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void MpiTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
void MpiTestOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:float32]");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -54,8 +54,8 @@ CustomOp::CustomOp(NanoVector shape, NanoString dtype) {
|
|||
output = create_output(shape, dtype);
|
||||
}
|
||||
|
||||
void CustomOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
void CustomOp::jit_prepare(JK& jk) {
|
||||
add_jit_define(jk, "T", output->dtype());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.0.6'
|
||||
__version__ = '1.2.0.7
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
|
|
@ -40,8 +40,8 @@ class TestCuda(unittest.TestCase):
|
|||
output = create_output(shape, dtype);
|
||||
}
|
||||
|
||||
void NoCudaOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
void NoCudaOp::jit_prepare(JK& jk) {
|
||||
add_jit_define(jk, "T", output->dtype());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
@ -70,8 +70,8 @@ class TestCuda(unittest.TestCase):
|
|||
output = create_output(shape, dtype);
|
||||
}
|
||||
|
||||
void MyCudaOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
void MyCudaOp::jit_prepare(JK& jk) {
|
||||
add_jit_define(jk, "T", output->dtype());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -35,8 +35,8 @@ CustomOp::CustomOp(NanoVector shape, NanoString dtype) {
|
|||
output = create_output(shape, dtype);
|
||||
}
|
||||
|
||||
void CustomOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
void CustomOp::jit_prepare(JK& jk) {
|
||||
add_jit_define(jk, "T", output->dtype());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
@ -86,8 +86,8 @@ class TestCustomOp(unittest.TestCase):
|
|||
output = create_output(shape, dtype);
|
||||
}
|
||||
|
||||
void MyOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
void MyOp::jit_prepare(JK& jk) {
|
||||
add_jit_define(jk, "T", output->dtype());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -385,9 +385,8 @@ class Gray:
|
|||
img_ = transform(img)
|
||||
'''
|
||||
def __call__(self, img:Image.Image):
|
||||
img = np.array(img.convert('L'))
|
||||
img = img[np.newaxis, :]
|
||||
return np.array((img / 255.0), dtype = np.float32)
|
||||
img = np.float32(img.convert('L')) / np.float32(255.0)
|
||||
return img[np.newaxis, :]
|
||||
|
||||
class RandomCrop:
|
||||
'''
|
||||
|
|
|
@ -378,6 +378,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
#ifdef HAS_CUDA
|
||||
int sync_times = 0;
|
||||
#endif
|
||||
auto& jkl = jk;
|
||||
for (uint rid=0; rid<queue.size(); rid++) {
|
||||
int root = queue[rid];
|
||||
Op* op = ops[root];
|
||||
|
@ -396,7 +397,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
for (auto* var : op->outputs())
|
||||
var->alloc(allocator);
|
||||
LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs();
|
||||
op->do_prepare();
|
||||
op->do_prepare(jkl);
|
||||
bool is_cuda = op->flags.get(NodeFlags::_cuda);
|
||||
#ifdef HAS_CUDA
|
||||
if (!is_cuda) {
|
||||
|
@ -422,7 +423,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
}
|
||||
#endif
|
||||
last_is_cuda = is_cuda;
|
||||
op->do_run_after_prepare();
|
||||
op->do_run_after_prepare(jkl);
|
||||
LOGvvv << "Finished Op(" >> op->name() << rid >>
|
||||
"/" >> queue.size() >> ") output:" << op->outputs();
|
||||
if (is_fused_op) {
|
||||
|
@ -454,7 +455,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
// log memory info
|
||||
display_memory_info(__FILELINE__, false, true);
|
||||
// log jit_key and file location
|
||||
op->do_prepare();
|
||||
op->do_prepare(jkl);
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
|
||||
LOGe << "[Error] source file location:" << jit_src_path;
|
||||
if (is_fused_op) {
|
||||
|
|
|
@ -32,7 +32,7 @@ loop_options_t& FusedOp::get_loop_options_tuned() {
|
|||
|
||||
void FusedOp::update_jit_key() {
|
||||
jk.clear();
|
||||
do_jit_prepare();
|
||||
do_jit_prepare(jk);
|
||||
}
|
||||
|
||||
void FusedOp::update_ops() {
|
||||
|
@ -41,7 +41,6 @@ void FusedOp::update_ops() {
|
|||
loop_options = loop_options_origin = nullptr;
|
||||
|
||||
_outputs.clear();
|
||||
jk.clear();
|
||||
for (Op* op : ops) {
|
||||
for (Var* o : op->outputs()) {
|
||||
if (o->loop_options) {
|
||||
|
@ -156,13 +155,13 @@ void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
|
|||
}
|
||||
}
|
||||
|
||||
void FusedOp::do_jit_prepare() {
|
||||
void FusedOp::do_jit_prepare(JK& jk) {
|
||||
jk.clear();
|
||||
int8 flags = 3;
|
||||
for (uint i=0; i<ops.size(); i++) {
|
||||
Op* op = ops[i];
|
||||
jk << JK::key << "opkey" << i << JK::val;
|
||||
op->do_jit_prepare();
|
||||
jk << "[opkey" << i << JK::val;
|
||||
op->do_jit_prepare(jk);
|
||||
jk << JK::end;
|
||||
if (op->flags.get(NodeFlags::_cpu))
|
||||
flags &= 1; // only cpu
|
||||
|
@ -170,39 +169,39 @@ void FusedOp::do_jit_prepare() {
|
|||
flags &= 2; // only gpu
|
||||
}
|
||||
ASSERT(flags) << "FusedOp cannot contain both cpu and cuda ops.";
|
||||
add_jit_define("JIT", "1");
|
||||
jk << _CS("[JIT:1]");
|
||||
if (flags==1) {
|
||||
// only cpu
|
||||
add_jit_define("JIT_cpu", "1");
|
||||
jk << _CS("[JIT_cpu:1]");
|
||||
this->flags.set(NodeFlags::_cuda, 0);
|
||||
this->flags.set(NodeFlags::_cpu, 1);
|
||||
} else {
|
||||
add_jit_define("JIT_cuda", "1");
|
||||
jk << _CS("[JIT_cuda:1]");
|
||||
this->flags.set(NodeFlags::_cpu, 0);
|
||||
this->flags.set(NodeFlags::_cuda, 1);
|
||||
}
|
||||
jk << JK::key << "graph" << JK::val;
|
||||
jk << _CS("[graph:");
|
||||
for (auto& t : edges) {
|
||||
uint i,j,k,l;
|
||||
std::tie(i,j,k,l) = t;
|
||||
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
|
||||
}
|
||||
jk << JK::end << JK::key << "var_info" << JK::val;
|
||||
for (auto& vi : vars)
|
||||
jk << _CS("][var_info:") << JK::val;
|
||||
for (auto& vi : vars)
|
||||
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
|
||||
jk << JK::end;
|
||||
if (loop_options->size()) {
|
||||
if (get_loop_option("compile_shapes")) {
|
||||
jk << JK::key << "shapes" << JK::val;
|
||||
jk << _CS("[shapes:");
|
||||
for (auto& vi : vars) {
|
||||
jk << '[';
|
||||
for (auto a : vi.var->shape)
|
||||
jk << a << ',';
|
||||
jk << "],";
|
||||
jk << _CS("],");
|
||||
}
|
||||
jk << JK::end;
|
||||
}
|
||||
jk << JK::key << "choices" << JK::val;
|
||||
jk << _CS("[choices:");
|
||||
for (auto& kv : *loop_options)
|
||||
jk << kv.first << ':' << kv.second << ',';
|
||||
jk << JK::end;
|
||||
|
@ -210,11 +209,11 @@ void FusedOp::do_jit_prepare() {
|
|||
jk.finilize();
|
||||
}
|
||||
|
||||
void FusedOp::do_prepare() {
|
||||
do_jit_prepare();
|
||||
void FusedOp::do_prepare(JK& jk) {
|
||||
do_jit_prepare(jk);
|
||||
}
|
||||
|
||||
void FusedOp::do_run_after_prepare() {
|
||||
void FusedOp::do_run_after_prepare(JK& jk) {
|
||||
const char* jit_key = jk.to_cstring();
|
||||
auto iter = jit_fused_ops.find(string_view(jit_key, jk.size));
|
||||
if (iter != jit_fused_ops.end()) {
|
||||
|
@ -230,7 +229,7 @@ void FusedOp::do_run_after_prepare() {
|
|||
context->setup(this);
|
||||
string prev_jit_key = jit_key;
|
||||
context->entry = OpCompiler::do_compile(this);
|
||||
string new_jit_key = get_jit_key();
|
||||
string new_jit_key = get_jit_key(jk);
|
||||
jit_fused_ops[new_jit_key] = jit_fused_ops[prev_jit_key] = context;
|
||||
jit_key_mapper[prev_jit_key] = new_jit_key;
|
||||
LOGvv << "Get jit op entry:" << (void*)(context->entry);
|
||||
|
@ -257,8 +256,8 @@ int FusedOp::has(Node* node) {
|
|||
}
|
||||
|
||||
void FusedOp::do_run(){
|
||||
do_prepare();
|
||||
do_run_after_prepare();
|
||||
do_prepare(jk);
|
||||
do_run_after_prepare(jk);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -50,9 +50,9 @@ struct FusedOp final : Op {
|
|||
void statistics(uint64_t& in, uint64_t& out, uint64_t& compute) override;
|
||||
bool shape_infered() override;
|
||||
void infer_shape() override;
|
||||
void do_jit_prepare() override;
|
||||
void do_prepare() override;
|
||||
void do_run_after_prepare() override;
|
||||
void do_jit_prepare(JK& jk) override;
|
||||
void do_prepare(JK& jk) override;
|
||||
void do_run_after_prepare(JK& jk) override;
|
||||
void do_run() override;
|
||||
#ifdef JIT
|
||||
void jit_run();
|
||||
|
|
|
@ -108,7 +108,7 @@ vector<pair<string,string>> parse_jit_keys(const string& s) {
|
|||
val += c;
|
||||
}
|
||||
}
|
||||
ASSERT(presum==0);
|
||||
ASSERT(presum==0) << s;
|
||||
return jit_keys;
|
||||
}
|
||||
|
||||
|
|
130
src/jit_key.h
130
src/jit_key.h
|
@ -70,18 +70,31 @@ struct JitKey {
|
|||
};
|
||||
};
|
||||
|
||||
struct __jk_int128 {
|
||||
int64 a,b;
|
||||
};
|
||||
struct __jk_int256 {
|
||||
int64 a,b,c,d;
|
||||
};
|
||||
|
||||
extern thread_local JitKey jk;
|
||||
typedef JitKey JK;
|
||||
|
||||
inline JK& operator<<(JK& jk, const char* s) {
|
||||
while (*s) jk.buffer[jk.size++] = *s, s++;
|
||||
int i;
|
||||
for (i=0; s[i]; i++)
|
||||
jk.buffer[jk.size+i] = s[i];
|
||||
jk.size += i;
|
||||
return jk;
|
||||
}
|
||||
|
||||
inline JK& operator<<(JK& jk, const string& s) {
|
||||
for (uint64 i=0; i<s.size(); i++)
|
||||
jk.buffer[jk.size+i] = s[i];
|
||||
jk.size += s.size();
|
||||
auto a = (__jk_int256*)(jk.buffer+jk.size);
|
||||
auto b = (__jk_int256*)(&s[0]);
|
||||
auto len = s.size();
|
||||
for (uint64 i=0; i*32<len; i++)
|
||||
a[i] = b[i];
|
||||
jk.size += len;
|
||||
return jk;
|
||||
}
|
||||
|
||||
|
@ -166,55 +179,138 @@ static inline float64 itof(uint64 a) { return *(float64*)&a; }
|
|||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
inline JK& operator<<(JK& jk, float64 f) {
|
||||
return jk << "itof(0x" << JK::hex(ftoi(f)) << ')';
|
||||
}
|
||||
|
||||
inline JK& operator<<(JK& jk, const NanoString& ns) {
|
||||
return jk << ns.to_cstring();
|
||||
auto a = (__jk_int128*)(jk.buffer+jk.size);
|
||||
auto b = (__jk_int128*)(ns.to_cstring());
|
||||
auto len = ns.len();
|
||||
a[0] = b[0];
|
||||
jk.size += len;
|
||||
return jk;
|
||||
}
|
||||
|
||||
vector<pair<string,string>> parse_jit_keys(const string& s);
|
||||
|
||||
template <class Ta, class Tb>
|
||||
void add_jit_define(const Ta& key, const Tb& val) {
|
||||
void add_jit_define(JK& jk, const Ta& key, const Tb& val) {
|
||||
jk << JK::key << key << JK::val << val << JK::end;
|
||||
}
|
||||
|
||||
template <class Ta, class Tb, class Tc>
|
||||
void add_jit_define(const Ta& key, const Tb& i, const Tc& val) {
|
||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const Tc& val) {
|
||||
jk << JK::key << key << i << JK::val << val << JK::end;
|
||||
}
|
||||
|
||||
|
||||
template <class Ta>
|
||||
void add_jit_define(const Ta& key, const JK::hex& val) {
|
||||
void add_jit_define(JK& jk, const Ta& key, const JK::hex& val) {
|
||||
jk << JK::key << key << JK::hex_val << val << JK::end;
|
||||
}
|
||||
|
||||
template <class Ta, class Tb>
|
||||
void add_jit_define(const Ta& key, const Tb& i, const JK::hex& val) {
|
||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex& val) {
|
||||
jk << JK::key << key << i << JK::hex_val << val << JK::end;
|
||||
}
|
||||
|
||||
template <class Ta>
|
||||
void add_jit_define(const Ta& key, const JK::hex1& val) {
|
||||
void add_jit_define(JK& jk, const Ta& key, const JK::hex1& val) {
|
||||
jk << JK::key << key << JK::hex_val << val << JK::end;
|
||||
}
|
||||
|
||||
template <class Ta, class Tb>
|
||||
void add_jit_define(const Ta& key, const Tb& i, const JK::hex1& val) {
|
||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex1& val) {
|
||||
jk << JK::key << key << i << JK::hex_val << val << JK::end;
|
||||
}
|
||||
|
||||
template <class Ta>
|
||||
void add_jit_define(const Ta& key, const JK::hex2& val) {
|
||||
void add_jit_define(JK& jk, const Ta& key, const JK::hex2& val) {
|
||||
jk << JK::key << key << JK::hex_val << val << JK::end;
|
||||
}
|
||||
|
||||
template <class Ta, class Tb>
|
||||
void add_jit_define(const Ta& key, const Tb& i, const JK::hex2& val) {
|
||||
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex2& val) {
|
||||
jk << JK::key << key << i << JK::hex_val << val << JK::end;
|
||||
}
|
||||
|
||||
|
||||
// begin of const string
|
||||
#define MAX_CONST_CHAR 32
|
||||
|
||||
#define _CS_MIN(a,b) (a)<(b)?(a):(b)
|
||||
|
||||
#define _CS_T(s)\
|
||||
getChr(s,0),\
|
||||
getChr(s,1),\
|
||||
getChr(s,2),\
|
||||
getChr(s,3),\
|
||||
getChr(s,4),\
|
||||
getChr(s,5),\
|
||||
getChr(s,6),\
|
||||
getChr(s,7),\
|
||||
getChr(s,8),\
|
||||
getChr(s,9),\
|
||||
getChr(s,10),\
|
||||
getChr(s,11),\
|
||||
getChr(s,12),\
|
||||
getChr(s,13),\
|
||||
getChr(s,14),\
|
||||
getChr(s,15),\
|
||||
getChr(s,16),\
|
||||
getChr(s,17),\
|
||||
getChr(s,18),\
|
||||
getChr(s,19),\
|
||||
getChr(s,20),\
|
||||
getChr(s,21),\
|
||||
getChr(s,22),\
|
||||
getChr(s,23),\
|
||||
getChr(s,24),\
|
||||
getChr(s,25),\
|
||||
getChr(s,26),\
|
||||
getChr(s,27),\
|
||||
getChr(s,28),\
|
||||
getChr(s,29),\
|
||||
getChr(s,30),\
|
||||
getChr(s,31),\
|
||||
getChr(s,32),\
|
||||
getChr(s,33),\
|
||||
getChr(s,34),\
|
||||
getChr(s,35)
|
||||
|
||||
#define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0)
|
||||
|
||||
#define _CS(str) _CS_G<_CS_T(str)>()
|
||||
|
||||
template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G {
|
||||
};
|
||||
|
||||
template<> struct _CS_G<0,0,0,0> {};
|
||||
|
||||
template <char c1, char c2, char c3, char c4, char... Chars_>
|
||||
inline JK& operator<<(JK& jk, const _CS_G<c1,c2,c3,c4,Chars_...>& _) {
|
||||
((int*)(jk.buffer+jk.size))[0] = c4*(1<<24)+c3*(1<<16)+c2*(1<<8)+c1;
|
||||
if (c4) {
|
||||
jk.size += 4;
|
||||
jk << _CS_G<Chars_...>();
|
||||
} else
|
||||
if (c3) {
|
||||
jk.size += 3;
|
||||
} else
|
||||
if (c2) {
|
||||
jk.size += 2;
|
||||
} else
|
||||
if (c1) {
|
||||
jk.size += 1;
|
||||
}
|
||||
return jk;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline JK& operator<<(JK& jk, const _CS_G<0,0,0,0>& _) {
|
||||
return jk;
|
||||
}
|
||||
|
||||
|
||||
inline JK& operator<<(JK& jk, float64 f) {
|
||||
return jk << _CS("itof(0x") << JK::hex(ftoi(f)) << ')';
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -115,7 +115,8 @@ static unordered_set<string> binary_ops = {
|
|||
FOR_ALL_NS(DEFINE_NS);
|
||||
|
||||
unordered_map<string, NanoString> NanoString::__string_to_ns;
|
||||
vector<const char*> NanoString::__ns_to_string;
|
||||
char NanoString::__ns_to_string[ns_max_size*ns_max_len];
|
||||
int NanoString::__ns_len[ns_max_size];
|
||||
|
||||
static void init_ns() {
|
||||
NanoString::ns_t i=0;
|
||||
|
@ -140,11 +141,17 @@ static void init_ns() {
|
|||
ns.set(NanoString::_bool, is_bool.count(name));
|
||||
}
|
||||
NanoString::__string_to_ns[name] = ns;
|
||||
NanoString::__ns_to_string.push_back(name);
|
||||
auto name2 = ns.to_cstring();
|
||||
int len=0;
|
||||
for (;;len++) {
|
||||
name2[len] = name[len];
|
||||
if (!name[len]) break;
|
||||
}
|
||||
NanoString::__ns_len[i-1] = len;
|
||||
};
|
||||
#define INIT_NS(T) func(#T, ns_##T);
|
||||
FOR_ALL_NS(INIT_NS);
|
||||
ASSERT(NanoString::__ns_to_string.size()<=(1<<NanoString::_index_nbits));
|
||||
ASSERT(i<=(1<<NanoString::_index_nbits));
|
||||
NanoString::__string_to_ns["sum"] = ns_add;
|
||||
NanoString::__string_to_ns["min"] = ns_minimum;
|
||||
NanoString::__string_to_ns["max"] = ns_maximum;
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
constexpr int ns_max_size = 256;
|
||||
constexpr int ns_max_len = 16;
|
||||
|
||||
#define FOR_ALL_NS(m) \
|
||||
\
|
||||
|
@ -107,7 +109,8 @@ struct NanoString {
|
|||
ns_t data=0;
|
||||
|
||||
static unordered_map<string, NanoString> __string_to_ns;
|
||||
static vector<const char*> __ns_to_string;
|
||||
static char __ns_to_string[];
|
||||
static int __ns_len[];
|
||||
|
||||
inline void set(Flags f, ns_t a=1, ns_t nbits=1) {
|
||||
ns_t mask = (((1u<<nbits)-1)<<f);
|
||||
|
@ -118,6 +121,7 @@ struct NanoString {
|
|||
return (data>>f) & ((1u<<nbits)-1);
|
||||
}
|
||||
inline ns_t index() const { return get(_index, _index_nbits); }
|
||||
inline int len() const { return __ns_len[index()]; }
|
||||
inline ns_t type() const { return get(_type, _type_nbits); }
|
||||
inline ns_t is_bool() const { return get(_bool); }
|
||||
inline ns_t is_int() const { return get(_int); }
|
||||
|
@ -140,7 +144,9 @@ struct NanoString {
|
|||
inline NanoString(const string& s) : NanoString(s.c_str()) {}
|
||||
// @pyjt(__repr__)
|
||||
inline const char* to_cstring() const
|
||||
{ return __ns_to_string[index()]; }
|
||||
{ return __ns_to_string+index()*ns_max_len; }
|
||||
inline char* to_cstring()
|
||||
{ return __ns_to_string+index()*ns_max_len; }
|
||||
};
|
||||
|
||||
// force_type = 1 for int, 2 for float
|
||||
|
|
|
@ -26,6 +26,7 @@ struct string_view_map {
|
|||
iter_t end() { return umap.end(); }
|
||||
|
||||
const T& at(string_view sv) { return umap.at(sv); }
|
||||
size_t size() { return umap.size(); }
|
||||
|
||||
T& operator[](string_view sv) {
|
||||
auto iter = find(sv);
|
||||
|
|
35
src/op.cc
35
src/op.cc
|
@ -79,7 +79,7 @@ void Op::compile_optimize(string& src) {}
|
|||
|
||||
void Op::infer_shape() {}
|
||||
void Op::run() {}
|
||||
void Op::jit_prepare() {}
|
||||
void Op::jit_prepare(JK& jk) {}
|
||||
void Op::graph_optimize() {}
|
||||
|
||||
string Op::name_ex() const {
|
||||
|
@ -91,20 +91,20 @@ string Op::name_ex() const {
|
|||
return a;
|
||||
}
|
||||
|
||||
string Op::get_jit_key() {
|
||||
string Op::get_jit_key(JK& jk) {
|
||||
jk.clear();
|
||||
do_jit_prepare();
|
||||
do_jit_prepare(jk);
|
||||
return jk.to_string();
|
||||
}
|
||||
|
||||
vector<pair<string,string>> Op::get_jit_define() {
|
||||
return parse_jit_keys(get_jit_key());
|
||||
return parse_jit_keys(get_jit_key(jk));
|
||||
}
|
||||
|
||||
void Op::do_jit_prepare() {
|
||||
void Op::do_jit_prepare(JK& jk) {
|
||||
memcheck_all_exist();
|
||||
jk << name();
|
||||
jit_prepare();
|
||||
jit_prepare(jk);
|
||||
if (jk.empty()) {
|
||||
// not a jit op
|
||||
bool has_cuda = flags.get(NodeFlags::_cuda);
|
||||
|
@ -144,9 +144,9 @@ void Op::do_jit_prepare() {
|
|||
use_int64_t = true;
|
||||
out_id ++;
|
||||
}
|
||||
add_jit_define("JIT", "1");
|
||||
jk << _CS("[JIT:1]");
|
||||
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
|
||||
add_jit_define("JIT_cuda", "1");
|
||||
jk << _CS("[JIT_cuda:1]");
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
// TODO: 64bit index in CUDA
|
||||
use_int64_t = false;
|
||||
|
@ -159,21 +159,24 @@ void Op::do_jit_prepare() {
|
|||
}
|
||||
ASSERT(flags.get(NodeFlags::_cpu))
|
||||
<< "Op" << name() << "doesn't have cpu version";
|
||||
add_jit_define("JIT_cpu", "1");
|
||||
jk << _CS("[JIT_cpu:1]");
|
||||
flags.set(NodeFlags::_cuda, 0);
|
||||
}
|
||||
if (try_use_32bit_index) use_int64_t = false;
|
||||
add_jit_define("index_t", use_int64_t ? "int64" : "int32");
|
||||
if (use_int64_t)
|
||||
jk << _CS("[index_t:int64]");
|
||||
else
|
||||
jk << _CS("[index_t:int32]");
|
||||
}
|
||||
jk.finilize();
|
||||
}
|
||||
|
||||
void Op::do_prepare(){
|
||||
void Op::do_prepare(JK& jk){
|
||||
jk.clear();
|
||||
do_jit_prepare();
|
||||
do_jit_prepare(jk);
|
||||
}
|
||||
|
||||
void Op::do_run_after_prepare() {
|
||||
void Op::do_run_after_prepare(JK& jk) {
|
||||
if (!jk.empty())
|
||||
jit_run();
|
||||
else
|
||||
|
@ -181,8 +184,8 @@ void Op::do_run_after_prepare() {
|
|||
}
|
||||
|
||||
void Op::do_run() {
|
||||
do_prepare();
|
||||
do_run_after_prepare();
|
||||
do_prepare(jk);
|
||||
do_run_after_prepare(jk);
|
||||
}
|
||||
|
||||
string Op::get_filename_from_jit_key(const string& jit_key, const string& suffix) {
|
||||
|
@ -245,7 +248,7 @@ void Op::jit_run() {
|
|||
// compile JIT op
|
||||
string prev_jit_key = jit_key;
|
||||
auto op_entry = OpCompiler::do_compile(this);
|
||||
string new_jit_key = get_jit_key();
|
||||
string new_jit_key = get_jit_key(jk);
|
||||
jit_ops[new_jit_key] = jit_ops[prev_jit_key] = op_entry;
|
||||
jit_key_mapper[prev_jit_key] = new_jit_key;
|
||||
LOGvv << "Get jit op entry:" << (void*)op_entry;
|
||||
|
|
12
src/op.h
12
src/op.h
|
@ -40,12 +40,12 @@ struct Op : Node {
|
|||
virtual void grads(Var** douts, VarPtr* dins);
|
||||
virtual void infer_shape();
|
||||
virtual void run();
|
||||
virtual void jit_prepare();
|
||||
virtual void do_jit_prepare();
|
||||
virtual void jit_prepare(JK& jk);
|
||||
virtual void do_jit_prepare(JK& jk);
|
||||
virtual const char* name() const = 0;
|
||||
virtual void statistics(uint64_t& in, uint64_t& out, uint64_t& compute);
|
||||
virtual void do_prepare();
|
||||
virtual void do_run_after_prepare();
|
||||
virtual void do_prepare(JK& jk);
|
||||
virtual void do_run_after_prepare(JK& jk);
|
||||
virtual void do_run();
|
||||
virtual VarPtr duplicate();
|
||||
virtual void compile_optimize(string& src);
|
||||
|
@ -53,7 +53,7 @@ struct Op : Node {
|
|||
void jit_run();
|
||||
|
||||
string name_ex() const;
|
||||
string get_jit_key();
|
||||
string get_jit_key(JK& jk);
|
||||
vector<pair<string,string>> get_jit_define();
|
||||
};
|
||||
|
||||
|
@ -66,7 +66,7 @@ extern string_view_map<string> jit_key_mapper;
|
|||
#ifdef JIT
|
||||
#define DECLARE_jit_run void jit_run();
|
||||
#else
|
||||
#define DECLARE_jit_run void jit_prepare() override;
|
||||
#define DECLARE_jit_run void jit_prepare(JK& jk) override;
|
||||
#endif
|
||||
|
||||
} // jittor
|
|
@ -1000,7 +1000,7 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
|
|||
src = &src_after_passes;
|
||||
}
|
||||
op->compile_optimize(*src);
|
||||
auto ret = oc.compile(op->get_jit_key(), *src);
|
||||
auto ret = oc.compile(op->get_jit_key(jk), *src);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -155,14 +155,15 @@ void ArgReduceOp::infer_shape() {
|
|||
y_key->set_shape(shape);
|
||||
}
|
||||
|
||||
void ArgReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
add_jit_define("YDIM", JK::hex1(y->shape.size()));
|
||||
add_jit_define("KEEPDIMS", keepdims ? 1 : 0);
|
||||
add_jit_define("DIM", JK::hex1(dim));
|
||||
add_jit_define("CMP", op==ns_minimum ? "<" : ">");
|
||||
void ArgReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][XDIM=") << JK::hex1(x->shape.size());
|
||||
jk << _CS("][YDIM=") << JK::hex1(y->shape.size());
|
||||
jk << _CS("][KEEPDIMS:") << (keepdims ? '1' : '0');
|
||||
jk << _CS("][DIM=") << JK::hex1(dim);
|
||||
jk << _CS("][CMP:") << (op==ns_minimum ? "<" : ">");
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -127,12 +127,13 @@ void ArgsortOp::infer_shape() {
|
|||
y_key->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void ArgsortOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
add_jit_define("DIM", JK::hex1(dim));
|
||||
add_jit_define("CMP", descending ? ">" : "<");
|
||||
void ArgsortOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][XDIM=") << JK::hex1(x->shape.size());
|
||||
jk << _CS("][DIM=") << JK::hex1(dim);
|
||||
jk << _CS("][CMP:") << (descending ? '>' : '<');
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -74,9 +74,9 @@ ArrayOp::ArrayOp(ArrayArgs&& args) {
|
|||
std::memcpy(allocation.ptr, args.ptr, output->size);
|
||||
}
|
||||
|
||||
void ArrayOp::jit_prepare() {
|
||||
void ArrayOp::jit_prepare(JK& jk) {
|
||||
if (output->flags.get(NodeFlags::_force_fuse))
|
||||
add_jit_define("T", output->dtype());
|
||||
jk << _CS("[T:") << output->dtype() << ']';
|
||||
}
|
||||
|
||||
void ArrayOp::run() {
|
||||
|
|
|
@ -28,7 +28,7 @@ struct ArrayOp : Op {
|
|||
|
||||
const char* name() const override { return "array"; }
|
||||
void run() override;
|
||||
void jit_prepare() override;
|
||||
void jit_prepare(JK& jk) override;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -180,11 +180,11 @@ void BinaryOp::infer_shape() {
|
|||
z->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void BinaryOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("Tz", z->dtype());
|
||||
add_jit_define("OP", ns.to_cstring());
|
||||
void BinaryOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][Ty:") << y->dtype()
|
||||
<< _CS("][Tz:") << z->dtype()
|
||||
<< _CS("][OP:") << ns << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -164,10 +164,10 @@ void BroadcastToOp::infer_shape() {
|
|||
LOGvvv << "Broadcast x(" >> x >> ") shape" << yshapes << "-> z(" >> z >> ")";
|
||||
}
|
||||
|
||||
void BroadcastToOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("DIM", JK::hex1(z->shape.size()));
|
||||
add_jit_define("BCAST", JK::hex(bcast_mask));
|
||||
void BroadcastToOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][DIM=") << JK::hex1(z->shape.size())
|
||||
<< _CS("][BCAST=") << JK::hex(bcast_mask) << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -23,11 +23,11 @@ void CandidateOp::infer_shape() {
|
|||
y->set_shape({-std::abs(x->shape[0])});
|
||||
}
|
||||
|
||||
void CandidateOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("FUNC", fail_cond);
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
void CandidateOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][FUNC:") << fail_cond;
|
||||
jk << _CS("][XDIM=") << JK::hex1(x->shape.size()) << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -111,28 +111,28 @@ VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
);
|
||||
}
|
||||
|
||||
void CodeOp::jit_prepare() {
|
||||
void CodeOp::jit_prepare(JK& jk) {
|
||||
|
||||
// forward: in0 in1 in2 -> out0 out1
|
||||
// backward: in0 in1 in2 in3(pout0) in4(pout1)
|
||||
add_jit_define("IN_SIZE", JK::hex1(_inputs.size()));
|
||||
jk << _CS("[IN_SIZE=") << JK::hex(_inputs.size());
|
||||
for (uint i=0; i<_inputs.size(); i++) {
|
||||
jk << JK::key << "in" << JK::hex1(i) << "_dim" <<
|
||||
JK::val << JK::hex1(_inputs[i]->shape.size()) << JK::end;
|
||||
jk << JK::key << "in" << JK::hex1(i) << "_type" <<
|
||||
JK::val << _inputs[i]->dtype() << JK::end;
|
||||
jk << _CS("][in") << JK::hex(i) << _CS("_dim=")
|
||||
<< JK::hex1(_inputs[i]->shape.size());
|
||||
jk << _CS("][in") << JK::hex(i) << _CS("_type:")
|
||||
<< _inputs[i]->dtype();
|
||||
}
|
||||
add_jit_define("OUT_SIZE", JK::hex1(_outputs.size()));
|
||||
jk << _CS("][OUT_SIZE=") << JK::hex(_outputs.size());
|
||||
for (uint i=0; i<_outputs.size(); i++) {
|
||||
jk << JK::key << "out" << JK::hex1(i) << "_dim" <<
|
||||
JK::val << JK::hex1(_outputs[i]->shape.size()) << JK::end;
|
||||
jk << JK::key << "out" << JK::hex1(i) << "_type" <<
|
||||
JK::val << _outputs[i]->dtype() << JK::end;
|
||||
jk << _CS("][out") << JK::hex(i) << _CS("_dim=")
|
||||
<< JK::hex1(_outputs[i]->shape.size());
|
||||
jk << _CS("][out") << JK::hex(i) << _CS("_type:")
|
||||
<< _outputs[i]->dtype();
|
||||
}
|
||||
if (flags.get(NodeFlags::_cuda)) {
|
||||
jk << JK::key << "HEADER" << JK::val << cuda_header;
|
||||
jk << _CS("][HEADER:") << cuda_header;
|
||||
ASSERT(cuda_src.size());
|
||||
jk << "\nnamespace jittor {\n";
|
||||
jk << _CS("\nnamespace jittor {\n");
|
||||
int i=0;
|
||||
// move cuda kernel function into header
|
||||
for (; i<cuda_src.size(); i++) {
|
||||
|
@ -153,13 +153,12 @@ void CodeOp::jit_prepare() {
|
|||
}
|
||||
} else break;
|
||||
}
|
||||
jk << "}" << JK::end << JK::key << "CODE" << JK::val;
|
||||
jk << _CS("}][CODE:");
|
||||
for (; i<cuda_src.size(); i++) jk << cuda_src[i];
|
||||
jk << JK::end;
|
||||
jk << ']';
|
||||
} else {
|
||||
add_jit_define("HEADER", cpu_header);
|
||||
jk << JK::key << "CODE" << JK::val;
|
||||
jk << cpu_src << JK::end;
|
||||
jk << _CS("][HEADER:") << cpu_header;
|
||||
jk << _CS("][CODE:") << cpu_src << ']';
|
||||
ASSERT(cpu_src.size());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,8 +55,8 @@ void ConcatOp::infer_shape() {
|
|||
}
|
||||
y->set_shape(shape);
|
||||
}
|
||||
void ConcatOp::jit_prepare() {
|
||||
add_jit_define("T", "int");
|
||||
void ConcatOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:int]");
|
||||
}
|
||||
|
||||
VarPtr ConcatOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
|
|
|
@ -134,7 +134,7 @@ void GetitemOp::infer_slices(
|
|||
i_to_o[i] = out_shape.size();
|
||||
if (in_shape_i > 0) {
|
||||
slice.fill(in_shape_i);
|
||||
if (abs(slice.step) <= 1)
|
||||
if (std::abs(slice.step) <= 1)
|
||||
out_shape_j = (slice.stop - slice.start) * slice.step;
|
||||
else if (slice.step>0)
|
||||
out_shape_j = (slice.stop - slice.start - 1) / slice.step + 1;
|
||||
|
@ -376,25 +376,25 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return make_setitem(zeros, VarSlices(vs), dout, ns_void);
|
||||
}
|
||||
|
||||
void GetitemOp::jit_prepare() {
|
||||
void GetitemOp::jit_prepare(JK& jk) {
|
||||
auto in = inputs().front();
|
||||
int idim = i_to_vs.size();
|
||||
add_jit_define("Ti", in->dtype());
|
||||
add_jit_define("IDIM", JK::hex1(i_to_vs.size()));
|
||||
add_jit_define("ODIM", JK::hex1(o_shape.size()));
|
||||
jk << _CS("[Ti:") << in->dtype();
|
||||
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size());
|
||||
jk << _CS("][ODIM=") << JK::hex1(o_shape.size());
|
||||
if (first_oid_of_var>=0) {
|
||||
add_jit_define("FOV", JK::hex1(first_oid_of_var));
|
||||
add_jit_define("VD", JK::hex1(var_dim));
|
||||
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var);
|
||||
jk << _CS("][VD=") << JK::hex1(var_dim);
|
||||
}
|
||||
for (int i=0; i<idim; i++) {
|
||||
auto iv = i_to_vs[i];
|
||||
auto io = i_to_o[i];
|
||||
add_jit_define("IV", JK::hex1(i), JK::shex1(iv));
|
||||
add_jit_define("IO", JK::hex1(i), JK::shex1(io));
|
||||
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv);
|
||||
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io);
|
||||
auto& v = vs.slices[iv];
|
||||
if (iv>=0 && io==-1) {
|
||||
if (v.is_int()) {
|
||||
add_jit_define("VS", JK::hex1(i), "-1");
|
||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
|
||||
} else {
|
||||
ASSERT(v.is_var());
|
||||
auto var = v.var;
|
||||
|
@ -406,16 +406,17 @@ void GetitemOp::jit_prepare() {
|
|||
if (vshape[j] == o_shape[k])
|
||||
vsmask |= 1<<(j+var_dim-vdim);
|
||||
}
|
||||
add_jit_define("VS", JK::hex1(i), JK::hex(vsmask));
|
||||
add_jit_define("VST", JK::hex1(i), var->dtype());
|
||||
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask);
|
||||
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype();
|
||||
}
|
||||
} else
|
||||
if (iv>=0 && io>=0) {
|
||||
ASSERT(v.is_slice());
|
||||
jk << _CS("][VS") << JK::hex1(i) << ':';
|
||||
if (std::abs(v.slice.step) <= 1)
|
||||
add_jit_define("VS", JK::hex1(i), JK::shex1(v.slice.step));
|
||||
jk << JK::shex1(v.slice.step);
|
||||
else
|
||||
add_jit_define("VS", JK::hex1(i), "0");
|
||||
jk << '0';
|
||||
}
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
|
@ -425,10 +426,11 @@ void GetitemOp::jit_prepare() {
|
|||
int tdims[6];
|
||||
cuda_loop_schedule(o_shape, masks, tdims);
|
||||
for (int i=0; i<no; i++) {
|
||||
add_jit_define("LO", JK::hex1(i), JK::hex(masks[i]));
|
||||
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -50,10 +50,10 @@ void IndexOp::infer_shape() {
|
|||
o->set_shape(a->shape);
|
||||
}
|
||||
|
||||
void IndexOp::jit_prepare() {
|
||||
add_jit_define("T", x[0]->dtype());
|
||||
add_jit_define("DIM", JK::hex1(dim));
|
||||
add_jit_define("XDIM", JK::hex1(x[0]->shape.size()));
|
||||
void IndexOp::jit_prepare(JK& jk) {
|
||||
add_jit_define(jk, "T", x[0]->dtype());
|
||||
add_jit_define(jk, "DIM", JK::hex1(dim));
|
||||
add_jit_define(jk, "XDIM", JK::hex1(x[0]->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -37,9 +37,9 @@ RandomOp::RandomOp(NanoVector shape, NanoString dtype, NanoString type) {
|
|||
ASSERT(type == ns_normal || type == ns_uniform);
|
||||
}
|
||||
|
||||
void RandomOp::jit_prepare() {
|
||||
add_jit_define("T", output->dtype());
|
||||
add_jit_define("R", type);
|
||||
void RandomOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:") << output->dtype();
|
||||
jk << _CS("][R:") << type << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -134,13 +134,13 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void ReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("Tz", y->dtype());
|
||||
add_jit_define("OP", ns.to_cstring());
|
||||
add_jit_define("DIM", JK::hex1(x->shape.size()));
|
||||
add_jit_define("REDUCE", JK::hex(reduce_mask));
|
||||
void ReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][Ty:") << y->dtype()
|
||||
<< _CS("][Tz:") << y->dtype()
|
||||
<< _CS("][OP:") << ns
|
||||
<< _CS("][DIM=") << JK::hex1(x->shape.size())
|
||||
<< _CS("][REDUCE=") << JK::hex(reduce_mask) << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -86,21 +86,22 @@ void ReindexOp::infer_shape() {
|
|||
CHECK(y->shape.size()) << "Number of shape should greater than 0.";
|
||||
}
|
||||
|
||||
void ReindexOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
add_jit_define("YDIM", JK::hex1(y->shape.size()));
|
||||
add_jit_define("OVERFLOW", overflow_value);
|
||||
void ReindexOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][XDIM=") << JK::hex1(x->shape.size())
|
||||
<< _CS("][YDIM=") << JK::hex1(y->shape.size())
|
||||
<< _CS("][OVERFLOW:") << overflow_value;
|
||||
for (uint i=0; i<indexes.size(); i++)
|
||||
add_jit_define("INDEX", JK::hex1(i), indexes[i]);
|
||||
add_jit_define("OSIZE", JK::hex1(overflow_conditions.size()));
|
||||
jk << _CS("][INDEX") << JK::hex1(i) << ':' << indexes[i];
|
||||
jk << _CS("][OSIZE=") << JK::hex1(overflow_conditions.size());
|
||||
for (uint i=0; i<overflow_conditions.size(); i++)
|
||||
add_jit_define("OFD", JK::hex1(i), overflow_conditions[i]);
|
||||
add_jit_define("ESIZE", JK::hex1(extras.size()));
|
||||
jk << _CS("][OFD") << JK::hex1(i) << ':' << overflow_conditions[i];
|
||||
jk << _CS("][ESIZE=") << JK::hex1(extras.size());
|
||||
for (uint i=0; i<extras.size(); i++) {
|
||||
add_jit_define("EDIM", JK::hex1(i), JK::hex1(extras[i]->shape.size()));
|
||||
add_jit_define("Te", JK::hex1(i), extras[i]->dtype());
|
||||
jk << _CS("][EDIM") << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
|
||||
jk << _CS("][Te") << JK::hex1(i) << ':' << extras[i]->dtype();
|
||||
}
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -63,21 +63,22 @@ void ReindexReduceOp::infer_shape() {
|
|||
CHECKop(y->size,>=,0u);
|
||||
}
|
||||
|
||||
void ReindexReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("OP", ns.to_cstring());
|
||||
add_jit_define("YDIM", JK::hex1(y->shape.size()));
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
void ReindexReduceOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][OP:") << ns
|
||||
<< _CS("][YDIM=") << JK::hex1(y->shape.size())
|
||||
<< _CS("][XDIM=") << JK::hex1(x->shape.size());
|
||||
for (uint i=0; i<indexes.size(); i++)
|
||||
add_jit_define("INDEX", JK::hex1(i), indexes[i]);
|
||||
add_jit_define("OSIZE", JK::hex1(overflow_conditions.size()));
|
||||
jk << _CS("][INDEX") << JK::hex1(i) << ':' << indexes[i];
|
||||
jk << _CS("][OSIZE=") << JK::hex1(overflow_conditions.size());
|
||||
for (uint i=0; i<overflow_conditions.size(); i++)
|
||||
add_jit_define("OFD", JK::hex1(i), overflow_conditions[i]);
|
||||
add_jit_define("ESIZE", JK::hex1(extras.size()));
|
||||
jk << _CS("][OFD") << JK::hex1(i) << ':' << overflow_conditions[i];
|
||||
jk << _CS("][ESIZE=") << JK::hex1(extras.size());
|
||||
for (uint i=0; i<extras.size(); i++) {
|
||||
add_jit_define("EDIM", JK::hex1(i), JK::hex1(extras[i]->shape.size()));
|
||||
add_jit_define("Te", JK::hex1(i), extras[i]->dtype());
|
||||
jk << _CS("][EDIM") << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
|
||||
jk << _CS("][Te") << JK::hex1(i) << ':' << extras[i]->dtype();
|
||||
}
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -173,7 +173,7 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void SetitemOp::jit_prepare() {
|
||||
void SetitemOp::jit_prepare(JK& jk) {
|
||||
for (int i=0; i<o_shape.size(); i++)
|
||||
if (o_shape[i]<0) {
|
||||
// because output shape is inferd, check in
|
||||
|
@ -183,28 +183,28 @@ void SetitemOp::jit_prepare() {
|
|||
break;
|
||||
}
|
||||
auto data = input(1);
|
||||
add_jit_define("OP", op);
|
||||
add_jit_define("Td", data->dtype());
|
||||
add_jit_define("BMASK", JK::hex(bmask));
|
||||
jk << _CS("[OP:") << op
|
||||
<< _CS("][Td:") << data->dtype()
|
||||
<< _CS("][BMASK=") << JK::hex(bmask);
|
||||
// TODO: merge code
|
||||
auto in = inputs().front();
|
||||
int idim = i_to_vs.size();
|
||||
add_jit_define("Ti", in->dtype());
|
||||
add_jit_define("IDIM", JK::hex1(i_to_vs.size()));
|
||||
add_jit_define("ODIM", JK::hex1(o_shape.size()));
|
||||
jk << _CS("][Ti:") << in->dtype();
|
||||
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size());
|
||||
jk << _CS("][ODIM=") << JK::hex1(o_shape.size());
|
||||
if (first_oid_of_var>=0) {
|
||||
add_jit_define("FOV", JK::hex1(first_oid_of_var));
|
||||
add_jit_define("VD", JK::hex1(var_dim));
|
||||
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var);
|
||||
jk << _CS("][VD=") << JK::hex1(var_dim);
|
||||
}
|
||||
for (int i=0; i<idim; i++) {
|
||||
auto iv = i_to_vs[i];
|
||||
auto io = i_to_o[i];
|
||||
add_jit_define("IV", JK::hex1(i), JK::shex1(iv));
|
||||
add_jit_define("IO", JK::hex1(i), JK::shex1(io));
|
||||
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv);
|
||||
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io);
|
||||
auto& v = vs.slices[iv];
|
||||
if (iv>=0 && io==-1) {
|
||||
if (v.is_int()) {
|
||||
add_jit_define("VS", JK::hex1(i), "-1");
|
||||
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
|
||||
} else {
|
||||
ASSERT(v.is_var());
|
||||
auto var = v.var;
|
||||
|
@ -216,16 +216,17 @@ void SetitemOp::jit_prepare() {
|
|||
if (vshape[j] == o_shape[k])
|
||||
vsmask |= 1<<(j+var_dim-vdim);
|
||||
}
|
||||
add_jit_define("VS", JK::hex1(i), JK::hex(vsmask));
|
||||
add_jit_define("VST", JK::hex1(i), var->dtype());
|
||||
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask);
|
||||
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype();
|
||||
}
|
||||
} else
|
||||
if (iv>=0 && io>=0) {
|
||||
ASSERT(v.is_slice());
|
||||
jk << _CS("][VS") << JK::hex1(i) << ':';
|
||||
if (std::abs(v.slice.step) <= 1)
|
||||
add_jit_define("VS", JK::hex1(i), JK::shex1(v.slice.step));
|
||||
jk << JK::shex1(v.slice.step);
|
||||
else
|
||||
add_jit_define("VS", JK::hex1(i), "0");
|
||||
jk << '0';
|
||||
}
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
|
@ -235,10 +236,11 @@ void SetitemOp::jit_prepare() {
|
|||
int tdims[6];
|
||||
cuda_loop_schedule(o_shape, masks, tdims);
|
||||
for (int i=0; i<no; i++) {
|
||||
add_jit_define("LO", JK::hex1(i), JK::hex(masks[i]));
|
||||
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
void SetitemOp::compile_optimize(string& src) {
|
||||
|
|
|
@ -45,11 +45,11 @@ void TernaryOp::infer_shape() {
|
|||
z->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void TernaryOp::jit_prepare() {
|
||||
add_jit_define("Tc", cond->dtype());
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("Tz", z->dtype());
|
||||
void TernaryOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tc:") << cond->dtype();
|
||||
jk << _CS("][Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][Tz:") << z->dtype() << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -105,11 +105,12 @@ VarPtr TransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return make_transpose(dout, reverse);
|
||||
}
|
||||
|
||||
void TransposeOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("DIM", JK::hex1(axes.size()));
|
||||
void TransposeOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][DIM=") << JK::hex1(axes.size());
|
||||
for (uint i=0; i<axes.size(); i++)
|
||||
add_jit_define("AXES", JK::hex1(axes[i]), S(i));
|
||||
jk << _CS("][AXES") << JK::hex1(axes[i]) << '=' << JK::hex1(i);
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -198,10 +198,10 @@ void UnaryOp::infer_shape() {
|
|||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void UnaryOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("OP", ns.to_cstring());
|
||||
void UnaryOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype()
|
||||
<< _CS("][Ty:") << y->dtype()
|
||||
<< _CS("][OP:") << ns << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -31,10 +31,10 @@ void WhereOp::infer_shape() {
|
|||
outs[i]->set_shape({num});
|
||||
}
|
||||
|
||||
void WhereOp::jit_prepare() {
|
||||
add_jit_define("Ti", cond->dtype());
|
||||
add_jit_define("To", outs[0]->dtype());
|
||||
add_jit_define("NDIM", JK::hex1(cond->shape.size()));
|
||||
void WhereOp::jit_prepare(JK& jk) {
|
||||
add_jit_define(jk, "Ti", cond->dtype());
|
||||
add_jit_define(jk, "To", outs[0]->dtype());
|
||||
add_jit_define(jk, "NDIM", JK::hex1(cond->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -1040,11 +1040,11 @@ void KernelIR::split_loop(int i, int j) {
|
|||
inner[2]->attrs["code"] = lvalue+"+="+rvalue2+";";
|
||||
push_back("for ("+dtype+" id"+sj+"=0; id"+sj+"<range"+sj+"; id"+sj+"++) {}");
|
||||
auto& sloop = children.back();
|
||||
int range, stride;
|
||||
int range=0, stride=0;
|
||||
if (get_number("range"+si, range) && get_number("stride"+si, stride) && (range%stride==0))
|
||||
push_back(dtype+" range"+sj+" = "+S(stride)+";", &inner);
|
||||
else {
|
||||
ASSERT(range != -1 && stride != -1) << range << stride;
|
||||
ASSERT(range != -1 && stride != -1) << range << stride << si;
|
||||
push_back(dtype+" range"+sj+" = ::min(range"+si+"-id"+si+", stride"+si+");", &inner);
|
||||
}
|
||||
sloop->attrs["loop_id"] = sj;
|
||||
|
@ -1081,7 +1081,7 @@ void KernelIR::resplit() {
|
|||
// define
|
||||
push_front(dtype+" "+lvalue+" = 0;", &before);
|
||||
|
||||
int num;
|
||||
int num=0;
|
||||
if (get_number(rvalue2, num)) {
|
||||
// range = number;
|
||||
inner[3]->attrs["rvalue"] = S(num);
|
||||
|
|
|
@ -41,7 +41,7 @@ void UnrollPass::run() {
|
|||
if (choice==1)
|
||||
loop->push_back("#pragma unroll", &loop->before);
|
||||
else {
|
||||
int num;
|
||||
int num=0;
|
||||
auto& split_id = loop->get_attr("split_id");
|
||||
auto& loop_id = loop->get_attr("loop_id");
|
||||
auto& rvalue = loop->get_attr("rvalue");
|
||||
|
|
|
@ -31,7 +31,7 @@ void VectorizePass::run() {
|
|||
if (choice == 1) {
|
||||
loop->push_back("#pragma vector", &loop->before);
|
||||
} else if (choice > 1) {
|
||||
int num;
|
||||
int num=0;
|
||||
if (!loop->get_number(loop->get_attr("rvalue"), num)) {
|
||||
if (loop->has_attr("split_id")) {
|
||||
string si = loop->attrs["split_id"];
|
||||
|
|
|
@ -233,7 +233,7 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
continue;
|
||||
Op* ops[3] = {op, bop->x->input(), bop->y->input()};
|
||||
int ok = 0;
|
||||
LOGvvvv << "conv like op" << fop << fop->get_jit_key();
|
||||
LOGvvvv << "conv like op" << fop << fop->get_jit_key(jk);
|
||||
for (int y_id=0; y_id<3; y_id++)
|
||||
for (int x_id=0; x_id<3; x_id++)
|
||||
for (int w_id=0; w_id<3; w_id++) {
|
||||
|
|
|
@ -68,7 +68,7 @@ int VarRelayManager::add_relay_group(const vector<pair<Var*, Var*>>& group) {
|
|||
if (node->is_var())
|
||||
continue;
|
||||
Op* op = node->op();
|
||||
op->do_jit_prepare();
|
||||
op->do_jit_prepare(jk);
|
||||
list<Node*> new_inputs;
|
||||
int removed = 0;
|
||||
for (Var* v : op->inputs())
|
||||
|
|
|
@ -89,14 +89,27 @@ struct SimpleThreads {
|
|||
}
|
||||
};
|
||||
|
||||
static int last_compiled_op_num = 0;
|
||||
static int not_compile_window = 0;
|
||||
|
||||
void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int64 tt) {
|
||||
// jit_search_kernel require compile at runtime
|
||||
if (jit_search_kernel || !use_parallel_op_compiler)
|
||||
if (jit_search_kernel || !use_parallel_op_compiler || not_compile_window > 1000)
|
||||
return;
|
||||
|
||||
// try not use parallel compile if no op needs compile
|
||||
if (last_compiled_op_num != jit_key_mapper.size()) {
|
||||
not_compile_window = 0;
|
||||
last_compiled_op_num = jit_key_mapper.size();
|
||||
} else {
|
||||
not_compile_window += queue.size();
|
||||
}
|
||||
|
||||
|
||||
vector<int> op_needs_compile;
|
||||
string_view_map<int> map;
|
||||
vector<unique_ptr<FusedOp>> fop_needs_compile;
|
||||
auto& jkl = jk;
|
||||
|
||||
for (uint rid=0; rid<queue.size(); rid++) {
|
||||
int root = queue[rid];
|
||||
|
@ -111,7 +124,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
|
||||
}
|
||||
LOGvvv << "Check op needs compile:" << op;
|
||||
op->do_prepare();
|
||||
op->do_prepare(jkl);
|
||||
if (jk.empty()) continue;
|
||||
|
||||
const char* jit_key = jk.to_cstring();
|
||||
|
@ -133,8 +146,8 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
LOGvv << "Op needs compile:" << op;
|
||||
} catch (const std::exception& e) {
|
||||
// log jit_key and file location
|
||||
op->do_prepare();
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
|
||||
op->do_prepare(jkl);
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc");
|
||||
LOGe << "[Error] source file location:" << jit_src_path;
|
||||
if (is_fused_op) {
|
||||
LOGf << "Compile fused operator(" >> rid >> '/' >> queue.size() >> ")"
|
||||
|
@ -173,6 +186,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
auto func = [&](int tid) {
|
||||
auto& entrys = op_entrys.at(tid);
|
||||
entrys.clear();
|
||||
auto& jkl = jk;
|
||||
while (!has_error && !segfault_happen) {
|
||||
int i = ai++;
|
||||
if (i >= n) break;
|
||||
|
@ -184,9 +198,9 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
int root = queue[rid];
|
||||
op = ops[root];
|
||||
LOGvv << "Compile Op:" << op;
|
||||
op->do_prepare();
|
||||
op->do_prepare(jkl);
|
||||
auto op_entry = OpCompiler::do_compile(op);
|
||||
entrys.emplace_back(std::make_tuple(i, 0, (void*)op_entry, op->get_jit_key()));
|
||||
entrys.emplace_back(std::make_tuple(i, 0, (void*)op_entry, op->get_jit_key(jkl)));
|
||||
} else {
|
||||
FusedOp& fused_op = *fop_needs_compile[-rid-1];
|
||||
op = &fused_op;
|
||||
|
@ -194,15 +208,15 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
LOGV(11) << "FusedOps:" << fused_op.ops;
|
||||
fused_op.context = new FusedOpContext();
|
||||
fused_op.context->setup(&fused_op);
|
||||
fused_op.do_prepare();
|
||||
fused_op.do_prepare(jkl);
|
||||
auto op_entry = OpCompiler::do_compile(op);
|
||||
fused_op.context->entry = op_entry;
|
||||
entrys.emplace_back(std::make_tuple(i, 1, (void*)fused_op.context, op->get_jit_key()));
|
||||
entrys.emplace_back(std::make_tuple(i, 1, (void*)fused_op.context, op->get_jit_key(jkl)));
|
||||
|
||||
// compile relay operators
|
||||
for (auto& vrg : fused_op.context->vrm.relay_groups) {
|
||||
for (auto& orc : vrg.oprcs) {
|
||||
orc.op->do_prepare();
|
||||
orc.op->do_prepare(jkl);
|
||||
bool needs_compile;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(entry_lock);
|
||||
|
@ -224,7 +238,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
}
|
||||
} catch (const std::exception& e) {
|
||||
// log jit_key and file location
|
||||
op->do_prepare();
|
||||
op->do_prepare(jkl);
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
|
||||
LOGe << "[Error] source file location:" << jit_src_path;
|
||||
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static inline int _lzcnt(int64 v) {
|
||||
#ifdef __clang__
|
||||
#if __has_feature(__builtin_ia32_lzcnt_u64)
|
||||
return __builtin_ia32_lzcnt_u64(v);
|
||||
#else
|
||||
return v ? __builtin_clzll(v) : 64;
|
||||
#endif
|
||||
#else
|
||||
return __builtin_clzll(v);
|
||||
#endif
|
||||
}
|
||||
|
||||
struct SimpleProfiler {
|
||||
string name;
|
||||
int64 cnt;
|
||||
int64 total_ns;
|
||||
int64 pcnt[7] = {0};
|
||||
int64 pns[7] = {0};
|
||||
int64 last[7] = {0};
|
||||
|
||||
inline SimpleProfiler(string&& name): name(move(name)), cnt(0), total_ns(0) {}
|
||||
inline ~SimpleProfiler() {
|
||||
std::cerr << "=============================\nSimpleProfiler [" << name << "] cnt: " << cnt << " total: ";
|
||||
if (total_ns < 1.e3)
|
||||
std::cerr << total_ns << " ns" << std::endl;
|
||||
else if (total_ns < 1.e6)
|
||||
std::cerr << std::setprecision(3) << total_ns/1.e3 << " us" << std::endl;
|
||||
else if (total_ns < 1.e9)
|
||||
std::cerr << std::setprecision(3) << total_ns/1.e6 << " ms" << std::endl;
|
||||
else
|
||||
std::cerr << std::setprecision(3) << total_ns/1.e9 << " s" << std::endl;
|
||||
std::cerr << " <32ns <1us <32us <1ms <32ms <1s >1s\n";
|
||||
std::cerr << "cnt: ";
|
||||
for (int i=0; i<7; i++) std::cerr << std::setw(9) << pcnt[i];
|
||||
std::cerr << "\n ";
|
||||
for (int i=0; i<7; i++) std::cerr << std::setw(9) << std::setprecision(3) << pcnt[i]*1.0/cnt;
|
||||
std::cerr << "\ntime:";
|
||||
for (int i=0; i<7; i++) std::cerr << std::setw(9) << std::setprecision(3) << pns[i]*1.0/total_ns;
|
||||
std::cerr << "\nlast:";
|
||||
for (int i=0; i<7; i++) std::cerr << std::setw(9) << last[i];
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
inline void add(int64 time) {
|
||||
auto nbit = 64 - _lzcnt(time);
|
||||
auto i = (nbit-1) / 5;
|
||||
if (i>6) i=6;
|
||||
cnt ++;
|
||||
total_ns += time;
|
||||
pcnt[i] ++;
|
||||
pns[i] += time;
|
||||
last[i] = cnt;
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
example:
|
||||
{
|
||||
static SimpleProfiler _("array");
|
||||
SimpleProfilerGuard __(_);
|
||||
......
|
||||
}
|
||||
*/
|
||||
struct SimpleProfilerGuard {
|
||||
SimpleProfiler* p;
|
||||
std::chrono::high_resolution_clock::time_point start;
|
||||
inline SimpleProfilerGuard(SimpleProfiler& p) : p(&p) {
|
||||
start = std::chrono::high_resolution_clock::now();
|
||||
}
|
||||
inline ~SimpleProfilerGuard() {
|
||||
auto finish = std::chrono::high_resolution_clock::now();
|
||||
auto total_ns = (int64_t)std::chrono::duration_cast<std::chrono::nanoseconds>(finish-start).count();
|
||||
p->add(total_ns);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // jittor
|
|
@ -13,6 +13,7 @@
|
|||
#include "misc/hash.h"
|
||||
#include "misc/nano_string.h"
|
||||
#include "misc/fast_shared_ptr.h"
|
||||
#include "profiler/simple_profiler.h"
|
||||
#ifdef HAS_CUDA
|
||||
#include "misc/cuda_flags.h"
|
||||
#endif
|
||||
|
|
|
@ -42,11 +42,11 @@ JIT_TEST(jit_key) {
|
|||
}
|
||||
|
||||
jk.clear();
|
||||
add_jit_define("f", 0.01);
|
||||
add_jit_define("f", 0.5);
|
||||
add_jit_define("f", 1.0/0);
|
||||
add_jit_define("f", -1.0/0);
|
||||
add_jit_define("f", 0.0/0);
|
||||
add_jit_define(jk, "f", 0.01);
|
||||
add_jit_define(jk, "f", 0.5);
|
||||
add_jit_define(jk, "f", 1.0/0);
|
||||
add_jit_define(jk, "f", -1.0/0);
|
||||
add_jit_define(jk, "f", 0.0/0);
|
||||
keys = parse_jit_keys(jk.to_string());
|
||||
k2 = {{"f","0x1.47ae147ae147bp-7"},
|
||||
{"f","0x1p-1"},
|
||||
|
|
|
@ -275,7 +275,7 @@ a[2]++;
|
|||
ir.move_out_children();
|
||||
ir.push_back("T x=1;");
|
||||
ir.push_back("T y=n;");
|
||||
int num;
|
||||
int num=0;
|
||||
CHECK(ir.get_number("x", num) && num==1);
|
||||
CHECK(!ir.get_number("z", num) && num==-1);
|
||||
CHECK(!ir.get_number("y", num) && num==-2);
|
||||
|
|
|
@ -45,7 +45,7 @@ JIT_TEST(fused_op_relay_matmul) {
|
|||
});
|
||||
CHECKop(q.size(),==,10);
|
||||
CHECKop(ops.size(),==,4);
|
||||
for (auto op : ops) op->do_jit_prepare();
|
||||
for (auto op : ops) op->do_jit_prepare(jk);
|
||||
FusedOp fop;
|
||||
FusedOpContext context;
|
||||
fop.context = &context;
|
||||
|
|
Loading…
Reference in New Issue