optimize jit_prepare

This commit is contained in:
Dun Liang 2020-10-26 19:40:04 +08:00
parent b747cfd9bf
commit f1b7af155f
73 changed files with 597 additions and 353 deletions

View File

@ -55,10 +55,14 @@ void CubArgReduceOp::infer_shape() {
y_key->set_shape(shape);
}
void CubArgReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Toffsets", offsets->dtype());
add_jit_define("FUNC", op==ns_minimum ? "ArgMin" : "ArgMax");
void CubArgReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][Toffsets:") << offsets->dtype();
jk << _CS("][FUNC:");
if (op==ns_minimum)
jk << _CS("ArgMin]");
else
jk << _CS("ArgMax]");
}
#else // JIT

View File

@ -47,12 +47,16 @@ void CubArgsortOp::infer_shape() {
y_key->set_shape(x->shape);
}
void CubArgsortOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Tindexes", indexes->dtype());
add_jit_define("Toffsets", offsets->dtype());
add_jit_define("Ty", y->dtype());
add_jit_define("FUNC", descending ? "SortPairsDescending" : "SortPairs");
void CubArgsortOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][Tindexes:") << indexes->dtype();
jk << _CS("][Toffsets:") << offsets->dtype();
jk << _CS("][Ty:") << y->dtype();
jk << _CS("][FUNC:");
if (descending)
jk << _CS("SortPairsDescending]");
else
jk << _CS("SortPairs]");
}
#else // JIT

View File

@ -22,8 +22,8 @@ CubTestOp::CubTestOp(string cmd) : cmd(cmd) {
output = create_output(1, ns_float32);
}
void CubTestOp::jit_prepare() {
add_jit_define("T", ns_float32);
void CubTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]");
}
#else // JIT

View File

@ -76,11 +76,12 @@ void CublasBatchedMatmulOp::infer_shape(){
c->set_shape(c_shape);
}
void CublasBatchedMatmulOp::jit_prepare() {
add_jit_define("T", a->dtype());
add_jit_define("Trans_a", trans_a ? "T" : "N");
add_jit_define("Trans_b", trans_b ? "T" : "N");
add_jit_define("op", a->dtype().dsize() == 4 ? "S" : "D");
void CublasBatchedMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D');
jk << ']';
}
#else // JIT

View File

@ -41,11 +41,12 @@ void CublasMatmulOp::infer_shape() {
c->set_shape({n, k});
}
void CublasMatmulOp::jit_prepare() {
add_jit_define("T", a->dtype());
add_jit_define("Trans_a", trans_a ? "T" : "N");
add_jit_define("Trans_b", trans_b ? "T" : "N");
add_jit_define("op", a->dtype().dsize() == 4 ? "S" : "D");
void CublasMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D');
jk << ']';
}
#else // JIT

View File

@ -17,8 +17,8 @@ CublasTestOp::CublasTestOp(int size_mult) : size_mult(size_mult) {
output = create_output(1, ns_float32);
}
void CublasTestOp::jit_prepare() {
add_jit_define("T", ns_float32);
void CublasTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]");
}
#else // JIT

View File

@ -62,13 +62,14 @@ void CudnnConvBackwardWOp::infer_shape() {
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
}
void CudnnConvBackwardWOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Ty", dy->dtype());
add_jit_define("Tw", dw->dtype());
add_jit_define("XFORMAT", xformat);
add_jit_define("WFORMAT", wformat);
add_jit_define("YFORMAT", yformat);
void CudnnConvBackwardWOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][Ty:") << dy->dtype();
jk << _CS("][Tw:") << dw->dtype();
jk << _CS("][XFORMAT:") << xformat;
jk << _CS("][WFORMAT:") << wformat;
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;

View File

@ -62,13 +62,14 @@ void CudnnConvBackwardXOp::infer_shape() {
set_shape(dx, "abcd", xformat, xn, xc, xh, xw);
}
void CudnnConvBackwardXOp::jit_prepare() {
add_jit_define("Ty", dy->dtype());
add_jit_define("Tw", w->dtype());
add_jit_define("Tx", dx->dtype());
add_jit_define("XFORMAT", xformat);
add_jit_define("WFORMAT", wformat);
add_jit_define("YFORMAT", yformat);
void CudnnConvBackwardXOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << dx->dtype();
jk << _CS("][Ty:") << dy->dtype();
jk << _CS("][Tw:") << w->dtype();
jk << _CS("][XFORMAT:") << xformat;
jk << _CS("][WFORMAT:") << wformat;
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;

View File

@ -64,13 +64,14 @@ void CudnnConvOp::infer_shape() {
set_shape(y, "abcd", yformat, yn, yc, yh, yw);
}
void CudnnConvOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Tw", w->dtype());
add_jit_define("Ty", y->dtype());
add_jit_define("XFORMAT", xformat);
add_jit_define("WFORMAT", wformat);
add_jit_define("YFORMAT", yformat);
void CudnnConvOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][Ty:") << y->dtype();
jk << _CS("][Tw:") << w->dtype();
jk << _CS("][XFORMAT:") << xformat;
jk << _CS("][WFORMAT:") << wformat;
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;

View File

@ -18,8 +18,8 @@ CudnnTestOp::CudnnTestOp(string cmd) : cmd(move(cmd)) {
output = create_output(1, ns_float32);
}
void CudnnTestOp::jit_prepare() {
add_jit_define("T", ns_float32);
void CudnnTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]");
}
#else // JIT

View File

@ -23,9 +23,9 @@ CurandRandomOp::CurandRandomOp(NanoVector shape, NanoString dtype, NanoString ty
ASSERT(type == ns_normal || type == ns_uniform);
}
void CurandRandomOp::jit_prepare() {
add_jit_define("T", output->dtype());
add_jit_define("R", type);
void CurandRandomOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << output->dtype();
jk << _CS("][R:") << type << ']';
}
#else // JIT

View File

@ -20,8 +20,8 @@ CuttTestOp::CuttTestOp(string cmd) : cmd(cmd) {
output = create_output(1, ns_float32);
}
void CuttTestOp::jit_prepare() {
add_jit_define("T", ns_float32);
void CuttTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]");
}
#else // JIT

View File

@ -58,11 +58,12 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return make_transpose(dout, reverse);
}
void CuttTransposeOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("DIM", JK::hex1(axes.size()));
void CuttTransposeOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][DIM=") << JK::hex1(axes.size());
for (uint i=0; i<axes.size(); i++)
add_jit_define("AXES", JK::hex1(axes[i]), S(i));
jk << _CS("][AXES") << JK::hex1(axes[i]) << '=' << JK::hex1(i);
jk << ']';
}
unordered_map<string, unsigned int> cutt_plan_cache;

View File

@ -36,8 +36,8 @@ VarPtr NcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return nccl_all_reduce(dout);
}
void NcclAllReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
void NcclAllReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() << ']';
}
#else // JIT

View File

@ -34,8 +34,8 @@ VarPtr NcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return nccl_reduce(dout,root);
}
void NcclBroadcastOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
void NcclBroadcastOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() << ']';
}
#else // JIT

View File

@ -34,8 +34,8 @@ VarPtr NcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return nccl_broadcast(dout,root);
}
void NcclReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
void NcclReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() << ']';
}
#else // JIT

View File

@ -19,8 +19,8 @@ NcclTestOp::NcclTestOp(string cmd) : cmd(cmd) {
output = create_output(1, ns_float32);
}
void NcclTestOp::jit_prepare() {
add_jit_define("T", ns_float32);
void NcclTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]");
}
#else // JIT

View File

@ -77,16 +77,17 @@ static const char* short_type(Var* x) {
}
}
void MklConvBackwardWOp::jit_prepare() {
add_jit_define("Txd", x->dtype());
add_jit_define("Tyd", dy->dtype());
add_jit_define("Twd", dw->dtype());
add_jit_define("Tx", short_type(x));
add_jit_define("Tw", short_type(dw));
add_jit_define("Ty", short_type(dy));
add_jit_define("XFORMAT", xformat);
add_jit_define("WFORMAT", wformat);
add_jit_define("YFORMAT", yformat);
void MklConvBackwardWOp::jit_prepare(JK& jk) {
jk << _CS("[Txd:") << x->dtype();
jk << _CS("][Tyd:") << dy->dtype();
jk << _CS("][Twd:") << dw->dtype();
jk << _CS("][Tx:") << short_type(x);
jk << _CS("][Tw:") << short_type(dw);
jk << _CS("][Ty:") << short_type(dy);
jk << _CS("][XFORMAT:") << xformat;
jk << _CS("][WFORMAT:") << wformat;
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
#else // JIT

View File

@ -75,16 +75,17 @@ static const char* short_type(Var* x) {
}
}
void MklConvBackwardXOp::jit_prepare() {
add_jit_define("Tyd", dy->dtype());
add_jit_define("Twd", w->dtype());
add_jit_define("Txd", dx->dtype());
add_jit_define("Tx", short_type(dx));
add_jit_define("Tw", short_type(w));
add_jit_define("Ty", short_type(dy));
add_jit_define("XFORMAT", xformat);
add_jit_define("WFORMAT", wformat);
add_jit_define("YFORMAT", yformat);
void MklConvBackwardXOp::jit_prepare(JK& jk) {
jk << _CS("[Tyd:") << dy->dtype();
jk << _CS("][Twd:") << w->dtype();
jk << _CS("][Txd:") << dx->dtype();
jk << _CS("][Tx:") << short_type(dx);
jk << _CS("][Tw:") << short_type(w);
jk << _CS("][Ty:") << short_type(dy);
jk << _CS("][XFORMAT:") << xformat;
jk << _CS("][WFORMAT:") << wformat;
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
#else // JIT

View File

@ -79,13 +79,17 @@ static const char* short_type(Var* x) {
}
}
void MklConvOp::jit_prepare() {
add_jit_define("Tx", short_type(x));
add_jit_define("Tw", short_type(w));
add_jit_define("Ty", short_type(y));
add_jit_define("XFORMAT", xformat);
add_jit_define("WFORMAT", wformat);
add_jit_define("YFORMAT", yformat);
void MklConvOp::jit_prepare(JK& jk) {
jk << _CS("[Txd:") << x->dtype();
jk << _CS("][Tyd:") << y->dtype();
jk << _CS("][Twd:") << w->dtype();
jk << _CS("][Tx:") << short_type(x);
jk << _CS("][Tw:") << short_type(w);
jk << _CS("][Ty:") << short_type(y);
jk << _CS("][XFORMAT:") << xformat;
jk << _CS("][WFORMAT:") << wformat;
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
#else // JIT

View File

@ -42,10 +42,10 @@ void MklMatmulOp::infer_shape() {
c->set_shape({n, k});
}
void MklMatmulOp::jit_prepare() {
add_jit_define("T", a->dtype());
add_jit_define("Trans_a", trans_a ? "T" : "N");
add_jit_define("Trans_b", trans_b ? "T" : "N");
void MklMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N') << ']';
}
#else // JIT

View File

@ -17,8 +17,8 @@ MklTestOp::MklTestOp() {
output = create_output(1, ns_float32);
}
void MklTestOp::jit_prepare() {
add_jit_define("T", ns_float32);
void MklTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]");
}
#else // JIT

View File

@ -61,9 +61,9 @@ VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return mpi_all_reduce(dout, ns_add);
}
void MpiAllReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("OP", op.to_cstring());
void MpiAllReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][OP:") << op << ']';
}
#else // JIT

View File

@ -48,8 +48,8 @@ VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return mpi_reduce(dout, ns_add, root);
}
void MpiBroadcastOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
void MpiBroadcastOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() << ']';
}
#else // JIT

View File

@ -61,9 +61,9 @@ VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return mpi_broadcast(dout,root);
}
void MpiReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("OP", op.to_cstring());
void MpiReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][OP:") << op << ']';
}
#else // JIT

View File

@ -16,8 +16,8 @@ MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) {
output = create_output(1, ns_float32);
}
void MpiTestOp::jit_prepare() {
add_jit_define("T", ns_float32);
void MpiTestOp::jit_prepare(JK& jk) {
jk << _CS("[T:float32]");
}
#else // JIT

View File

@ -54,8 +54,8 @@ CustomOp::CustomOp(NanoVector shape, NanoString dtype) {
output = create_output(shape, dtype);
}
void CustomOp::jit_prepare() {
add_jit_define("T", output->dtype());
void CustomOp::jit_prepare(JK& jk) {
add_jit_define(jk, "T", output->dtype());
}
#else // JIT

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.0.6'
__version__ = '1.2.0.7
from . import lock
with lock.lock_scope():
from . import compiler

View File

@ -40,8 +40,8 @@ class TestCuda(unittest.TestCase):
output = create_output(shape, dtype);
}
void NoCudaOp::jit_prepare() {
add_jit_define("T", output->dtype());
void NoCudaOp::jit_prepare(JK& jk) {
add_jit_define(jk, "T", output->dtype());
}
#else // JIT
@ -70,8 +70,8 @@ class TestCuda(unittest.TestCase):
output = create_output(shape, dtype);
}
void MyCudaOp::jit_prepare() {
add_jit_define("T", output->dtype());
void MyCudaOp::jit_prepare(JK& jk) {
add_jit_define(jk, "T", output->dtype());
}
#else // JIT

View File

@ -35,8 +35,8 @@ CustomOp::CustomOp(NanoVector shape, NanoString dtype) {
output = create_output(shape, dtype);
}
void CustomOp::jit_prepare() {
add_jit_define("T", output->dtype());
void CustomOp::jit_prepare(JK& jk) {
add_jit_define(jk, "T", output->dtype());
}
#else // JIT
@ -86,8 +86,8 @@ class TestCustomOp(unittest.TestCase):
output = create_output(shape, dtype);
}
void MyOp::jit_prepare() {
add_jit_define("T", output->dtype());
void MyOp::jit_prepare(JK& jk) {
add_jit_define(jk, "T", output->dtype());
}
#else // JIT

View File

@ -385,9 +385,8 @@ class Gray:
img_ = transform(img)
'''
def __call__(self, img:Image.Image):
img = np.array(img.convert('L'))
img = img[np.newaxis, :]
return np.array((img / 255.0), dtype = np.float32)
img = np.float32(img.convert('L')) / np.float32(255.0)
return img[np.newaxis, :]
class RandomCrop:
'''

View File

@ -378,6 +378,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
#ifdef HAS_CUDA
int sync_times = 0;
#endif
auto& jkl = jk;
for (uint rid=0; rid<queue.size(); rid++) {
int root = queue[rid];
Op* op = ops[root];
@ -396,7 +397,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
for (auto* var : op->outputs())
var->alloc(allocator);
LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs();
op->do_prepare();
op->do_prepare(jkl);
bool is_cuda = op->flags.get(NodeFlags::_cuda);
#ifdef HAS_CUDA
if (!is_cuda) {
@ -422,7 +423,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
}
#endif
last_is_cuda = is_cuda;
op->do_run_after_prepare();
op->do_run_after_prepare(jkl);
LOGvvv << "Finished Op(" >> op->name() << rid >>
"/" >> queue.size() >> ") output:" << op->outputs();
if (is_fused_op) {
@ -454,7 +455,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
// log memory info
display_memory_info(__FILELINE__, false, true);
// log jit_key and file location
op->do_prepare();
op->do_prepare(jkl);
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
LOGe << "[Error] source file location:" << jit_src_path;
if (is_fused_op) {

View File

@ -32,7 +32,7 @@ loop_options_t& FusedOp::get_loop_options_tuned() {
void FusedOp::update_jit_key() {
jk.clear();
do_jit_prepare();
do_jit_prepare(jk);
}
void FusedOp::update_ops() {
@ -41,7 +41,6 @@ void FusedOp::update_ops() {
loop_options = loop_options_origin = nullptr;
_outputs.clear();
jk.clear();
for (Op* op : ops) {
for (Var* o : op->outputs()) {
if (o->loop_options) {
@ -156,13 +155,13 @@ void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
}
}
void FusedOp::do_jit_prepare() {
void FusedOp::do_jit_prepare(JK& jk) {
jk.clear();
int8 flags = 3;
for (uint i=0; i<ops.size(); i++) {
Op* op = ops[i];
jk << JK::key << "opkey" << i << JK::val;
op->do_jit_prepare();
jk << "[opkey" << i << JK::val;
op->do_jit_prepare(jk);
jk << JK::end;
if (op->flags.get(NodeFlags::_cpu))
flags &= 1; // only cpu
@ -170,39 +169,39 @@ void FusedOp::do_jit_prepare() {
flags &= 2; // only gpu
}
ASSERT(flags) << "FusedOp cannot contain both cpu and cuda ops.";
add_jit_define("JIT", "1");
jk << _CS("[JIT:1]");
if (flags==1) {
// only cpu
add_jit_define("JIT_cpu", "1");
jk << _CS("[JIT_cpu:1]");
this->flags.set(NodeFlags::_cuda, 0);
this->flags.set(NodeFlags::_cpu, 1);
} else {
add_jit_define("JIT_cuda", "1");
jk << _CS("[JIT_cuda:1]");
this->flags.set(NodeFlags::_cpu, 0);
this->flags.set(NodeFlags::_cuda, 1);
}
jk << JK::key << "graph" << JK::val;
jk << _CS("[graph:");
for (auto& t : edges) {
uint i,j,k,l;
std::tie(i,j,k,l) = t;
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
}
jk << JK::end << JK::key << "var_info" << JK::val;
for (auto& vi : vars)
jk << _CS("][var_info:") << JK::val;
for (auto& vi : vars)
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
jk << JK::end;
if (loop_options->size()) {
if (get_loop_option("compile_shapes")) {
jk << JK::key << "shapes" << JK::val;
jk << _CS("[shapes:");
for (auto& vi : vars) {
jk << '[';
for (auto a : vi.var->shape)
jk << a << ',';
jk << "],";
jk << _CS("],");
}
jk << JK::end;
}
jk << JK::key << "choices" << JK::val;
jk << _CS("[choices:");
for (auto& kv : *loop_options)
jk << kv.first << ':' << kv.second << ',';
jk << JK::end;
@ -210,11 +209,11 @@ void FusedOp::do_jit_prepare() {
jk.finilize();
}
void FusedOp::do_prepare() {
do_jit_prepare();
void FusedOp::do_prepare(JK& jk) {
do_jit_prepare(jk);
}
void FusedOp::do_run_after_prepare() {
void FusedOp::do_run_after_prepare(JK& jk) {
const char* jit_key = jk.to_cstring();
auto iter = jit_fused_ops.find(string_view(jit_key, jk.size));
if (iter != jit_fused_ops.end()) {
@ -230,7 +229,7 @@ void FusedOp::do_run_after_prepare() {
context->setup(this);
string prev_jit_key = jit_key;
context->entry = OpCompiler::do_compile(this);
string new_jit_key = get_jit_key();
string new_jit_key = get_jit_key(jk);
jit_fused_ops[new_jit_key] = jit_fused_ops[prev_jit_key] = context;
jit_key_mapper[prev_jit_key] = new_jit_key;
LOGvv << "Get jit op entry:" << (void*)(context->entry);
@ -257,8 +256,8 @@ int FusedOp::has(Node* node) {
}
void FusedOp::do_run(){
do_prepare();
do_run_after_prepare();
do_prepare(jk);
do_run_after_prepare(jk);
}
#else // JIT

View File

@ -50,9 +50,9 @@ struct FusedOp final : Op {
void statistics(uint64_t& in, uint64_t& out, uint64_t& compute) override;
bool shape_infered() override;
void infer_shape() override;
void do_jit_prepare() override;
void do_prepare() override;
void do_run_after_prepare() override;
void do_jit_prepare(JK& jk) override;
void do_prepare(JK& jk) override;
void do_run_after_prepare(JK& jk) override;
void do_run() override;
#ifdef JIT
void jit_run();

View File

@ -108,7 +108,7 @@ vector<pair<string,string>> parse_jit_keys(const string& s) {
val += c;
}
}
ASSERT(presum==0);
ASSERT(presum==0) << s;
return jit_keys;
}

View File

@ -70,18 +70,31 @@ struct JitKey {
};
};
struct __jk_int128 {
int64 a,b;
};
struct __jk_int256 {
int64 a,b,c,d;
};
extern thread_local JitKey jk;
typedef JitKey JK;
inline JK& operator<<(JK& jk, const char* s) {
while (*s) jk.buffer[jk.size++] = *s, s++;
int i;
for (i=0; s[i]; i++)
jk.buffer[jk.size+i] = s[i];
jk.size += i;
return jk;
}
inline JK& operator<<(JK& jk, const string& s) {
for (uint64 i=0; i<s.size(); i++)
jk.buffer[jk.size+i] = s[i];
jk.size += s.size();
auto a = (__jk_int256*)(jk.buffer+jk.size);
auto b = (__jk_int256*)(&s[0]);
auto len = s.size();
for (uint64 i=0; i*32<len; i++)
a[i] = b[i];
jk.size += len;
return jk;
}
@ -166,55 +179,138 @@ static inline float64 itof(uint64 a) { return *(float64*)&a; }
#pragma GCC diagnostic pop
#endif
inline JK& operator<<(JK& jk, float64 f) {
return jk << "itof(0x" << JK::hex(ftoi(f)) << ')';
}
inline JK& operator<<(JK& jk, const NanoString& ns) {
return jk << ns.to_cstring();
auto a = (__jk_int128*)(jk.buffer+jk.size);
auto b = (__jk_int128*)(ns.to_cstring());
auto len = ns.len();
a[0] = b[0];
jk.size += len;
return jk;
}
vector<pair<string,string>> parse_jit_keys(const string& s);
template <class Ta, class Tb>
void add_jit_define(const Ta& key, const Tb& val) {
void add_jit_define(JK& jk, const Ta& key, const Tb& val) {
jk << JK::key << key << JK::val << val << JK::end;
}
template <class Ta, class Tb, class Tc>
void add_jit_define(const Ta& key, const Tb& i, const Tc& val) {
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const Tc& val) {
jk << JK::key << key << i << JK::val << val << JK::end;
}
template <class Ta>
void add_jit_define(const Ta& key, const JK::hex& val) {
void add_jit_define(JK& jk, const Ta& key, const JK::hex& val) {
jk << JK::key << key << JK::hex_val << val << JK::end;
}
template <class Ta, class Tb>
void add_jit_define(const Ta& key, const Tb& i, const JK::hex& val) {
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex& val) {
jk << JK::key << key << i << JK::hex_val << val << JK::end;
}
template <class Ta>
void add_jit_define(const Ta& key, const JK::hex1& val) {
void add_jit_define(JK& jk, const Ta& key, const JK::hex1& val) {
jk << JK::key << key << JK::hex_val << val << JK::end;
}
template <class Ta, class Tb>
void add_jit_define(const Ta& key, const Tb& i, const JK::hex1& val) {
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex1& val) {
jk << JK::key << key << i << JK::hex_val << val << JK::end;
}
template <class Ta>
void add_jit_define(const Ta& key, const JK::hex2& val) {
void add_jit_define(JK& jk, const Ta& key, const JK::hex2& val) {
jk << JK::key << key << JK::hex_val << val << JK::end;
}
template <class Ta, class Tb>
void add_jit_define(const Ta& key, const Tb& i, const JK::hex2& val) {
void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex2& val) {
jk << JK::key << key << i << JK::hex_val << val << JK::end;
}
// begin of const string
#define MAX_CONST_CHAR 32
#define _CS_MIN(a,b) (a)<(b)?(a):(b)
#define _CS_T(s)\
getChr(s,0),\
getChr(s,1),\
getChr(s,2),\
getChr(s,3),\
getChr(s,4),\
getChr(s,5),\
getChr(s,6),\
getChr(s,7),\
getChr(s,8),\
getChr(s,9),\
getChr(s,10),\
getChr(s,11),\
getChr(s,12),\
getChr(s,13),\
getChr(s,14),\
getChr(s,15),\
getChr(s,16),\
getChr(s,17),\
getChr(s,18),\
getChr(s,19),\
getChr(s,20),\
getChr(s,21),\
getChr(s,22),\
getChr(s,23),\
getChr(s,24),\
getChr(s,25),\
getChr(s,26),\
getChr(s,27),\
getChr(s,28),\
getChr(s,29),\
getChr(s,30),\
getChr(s,31),\
getChr(s,32),\
getChr(s,33),\
getChr(s,34),\
getChr(s,35)
#define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))<sizeof(name)/sizeof(*name)?name[ii]:0)
#define _CS(str) _CS_G<_CS_T(str)>()
template <char c1, char c2, char c3, char c4, char... Chars_> struct _CS_G {
};
template<> struct _CS_G<0,0,0,0> {};
template <char c1, char c2, char c3, char c4, char... Chars_>
inline JK& operator<<(JK& jk, const _CS_G<c1,c2,c3,c4,Chars_...>& _) {
((int*)(jk.buffer+jk.size))[0] = c4*(1<<24)+c3*(1<<16)+c2*(1<<8)+c1;
if (c4) {
jk.size += 4;
jk << _CS_G<Chars_...>();
} else
if (c3) {
jk.size += 3;
} else
if (c2) {
jk.size += 2;
} else
if (c1) {
jk.size += 1;
}
return jk;
}
template <>
inline JK& operator<<(JK& jk, const _CS_G<0,0,0,0>& _) {
return jk;
}
inline JK& operator<<(JK& jk, float64 f) {
return jk << _CS("itof(0x") << JK::hex(ftoi(f)) << ')';
}
} // jittor

View File

@ -115,7 +115,8 @@ static unordered_set<string> binary_ops = {
FOR_ALL_NS(DEFINE_NS);
unordered_map<string, NanoString> NanoString::__string_to_ns;
vector<const char*> NanoString::__ns_to_string;
char NanoString::__ns_to_string[ns_max_size*ns_max_len];
int NanoString::__ns_len[ns_max_size];
static void init_ns() {
NanoString::ns_t i=0;
@ -140,11 +141,17 @@ static void init_ns() {
ns.set(NanoString::_bool, is_bool.count(name));
}
NanoString::__string_to_ns[name] = ns;
NanoString::__ns_to_string.push_back(name);
auto name2 = ns.to_cstring();
int len=0;
for (;;len++) {
name2[len] = name[len];
if (!name[len]) break;
}
NanoString::__ns_len[i-1] = len;
};
#define INIT_NS(T) func(#T, ns_##T);
FOR_ALL_NS(INIT_NS);
ASSERT(NanoString::__ns_to_string.size()<=(1<<NanoString::_index_nbits));
ASSERT(i<=(1<<NanoString::_index_nbits));
NanoString::__string_to_ns["sum"] = ns_add;
NanoString::__string_to_ns["min"] = ns_minimum;
NanoString::__string_to_ns["max"] = ns_maximum;

View File

@ -8,6 +8,8 @@
namespace jittor {
constexpr int ns_max_size = 256;
constexpr int ns_max_len = 16;
#define FOR_ALL_NS(m) \
\
@ -107,7 +109,8 @@ struct NanoString {
ns_t data=0;
static unordered_map<string, NanoString> __string_to_ns;
static vector<const char*> __ns_to_string;
static char __ns_to_string[];
static int __ns_len[];
inline void set(Flags f, ns_t a=1, ns_t nbits=1) {
ns_t mask = (((1u<<nbits)-1)<<f);
@ -118,6 +121,7 @@ struct NanoString {
return (data>>f) & ((1u<<nbits)-1);
}
inline ns_t index() const { return get(_index, _index_nbits); }
inline int len() const { return __ns_len[index()]; }
inline ns_t type() const { return get(_type, _type_nbits); }
inline ns_t is_bool() const { return get(_bool); }
inline ns_t is_int() const { return get(_int); }
@ -140,7 +144,9 @@ struct NanoString {
inline NanoString(const string& s) : NanoString(s.c_str()) {}
// @pyjt(__repr__)
inline const char* to_cstring() const
{ return __ns_to_string[index()]; }
{ return __ns_to_string+index()*ns_max_len; }
inline char* to_cstring()
{ return __ns_to_string+index()*ns_max_len; }
};
// force_type = 1 for int, 2 for float

View File

@ -26,6 +26,7 @@ struct string_view_map {
iter_t end() { return umap.end(); }
const T& at(string_view sv) { return umap.at(sv); }
size_t size() { return umap.size(); }
T& operator[](string_view sv) {
auto iter = find(sv);

View File

@ -79,7 +79,7 @@ void Op::compile_optimize(string& src) {}
void Op::infer_shape() {}
void Op::run() {}
void Op::jit_prepare() {}
void Op::jit_prepare(JK& jk) {}
void Op::graph_optimize() {}
string Op::name_ex() const {
@ -91,20 +91,20 @@ string Op::name_ex() const {
return a;
}
string Op::get_jit_key() {
string Op::get_jit_key(JK& jk) {
jk.clear();
do_jit_prepare();
do_jit_prepare(jk);
return jk.to_string();
}
vector<pair<string,string>> Op::get_jit_define() {
return parse_jit_keys(get_jit_key());
return parse_jit_keys(get_jit_key(jk));
}
void Op::do_jit_prepare() {
void Op::do_jit_prepare(JK& jk) {
memcheck_all_exist();
jk << name();
jit_prepare();
jit_prepare(jk);
if (jk.empty()) {
// not a jit op
bool has_cuda = flags.get(NodeFlags::_cuda);
@ -144,9 +144,9 @@ void Op::do_jit_prepare() {
use_int64_t = true;
out_id ++;
}
add_jit_define("JIT", "1");
jk << _CS("[JIT:1]");
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
add_jit_define("JIT_cuda", "1");
jk << _CS("[JIT_cuda:1]");
flags.set(NodeFlags::_cpu, 0);
// TODO: 64bit index in CUDA
use_int64_t = false;
@ -159,21 +159,24 @@ void Op::do_jit_prepare() {
}
ASSERT(flags.get(NodeFlags::_cpu))
<< "Op" << name() << "doesn't have cpu version";
add_jit_define("JIT_cpu", "1");
jk << _CS("[JIT_cpu:1]");
flags.set(NodeFlags::_cuda, 0);
}
if (try_use_32bit_index) use_int64_t = false;
add_jit_define("index_t", use_int64_t ? "int64" : "int32");
if (use_int64_t)
jk << _CS("[index_t:int64]");
else
jk << _CS("[index_t:int32]");
}
jk.finilize();
}
void Op::do_prepare(){
void Op::do_prepare(JK& jk){
jk.clear();
do_jit_prepare();
do_jit_prepare(jk);
}
void Op::do_run_after_prepare() {
void Op::do_run_after_prepare(JK& jk) {
if (!jk.empty())
jit_run();
else
@ -181,8 +184,8 @@ void Op::do_run_after_prepare() {
}
void Op::do_run() {
do_prepare();
do_run_after_prepare();
do_prepare(jk);
do_run_after_prepare(jk);
}
string Op::get_filename_from_jit_key(const string& jit_key, const string& suffix) {
@ -245,7 +248,7 @@ void Op::jit_run() {
// compile JIT op
string prev_jit_key = jit_key;
auto op_entry = OpCompiler::do_compile(this);
string new_jit_key = get_jit_key();
string new_jit_key = get_jit_key(jk);
jit_ops[new_jit_key] = jit_ops[prev_jit_key] = op_entry;
jit_key_mapper[prev_jit_key] = new_jit_key;
LOGvv << "Get jit op entry:" << (void*)op_entry;

View File

@ -40,12 +40,12 @@ struct Op : Node {
virtual void grads(Var** douts, VarPtr* dins);
virtual void infer_shape();
virtual void run();
virtual void jit_prepare();
virtual void do_jit_prepare();
virtual void jit_prepare(JK& jk);
virtual void do_jit_prepare(JK& jk);
virtual const char* name() const = 0;
virtual void statistics(uint64_t& in, uint64_t& out, uint64_t& compute);
virtual void do_prepare();
virtual void do_run_after_prepare();
virtual void do_prepare(JK& jk);
virtual void do_run_after_prepare(JK& jk);
virtual void do_run();
virtual VarPtr duplicate();
virtual void compile_optimize(string& src);
@ -53,7 +53,7 @@ struct Op : Node {
void jit_run();
string name_ex() const;
string get_jit_key();
string get_jit_key(JK& jk);
vector<pair<string,string>> get_jit_define();
};
@ -66,7 +66,7 @@ extern string_view_map<string> jit_key_mapper;
#ifdef JIT
#define DECLARE_jit_run void jit_run();
#else
#define DECLARE_jit_run void jit_prepare() override;
#define DECLARE_jit_run void jit_prepare(JK& jk) override;
#endif
} // jittor

View File

@ -1000,7 +1000,7 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
src = &src_after_passes;
}
op->compile_optimize(*src);
auto ret = oc.compile(op->get_jit_key(), *src);
auto ret = oc.compile(op->get_jit_key(jk), *src);
return ret;
}

View File

@ -155,14 +155,15 @@ void ArgReduceOp::infer_shape() {
y_key->set_shape(shape);
}
void ArgReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Ty", y->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));
add_jit_define("YDIM", JK::hex1(y->shape.size()));
add_jit_define("KEEPDIMS", keepdims ? 1 : 0);
add_jit_define("DIM", JK::hex1(dim));
add_jit_define("CMP", op==ns_minimum ? "<" : ">");
void ArgReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][Ty:") << y->dtype();
jk << _CS("][XDIM=") << JK::hex1(x->shape.size());
jk << _CS("][YDIM=") << JK::hex1(y->shape.size());
jk << _CS("][KEEPDIMS:") << (keepdims ? '1' : '0');
jk << _CS("][DIM=") << JK::hex1(dim);
jk << _CS("][CMP:") << (op==ns_minimum ? "<" : ">");
jk << ']';
}
#else // JIT

View File

@ -127,12 +127,13 @@ void ArgsortOp::infer_shape() {
y_key->set_shape(x->shape);
}
void ArgsortOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Ty", y->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));
add_jit_define("DIM", JK::hex1(dim));
add_jit_define("CMP", descending ? ">" : "<");
void ArgsortOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][Ty:") << y->dtype();
jk << _CS("][XDIM=") << JK::hex1(x->shape.size());
jk << _CS("][DIM=") << JK::hex1(dim);
jk << _CS("][CMP:") << (descending ? '>' : '<');
jk << ']';
}
#else // JIT

View File

@ -74,9 +74,9 @@ ArrayOp::ArrayOp(ArrayArgs&& args) {
std::memcpy(allocation.ptr, args.ptr, output->size);
}
void ArrayOp::jit_prepare() {
void ArrayOp::jit_prepare(JK& jk) {
if (output->flags.get(NodeFlags::_force_fuse))
add_jit_define("T", output->dtype());
jk << _CS("[T:") << output->dtype() << ']';
}
void ArrayOp::run() {

View File

@ -28,7 +28,7 @@ struct ArrayOp : Op {
const char* name() const override { return "array"; }
void run() override;
void jit_prepare() override;
void jit_prepare(JK& jk) override;
};
} // jittor

View File

@ -180,11 +180,11 @@ void BinaryOp::infer_shape() {
z->set_shape(x->shape);
}
void BinaryOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Ty", y->dtype());
add_jit_define("Tz", z->dtype());
add_jit_define("OP", ns.to_cstring());
void BinaryOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype()
<< _CS("][Ty:") << y->dtype()
<< _CS("][Tz:") << z->dtype()
<< _CS("][OP:") << ns << ']';
}
#else // JIT

View File

@ -164,10 +164,10 @@ void BroadcastToOp::infer_shape() {
LOGvvv << "Broadcast x(" >> x >> ") shape" << yshapes << "-> z(" >> z >> ")";
}
void BroadcastToOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("DIM", JK::hex1(z->shape.size()));
add_jit_define("BCAST", JK::hex(bcast_mask));
void BroadcastToOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype()
<< _CS("][DIM=") << JK::hex1(z->shape.size())
<< _CS("][BCAST=") << JK::hex(bcast_mask) << ']';
}
#else // JIT

View File

@ -23,11 +23,11 @@ void CandidateOp::infer_shape() {
y->set_shape({-std::abs(x->shape[0])});
}
void CandidateOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Ty", y->dtype());
add_jit_define("FUNC", fail_cond);
add_jit_define("XDIM", JK::hex1(x->shape.size()));
void CandidateOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][Ty:") << y->dtype();
jk << _CS("][FUNC:") << fail_cond;
jk << _CS("][XDIM=") << JK::hex1(x->shape.size()) << ']';
}
#else // JIT

View File

@ -111,28 +111,28 @@ VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
);
}
void CodeOp::jit_prepare() {
void CodeOp::jit_prepare(JK& jk) {
// forward: in0 in1 in2 -> out0 out1
// backward: in0 in1 in2 in3(pout0) in4(pout1)
add_jit_define("IN_SIZE", JK::hex1(_inputs.size()));
jk << _CS("[IN_SIZE=") << JK::hex(_inputs.size());
for (uint i=0; i<_inputs.size(); i++) {
jk << JK::key << "in" << JK::hex1(i) << "_dim" <<
JK::val << JK::hex1(_inputs[i]->shape.size()) << JK::end;
jk << JK::key << "in" << JK::hex1(i) << "_type" <<
JK::val << _inputs[i]->dtype() << JK::end;
jk << _CS("][in") << JK::hex(i) << _CS("_dim=")
<< JK::hex1(_inputs[i]->shape.size());
jk << _CS("][in") << JK::hex(i) << _CS("_type:")
<< _inputs[i]->dtype();
}
add_jit_define("OUT_SIZE", JK::hex1(_outputs.size()));
jk << _CS("][OUT_SIZE=") << JK::hex(_outputs.size());
for (uint i=0; i<_outputs.size(); i++) {
jk << JK::key << "out" << JK::hex1(i) << "_dim" <<
JK::val << JK::hex1(_outputs[i]->shape.size()) << JK::end;
jk << JK::key << "out" << JK::hex1(i) << "_type" <<
JK::val << _outputs[i]->dtype() << JK::end;
jk << _CS("][out") << JK::hex(i) << _CS("_dim=")
<< JK::hex1(_outputs[i]->shape.size());
jk << _CS("][out") << JK::hex(i) << _CS("_type:")
<< _outputs[i]->dtype();
}
if (flags.get(NodeFlags::_cuda)) {
jk << JK::key << "HEADER" << JK::val << cuda_header;
jk << _CS("][HEADER:") << cuda_header;
ASSERT(cuda_src.size());
jk << "\nnamespace jittor {\n";
jk << _CS("\nnamespace jittor {\n");
int i=0;
// move cuda kernel function into header
for (; i<cuda_src.size(); i++) {
@ -153,13 +153,12 @@ void CodeOp::jit_prepare() {
}
} else break;
}
jk << "}" << JK::end << JK::key << "CODE" << JK::val;
jk << _CS("}][CODE:");
for (; i<cuda_src.size(); i++) jk << cuda_src[i];
jk << JK::end;
jk << ']';
} else {
add_jit_define("HEADER", cpu_header);
jk << JK::key << "CODE" << JK::val;
jk << cpu_src << JK::end;
jk << _CS("][HEADER:") << cpu_header;
jk << _CS("][CODE:") << cpu_src << ']';
ASSERT(cpu_src.size());
}
}

View File

@ -55,8 +55,8 @@ void ConcatOp::infer_shape() {
}
y->set_shape(shape);
}
void ConcatOp::jit_prepare() {
add_jit_define("T", "int");
void ConcatOp::jit_prepare(JK& jk) {
jk << _CS("[T:int]");
}
VarPtr ConcatOp::grad(Var* out, Var* dout, Var* v, int v_index) {

View File

@ -134,7 +134,7 @@ void GetitemOp::infer_slices(
i_to_o[i] = out_shape.size();
if (in_shape_i > 0) {
slice.fill(in_shape_i);
if (abs(slice.step) <= 1)
if (std::abs(slice.step) <= 1)
out_shape_j = (slice.stop - slice.start) * slice.step;
else if (slice.step>0)
out_shape_j = (slice.stop - slice.start - 1) / slice.step + 1;
@ -376,25 +376,25 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return make_setitem(zeros, VarSlices(vs), dout, ns_void);
}
void GetitemOp::jit_prepare() {
void GetitemOp::jit_prepare(JK& jk) {
auto in = inputs().front();
int idim = i_to_vs.size();
add_jit_define("Ti", in->dtype());
add_jit_define("IDIM", JK::hex1(i_to_vs.size()));
add_jit_define("ODIM", JK::hex1(o_shape.size()));
jk << _CS("[Ti:") << in->dtype();
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size());
jk << _CS("][ODIM=") << JK::hex1(o_shape.size());
if (first_oid_of_var>=0) {
add_jit_define("FOV", JK::hex1(first_oid_of_var));
add_jit_define("VD", JK::hex1(var_dim));
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var);
jk << _CS("][VD=") << JK::hex1(var_dim);
}
for (int i=0; i<idim; i++) {
auto iv = i_to_vs[i];
auto io = i_to_o[i];
add_jit_define("IV", JK::hex1(i), JK::shex1(iv));
add_jit_define("IO", JK::hex1(i), JK::shex1(io));
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv);
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io);
auto& v = vs.slices[iv];
if (iv>=0 && io==-1) {
if (v.is_int()) {
add_jit_define("VS", JK::hex1(i), "-1");
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
} else {
ASSERT(v.is_var());
auto var = v.var;
@ -406,16 +406,17 @@ void GetitemOp::jit_prepare() {
if (vshape[j] == o_shape[k])
vsmask |= 1<<(j+var_dim-vdim);
}
add_jit_define("VS", JK::hex1(i), JK::hex(vsmask));
add_jit_define("VST", JK::hex1(i), var->dtype());
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask);
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype();
}
} else
if (iv>=0 && io>=0) {
ASSERT(v.is_slice());
jk << _CS("][VS") << JK::hex1(i) << ':';
if (std::abs(v.slice.step) <= 1)
add_jit_define("VS", JK::hex1(i), JK::shex1(v.slice.step));
jk << JK::shex1(v.slice.step);
else
add_jit_define("VS", JK::hex1(i), "0");
jk << '0';
}
}
#ifdef HAS_CUDA
@ -425,10 +426,11 @@ void GetitemOp::jit_prepare() {
int tdims[6];
cuda_loop_schedule(o_shape, masks, tdims);
for (int i=0; i<no; i++) {
add_jit_define("LO", JK::hex1(i), JK::hex(masks[i]));
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]);
}
}
#endif
jk << ']';
}
#else // JIT

View File

@ -50,10 +50,10 @@ void IndexOp::infer_shape() {
o->set_shape(a->shape);
}
void IndexOp::jit_prepare() {
add_jit_define("T", x[0]->dtype());
add_jit_define("DIM", JK::hex1(dim));
add_jit_define("XDIM", JK::hex1(x[0]->shape.size()));
void IndexOp::jit_prepare(JK& jk) {
add_jit_define(jk, "T", x[0]->dtype());
add_jit_define(jk, "DIM", JK::hex1(dim));
add_jit_define(jk, "XDIM", JK::hex1(x[0]->shape.size()));
}
#else // JIT

View File

@ -37,9 +37,9 @@ RandomOp::RandomOp(NanoVector shape, NanoString dtype, NanoString type) {
ASSERT(type == ns_normal || type == ns_uniform);
}
void RandomOp::jit_prepare() {
add_jit_define("T", output->dtype());
add_jit_define("R", type);
void RandomOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << output->dtype();
jk << _CS("][R:") << type << ']';
}
#else // JIT

View File

@ -134,13 +134,13 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return nullptr;
}
void ReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Ty", y->dtype());
add_jit_define("Tz", y->dtype());
add_jit_define("OP", ns.to_cstring());
add_jit_define("DIM", JK::hex1(x->shape.size()));
add_jit_define("REDUCE", JK::hex(reduce_mask));
void ReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype()
<< _CS("][Ty:") << y->dtype()
<< _CS("][Tz:") << y->dtype()
<< _CS("][OP:") << ns
<< _CS("][DIM=") << JK::hex1(x->shape.size())
<< _CS("][REDUCE=") << JK::hex(reduce_mask) << ']';
}
#else // JIT

View File

@ -86,21 +86,22 @@ void ReindexOp::infer_shape() {
CHECK(y->shape.size()) << "Number of shape should greater than 0.";
}
void ReindexOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));
add_jit_define("YDIM", JK::hex1(y->shape.size()));
add_jit_define("OVERFLOW", overflow_value);
void ReindexOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype()
<< _CS("][XDIM=") << JK::hex1(x->shape.size())
<< _CS("][YDIM=") << JK::hex1(y->shape.size())
<< _CS("][OVERFLOW:") << overflow_value;
for (uint i=0; i<indexes.size(); i++)
add_jit_define("INDEX", JK::hex1(i), indexes[i]);
add_jit_define("OSIZE", JK::hex1(overflow_conditions.size()));
jk << _CS("][INDEX") << JK::hex1(i) << ':' << indexes[i];
jk << _CS("][OSIZE=") << JK::hex1(overflow_conditions.size());
for (uint i=0; i<overflow_conditions.size(); i++)
add_jit_define("OFD", JK::hex1(i), overflow_conditions[i]);
add_jit_define("ESIZE", JK::hex1(extras.size()));
jk << _CS("][OFD") << JK::hex1(i) << ':' << overflow_conditions[i];
jk << _CS("][ESIZE=") << JK::hex1(extras.size());
for (uint i=0; i<extras.size(); i++) {
add_jit_define("EDIM", JK::hex1(i), JK::hex1(extras[i]->shape.size()));
add_jit_define("Te", JK::hex1(i), extras[i]->dtype());
jk << _CS("][EDIM") << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
jk << _CS("][Te") << JK::hex1(i) << ':' << extras[i]->dtype();
}
jk << ']';
}
#else // JIT

View File

@ -63,21 +63,22 @@ void ReindexReduceOp::infer_shape() {
CHECKop(y->size,>=,0u);
}
void ReindexReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("OP", ns.to_cstring());
add_jit_define("YDIM", JK::hex1(y->shape.size()));
add_jit_define("XDIM", JK::hex1(x->shape.size()));
void ReindexReduceOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype()
<< _CS("][OP:") << ns
<< _CS("][YDIM=") << JK::hex1(y->shape.size())
<< _CS("][XDIM=") << JK::hex1(x->shape.size());
for (uint i=0; i<indexes.size(); i++)
add_jit_define("INDEX", JK::hex1(i), indexes[i]);
add_jit_define("OSIZE", JK::hex1(overflow_conditions.size()));
jk << _CS("][INDEX") << JK::hex1(i) << ':' << indexes[i];
jk << _CS("][OSIZE=") << JK::hex1(overflow_conditions.size());
for (uint i=0; i<overflow_conditions.size(); i++)
add_jit_define("OFD", JK::hex1(i), overflow_conditions[i]);
add_jit_define("ESIZE", JK::hex1(extras.size()));
jk << _CS("][OFD") << JK::hex1(i) << ':' << overflow_conditions[i];
jk << _CS("][ESIZE=") << JK::hex1(extras.size());
for (uint i=0; i<extras.size(); i++) {
add_jit_define("EDIM", JK::hex1(i), JK::hex1(extras[i]->shape.size()));
add_jit_define("Te", JK::hex1(i), extras[i]->dtype());
jk << _CS("][EDIM") << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size());
jk << _CS("][Te") << JK::hex1(i) << ':' << extras[i]->dtype();
}
jk << ']';
}
#else // JIT

View File

@ -173,7 +173,7 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return nullptr;
}
void SetitemOp::jit_prepare() {
void SetitemOp::jit_prepare(JK& jk) {
for (int i=0; i<o_shape.size(); i++)
if (o_shape[i]<0) {
// because output shape is inferd, check in
@ -183,28 +183,28 @@ void SetitemOp::jit_prepare() {
break;
}
auto data = input(1);
add_jit_define("OP", op);
add_jit_define("Td", data->dtype());
add_jit_define("BMASK", JK::hex(bmask));
jk << _CS("[OP:") << op
<< _CS("][Td:") << data->dtype()
<< _CS("][BMASK=") << JK::hex(bmask);
// TODO: merge code
auto in = inputs().front();
int idim = i_to_vs.size();
add_jit_define("Ti", in->dtype());
add_jit_define("IDIM", JK::hex1(i_to_vs.size()));
add_jit_define("ODIM", JK::hex1(o_shape.size()));
jk << _CS("][Ti:") << in->dtype();
jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size());
jk << _CS("][ODIM=") << JK::hex1(o_shape.size());
if (first_oid_of_var>=0) {
add_jit_define("FOV", JK::hex1(first_oid_of_var));
add_jit_define("VD", JK::hex1(var_dim));
jk << _CS("][FOV=") << JK::hex1(first_oid_of_var);
jk << _CS("][VD=") << JK::hex1(var_dim);
}
for (int i=0; i<idim; i++) {
auto iv = i_to_vs[i];
auto io = i_to_o[i];
add_jit_define("IV", JK::hex1(i), JK::shex1(iv));
add_jit_define("IO", JK::hex1(i), JK::shex1(io));
jk << _CS("][IV") << JK::hex1(i) << ':' << JK::shex1(iv);
jk << _CS("][IO") << JK::hex1(i) << ':' << JK::shex1(io);
auto& v = vs.slices[iv];
if (iv>=0 && io==-1) {
if (v.is_int()) {
add_jit_define("VS", JK::hex1(i), "-1");
jk << _CS("][VS") << JK::hex1(i) << _CS(":-1");
} else {
ASSERT(v.is_var());
auto var = v.var;
@ -216,16 +216,17 @@ void SetitemOp::jit_prepare() {
if (vshape[j] == o_shape[k])
vsmask |= 1<<(j+var_dim-vdim);
}
add_jit_define("VS", JK::hex1(i), JK::hex(vsmask));
add_jit_define("VST", JK::hex1(i), var->dtype());
jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask);
jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype();
}
} else
if (iv>=0 && io>=0) {
ASSERT(v.is_slice());
jk << _CS("][VS") << JK::hex1(i) << ':';
if (std::abs(v.slice.step) <= 1)
add_jit_define("VS", JK::hex1(i), JK::shex1(v.slice.step));
jk << JK::shex1(v.slice.step);
else
add_jit_define("VS", JK::hex1(i), "0");
jk << '0';
}
}
#ifdef HAS_CUDA
@ -235,10 +236,11 @@ void SetitemOp::jit_prepare() {
int tdims[6];
cuda_loop_schedule(o_shape, masks, tdims);
for (int i=0; i<no; i++) {
add_jit_define("LO", JK::hex1(i), JK::hex(masks[i]));
jk << _CS("][LO") << JK::hex1(i) << '=' << JK::hex(masks[i]);
}
}
#endif
jk << ']';
}
void SetitemOp::compile_optimize(string& src) {

View File

@ -45,11 +45,11 @@ void TernaryOp::infer_shape() {
z->set_shape(x->shape);
}
void TernaryOp::jit_prepare() {
add_jit_define("Tc", cond->dtype());
add_jit_define("Tx", x->dtype());
add_jit_define("Ty", y->dtype());
add_jit_define("Tz", z->dtype());
void TernaryOp::jit_prepare(JK& jk) {
jk << _CS("[Tc:") << cond->dtype();
jk << _CS("][Tx:") << x->dtype();
jk << _CS("][Ty:") << y->dtype();
jk << _CS("][Tz:") << z->dtype() << ']';
}
#else // JIT

View File

@ -105,11 +105,12 @@ VarPtr TransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return make_transpose(dout, reverse);
}
void TransposeOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("DIM", JK::hex1(axes.size()));
void TransposeOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][DIM=") << JK::hex1(axes.size());
for (uint i=0; i<axes.size(); i++)
add_jit_define("AXES", JK::hex1(axes[i]), S(i));
jk << _CS("][AXES") << JK::hex1(axes[i]) << '=' << JK::hex1(i);
jk << ']';
}
#else // JIT

View File

@ -198,10 +198,10 @@ void UnaryOp::infer_shape() {
y->set_shape(x->shape);
}
void UnaryOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("Ty", y->dtype());
add_jit_define("OP", ns.to_cstring());
void UnaryOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype()
<< _CS("][Ty:") << y->dtype()
<< _CS("][OP:") << ns << ']';
}
#else // JIT

View File

@ -31,10 +31,10 @@ void WhereOp::infer_shape() {
outs[i]->set_shape({num});
}
void WhereOp::jit_prepare() {
add_jit_define("Ti", cond->dtype());
add_jit_define("To", outs[0]->dtype());
add_jit_define("NDIM", JK::hex1(cond->shape.size()));
void WhereOp::jit_prepare(JK& jk) {
add_jit_define(jk, "Ti", cond->dtype());
add_jit_define(jk, "To", outs[0]->dtype());
add_jit_define(jk, "NDIM", JK::hex1(cond->shape.size()));
}
#else // JIT

View File

@ -1040,11 +1040,11 @@ void KernelIR::split_loop(int i, int j) {
inner[2]->attrs["code"] = lvalue+"+="+rvalue2+";";
push_back("for ("+dtype+" id"+sj+"=0; id"+sj+"<range"+sj+"; id"+sj+"++) {}");
auto& sloop = children.back();
int range, stride;
int range=0, stride=0;
if (get_number("range"+si, range) && get_number("stride"+si, stride) && (range%stride==0))
push_back(dtype+" range"+sj+" = "+S(stride)+";", &inner);
else {
ASSERT(range != -1 && stride != -1) << range << stride;
ASSERT(range != -1 && stride != -1) << range << stride << si;
push_back(dtype+" range"+sj+" = ::min(range"+si+"-id"+si+", stride"+si+");", &inner);
}
sloop->attrs["loop_id"] = sj;
@ -1081,7 +1081,7 @@ void KernelIR::resplit() {
// define
push_front(dtype+" "+lvalue+" = 0;", &before);
int num;
int num=0;
if (get_number(rvalue2, num)) {
// range = number;
inner[3]->attrs["rvalue"] = S(num);

View File

@ -41,7 +41,7 @@ void UnrollPass::run() {
if (choice==1)
loop->push_back("#pragma unroll", &loop->before);
else {
int num;
int num=0;
auto& split_id = loop->get_attr("split_id");
auto& loop_id = loop->get_attr("loop_id");
auto& rvalue = loop->get_attr("rvalue");

View File

@ -31,7 +31,7 @@ void VectorizePass::run() {
if (choice == 1) {
loop->push_back("#pragma vector", &loop->before);
} else if (choice > 1) {
int num;
int num=0;
if (!loop->get_number(loop->get_attr("rvalue"), num)) {
if (loop->has_attr("split_id")) {
string si = loop->attrs["split_id"];

View File

@ -233,7 +233,7 @@ void ConvTuner::forwardTune(FusedOp* fop) {
continue;
Op* ops[3] = {op, bop->x->input(), bop->y->input()};
int ok = 0;
LOGvvvv << "conv like op" << fop << fop->get_jit_key();
LOGvvvv << "conv like op" << fop << fop->get_jit_key(jk);
for (int y_id=0; y_id<3; y_id++)
for (int x_id=0; x_id<3; x_id++)
for (int w_id=0; w_id<3; w_id++) {

View File

@ -68,7 +68,7 @@ int VarRelayManager::add_relay_group(const vector<pair<Var*, Var*>>& group) {
if (node->is_var())
continue;
Op* op = node->op();
op->do_jit_prepare();
op->do_jit_prepare(jk);
list<Node*> new_inputs;
int removed = 0;
for (Var* v : op->inputs())

View File

@ -89,14 +89,27 @@ struct SimpleThreads {
}
};
static int last_compiled_op_num = 0;
static int not_compile_window = 0;
void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int64 tt) {
// jit_search_kernel require compile at runtime
if (jit_search_kernel || !use_parallel_op_compiler)
if (jit_search_kernel || !use_parallel_op_compiler || not_compile_window > 1000)
return;
// try not use parallel compile if no op needs compile
if (last_compiled_op_num != jit_key_mapper.size()) {
not_compile_window = 0;
last_compiled_op_num = jit_key_mapper.size();
} else {
not_compile_window += queue.size();
}
vector<int> op_needs_compile;
string_view_map<int> map;
vector<unique_ptr<FusedOp>> fop_needs_compile;
auto& jkl = jk;
for (uint rid=0; rid<queue.size(); rid++) {
int root = queue[rid];
@ -111,7 +124,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
}
LOGvvv << "Check op needs compile:" << op;
op->do_prepare();
op->do_prepare(jkl);
if (jk.empty()) continue;
const char* jit_key = jk.to_cstring();
@ -133,8 +146,8 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
LOGvv << "Op needs compile:" << op;
} catch (const std::exception& e) {
// log jit_key and file location
op->do_prepare();
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
op->do_prepare(jkl);
string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc");
LOGe << "[Error] source file location:" << jit_src_path;
if (is_fused_op) {
LOGf << "Compile fused operator(" >> rid >> '/' >> queue.size() >> ")"
@ -173,6 +186,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
auto func = [&](int tid) {
auto& entrys = op_entrys.at(tid);
entrys.clear();
auto& jkl = jk;
while (!has_error && !segfault_happen) {
int i = ai++;
if (i >= n) break;
@ -184,9 +198,9 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
int root = queue[rid];
op = ops[root];
LOGvv << "Compile Op:" << op;
op->do_prepare();
op->do_prepare(jkl);
auto op_entry = OpCompiler::do_compile(op);
entrys.emplace_back(std::make_tuple(i, 0, (void*)op_entry, op->get_jit_key()));
entrys.emplace_back(std::make_tuple(i, 0, (void*)op_entry, op->get_jit_key(jkl)));
} else {
FusedOp& fused_op = *fop_needs_compile[-rid-1];
op = &fused_op;
@ -194,15 +208,15 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
LOGV(11) << "FusedOps:" << fused_op.ops;
fused_op.context = new FusedOpContext();
fused_op.context->setup(&fused_op);
fused_op.do_prepare();
fused_op.do_prepare(jkl);
auto op_entry = OpCompiler::do_compile(op);
fused_op.context->entry = op_entry;
entrys.emplace_back(std::make_tuple(i, 1, (void*)fused_op.context, op->get_jit_key()));
entrys.emplace_back(std::make_tuple(i, 1, (void*)fused_op.context, op->get_jit_key(jkl)));
// compile relay operators
for (auto& vrg : fused_op.context->vrm.relay_groups) {
for (auto& orc : vrg.oprcs) {
orc.op->do_prepare();
orc.op->do_prepare(jkl);
bool needs_compile;
{
std::lock_guard<std::mutex> lock(entry_lock);
@ -224,7 +238,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
}
} catch (const std::exception& e) {
// log jit_key and file location
op->do_prepare();
op->do_prepare(jkl);
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
LOGe << "[Error] source file location:" << jit_src_path;

View File

@ -0,0 +1,89 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include <chrono>
#include <iomanip>
#include "common.h"
namespace jittor {
static inline int _lzcnt(int64 v) {
#ifdef __clang__
#if __has_feature(__builtin_ia32_lzcnt_u64)
return __builtin_ia32_lzcnt_u64(v);
#else
return v ? __builtin_clzll(v) : 64;
#endif
#else
return __builtin_clzll(v);
#endif
}
struct SimpleProfiler {
string name;
int64 cnt;
int64 total_ns;
int64 pcnt[7] = {0};
int64 pns[7] = {0};
int64 last[7] = {0};
inline SimpleProfiler(string&& name): name(move(name)), cnt(0), total_ns(0) {}
inline ~SimpleProfiler() {
std::cerr << "=============================\nSimpleProfiler [" << name << "] cnt: " << cnt << " total: ";
if (total_ns < 1.e3)
std::cerr << total_ns << " ns" << std::endl;
else if (total_ns < 1.e6)
std::cerr << std::setprecision(3) << total_ns/1.e3 << " us" << std::endl;
else if (total_ns < 1.e9)
std::cerr << std::setprecision(3) << total_ns/1.e6 << " ms" << std::endl;
else
std::cerr << std::setprecision(3) << total_ns/1.e9 << " s" << std::endl;
std::cerr << " <32ns <1us <32us <1ms <32ms <1s >1s\n";
std::cerr << "cnt: ";
for (int i=0; i<7; i++) std::cerr << std::setw(9) << pcnt[i];
std::cerr << "\n ";
for (int i=0; i<7; i++) std::cerr << std::setw(9) << std::setprecision(3) << pcnt[i]*1.0/cnt;
std::cerr << "\ntime:";
for (int i=0; i<7; i++) std::cerr << std::setw(9) << std::setprecision(3) << pns[i]*1.0/total_ns;
std::cerr << "\nlast:";
for (int i=0; i<7; i++) std::cerr << std::setw(9) << last[i];
std::cerr << std::endl;
}
inline void add(int64 time) {
auto nbit = 64 - _lzcnt(time);
auto i = (nbit-1) / 5;
if (i>6) i=6;
cnt ++;
total_ns += time;
pcnt[i] ++;
pns[i] += time;
last[i] = cnt;
}
};
/*
example:
{
static SimpleProfiler _("array");
SimpleProfilerGuard __(_);
......
}
*/
struct SimpleProfilerGuard {
SimpleProfiler* p;
std::chrono::high_resolution_clock::time_point start;
inline SimpleProfilerGuard(SimpleProfiler& p) : p(&p) {
start = std::chrono::high_resolution_clock::now();
}
inline ~SimpleProfilerGuard() {
auto finish = std::chrono::high_resolution_clock::now();
auto total_ns = (int64_t)std::chrono::duration_cast<std::chrono::nanoseconds>(finish-start).count();
p->add(total_ns);
}
};
} // jittor

View File

@ -13,6 +13,7 @@
#include "misc/hash.h"
#include "misc/nano_string.h"
#include "misc/fast_shared_ptr.h"
#include "profiler/simple_profiler.h"
#ifdef HAS_CUDA
#include "misc/cuda_flags.h"
#endif

View File

@ -42,11 +42,11 @@ JIT_TEST(jit_key) {
}
jk.clear();
add_jit_define("f", 0.01);
add_jit_define("f", 0.5);
add_jit_define("f", 1.0/0);
add_jit_define("f", -1.0/0);
add_jit_define("f", 0.0/0);
add_jit_define(jk, "f", 0.01);
add_jit_define(jk, "f", 0.5);
add_jit_define(jk, "f", 1.0/0);
add_jit_define(jk, "f", -1.0/0);
add_jit_define(jk, "f", 0.0/0);
keys = parse_jit_keys(jk.to_string());
k2 = {{"f","0x1.47ae147ae147bp-7"},
{"f","0x1p-1"},

View File

@ -275,7 +275,7 @@ a[2]++;
ir.move_out_children();
ir.push_back("T x=1;");
ir.push_back("T y=n;");
int num;
int num=0;
CHECK(ir.get_number("x", num) && num==1);
CHECK(!ir.get_number("z", num) && num==-1);
CHECK(!ir.get_number("y", num) && num==-2);

View File

@ -45,7 +45,7 @@ JIT_TEST(fused_op_relay_matmul) {
});
CHECKop(q.size(),==,10);
CHECKop(ops.size(),==,4);
for (auto op : ops) op->do_jit_prepare();
for (auto op : ops) op->do_jit_prepare(jk);
FusedOp fop;
FusedOpContext context;
fop.context = &context;