From c5ccdaf3305e10b9d83b8124f951b88378f1eaf6 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Mon, 23 May 2022 14:23:02 +0800 Subject: [PATCH] polish jit key --- python/jittor/__init__.py | 2 +- .../extern/cuda/cub/ops/cub_arg_reduce_op.cc | 10 +- .../extern/cuda/cub/ops/cub_argsort_op.cc | 14 +- .../extern/cuda/cub/ops/cub_cumsum_op.cc | 7 +- .../jittor/extern/cuda/cub/ops/cub_test_op.cc | 2 +- .../extern/cuda/cub/ops/cub_where_op.cc | 7 +- .../cuda/cublas/ops/cublas_acc_matmul_op.cc | 9 +- .../cublas/ops/cublas_batched_matmul_op.cc | 9 +- .../cuda/cublas/ops/cublas_matmul_op.cc | 9 +- .../extern/cuda/cublas/ops/cublas_test_op.cc | 2 +- .../cudnn/ops/cudnn_conv3d_backward_w_op.cc | 7 +- .../cudnn/ops/cudnn_conv3d_backward_x_op.cc | 7 +- .../extern/cuda/cudnn/ops/cudnn_conv3d_op.cc | 7 +- .../cudnn/ops/cudnn_conv_backward_w_op.cc | 13 +- .../cudnn/ops/cudnn_conv_backward_x_op.cc | 13 +- .../extern/cuda/cudnn/ops/cudnn_conv_op.cc | 13 +- .../cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc | 7 +- .../extern/cuda/cudnn/ops/cudnn_rnn_op.cc | 7 +- .../extern/cuda/cudnn/ops/cudnn_test_op.cc | 2 +- .../extern/cuda/cufft/ops/cufft_fft_op.cc | 6 +- .../cuda/curand/ops/curand_random_op.cc | 4 +- .../extern/cuda/cutt/ops/cutt_test_op.cc | 2 +- .../extern/cuda/cutt/ops/cutt_transpose_op.cc | 2 +- .../cuda/nccl/ops/nccl_all_reduce_op.cc | 2 +- .../extern/cuda/nccl/ops/nccl_broadcast_op.cc | 2 +- .../extern/cuda/nccl/ops/nccl_reduce_op.cc | 2 +- .../extern/cuda/nccl/ops/nccl_test_op.cc | 2 +- .../extern/mkl/ops/mkl_conv_backward_w_op.cc | 19 +- .../extern/mkl/ops/mkl_conv_backward_x_op.cc | 19 +- python/jittor/extern/mkl/ops/mkl_conv_op.cc | 19 +- python/jittor/extern/mkl/ops/mkl_matmul_op.cc | 6 +- python/jittor/extern/mkl/ops/mkl_test_op.cc | 2 +- .../extern/mpi/ops/mpi_all_reduce_op.cc | 4 +- .../jittor/extern/mpi/ops/mpi_broadcast_op.cc | 2 +- python/jittor/extern/mpi/ops/mpi_reduce_op.cc | 4 +- python/jittor/extern/mpi/ops/mpi_test_op.cc | 2 +- python/jittor/src/executor.cc | 4 + python/jittor/src/fused_op.cc | 26 +-- python/jittor/src/jit_compiler.cc | 2 +- python/jittor/src/jit_key.cc | 46 ++-- python/jittor/src/jit_key.h | 211 ++++++++++-------- python/jittor/src/op.cc | 14 +- python/jittor/src/ops/arg_reduce_op.cc | 15 +- python/jittor/src/ops/argsort_op.cc | 11 +- python/jittor/src/ops/array_op.cc | 4 +- python/jittor/src/ops/binary_op.cc | 8 +- python/jittor/src/ops/broadcast_to_op.cc | 6 +- python/jittor/src/ops/candidate_op.cc | 8 +- python/jittor/src/ops/code_op.cc | 19 +- python/jittor/src/ops/fuse_transpose_op.cc | 9 +- python/jittor/src/ops/getitem_op.cc | 29 ++- python/jittor/src/ops/random_op.cc | 4 +- python/jittor/src/ops/reduce_op.cc | 12 +- python/jittor/src/ops/reindex_op.cc | 21 +- python/jittor/src/ops/reindex_reduce_op.cc | 21 +- python/jittor/src/ops/safe_clip_op.cc | 2 +- python/jittor/src/ops/setitem_op.cc | 35 ++- python/jittor/src/ops/ternary_op.cc | 8 +- python/jittor/src/ops/transpose_op.cc | 7 +- python/jittor/src/ops/unary_op.cc | 6 +- python/jittor/src/ops/where_op.cc | 7 +- python/jittor/src/test/test_jit_key.cc | 18 +- python/jittor/src/test/test_op_relay.cc | 2 +- .../jittor/test/test_merge_single_array_op.py | 2 +- python/jittor/test/test_resize_and_crop.py | 2 +- python/jittor/test/test_setitem.py | 9 + 66 files changed, 407 insertions(+), 407 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 2dd8435c..9b9dd992 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.4.7' +__version__ = '1.3.4.8' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.cc b/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.cc index 9290fc42..60c4d581 100644 --- a/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.cc +++ b/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.cc @@ -59,13 +59,13 @@ void CubArgReduceOp::infer_shape() { } void CubArgReduceOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Toffsets:") << offsets->dtype(); - jk << _CS("][FUNC:"); + jk << "«Tx:" << x->dtype(); + jk << "«Toffsets:" << offsets->dtype(); + jk << "«FUNC:"; if (op==ns_minimum) - jk << _CS("ArgMin]"); + jk << "ArgMin"; else - jk << _CS("ArgMax]"); + jk << "ArgMax"; } #else // JIT diff --git a/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc b/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc index 549229a3..c53f9f87 100644 --- a/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc +++ b/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc @@ -51,15 +51,15 @@ void CubArgsortOp::infer_shape() { } void CubArgsortOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Tindexes:") << indexes->dtype(); - jk << _CS("][Toffsets:") << offsets->dtype(); - jk << _CS("][Ty:") << y->dtype(); - jk << _CS("][FUNC:"); + jk << "«Tx:" << x->dtype(); + jk << "«Tindexes:" << indexes->dtype(); + jk << "«Toffsets:" << offsets->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«FUNC:"; if (descending) - jk << _CS("SortPairsDescending]"); + jk << "SortPairsDescending"; else - jk << _CS("SortPairs]"); + jk << "SortPairs"; } #else // JIT diff --git a/python/jittor/extern/cuda/cub/ops/cub_cumsum_op.cc b/python/jittor/extern/cuda/cub/ops/cub_cumsum_op.cc index 0900cbc6..7b08cec7 100644 --- a/python/jittor/extern/cuda/cub/ops/cub_cumsum_op.cc +++ b/python/jittor/extern/cuda/cub/ops/cub_cumsum_op.cc @@ -38,10 +38,9 @@ void CubCumsumOp::infer_shape() { } void CubCumsumOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Ty:") << y->dtype(); - jk << _CS("][reverse:") << reverse; - jk << _CS("]"); + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«reverse:" << reverse; } VarPtr CubCumsumOp::grad(Var* out, Var* dout, Var* v, int v_index) { diff --git a/python/jittor/extern/cuda/cub/ops/cub_test_op.cc b/python/jittor/extern/cuda/cub/ops/cub_test_op.cc index 400b0f31..a689b130 100644 --- a/python/jittor/extern/cuda/cub/ops/cub_test_op.cc +++ b/python/jittor/extern/cuda/cub/ops/cub_test_op.cc @@ -24,7 +24,7 @@ CubTestOp::CubTestOp(string cmd) : cmd(cmd) { } void CubTestOp::jit_prepare(JK& jk) { - jk << _CS("[T:float32]"); + jk << "«T:float32"; } #else // JIT diff --git a/python/jittor/extern/cuda/cub/ops/cub_where_op.cc b/python/jittor/extern/cuda/cub/ops/cub_where_op.cc index 9b8ef953..6656fe11 100644 --- a/python/jittor/extern/cuda/cub/ops/cub_where_op.cc +++ b/python/jittor/extern/cuda/cub/ops/cub_where_op.cc @@ -40,10 +40,9 @@ void CubWhereOp::infer_shape() { } void CubWhereOp::jit_prepare(JK& jk) { - jk << _CS("[Ti:") << cond->dtype(); - jk << _CS("][To:") << outs[0]->dtype(); - jk << _CS("][NDIM=") << JK::hex1(cond->shape.size()); - jk << ']'; + jk << "«Ti:" << cond->dtype(); + jk << "«To:" << outs[0]->dtype(); + jk << "«NDIM=" << JK::hex1(cond->shape.size()); } #else // JIT diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.cc index a6b280e0..41da0b3b 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.cc @@ -54,11 +54,10 @@ void CublasAccMatmulOp::infer_shape() { } void CublasAccMatmulOp::jit_prepare(JK& jk) { - jk << _CS("[T:") << a->dtype(); - jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); - jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); - jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); - jk << ']'; + jk << "«T:" << a->dtype(); + jk << "«Trans_a:" << (trans_a ? 'T' : 'N'); + jk << "«Trans_b:" << (trans_b ? 'T' : 'N'); + jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); } #else // JIT diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc index efccfb5e..95797efc 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc @@ -89,11 +89,10 @@ void CublasBatchedMatmulOp::infer_shape(){ } void CublasBatchedMatmulOp::jit_prepare(JK& jk) { - jk << _CS("[T:") << a->dtype(); - jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); - jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); - jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); - jk << ']'; + jk << "«T:" << a->dtype(); + jk << "«Trans_a:" << (trans_a ? 'T' : 'N'); + jk << "«Trans_b:" << (trans_b ? 'T' : 'N'); + jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); } #else // JIT diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc index b6dfde3c..46d9fc23 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc @@ -50,11 +50,10 @@ void CublasMatmulOp::infer_shape() { } void CublasMatmulOp::jit_prepare(JK& jk) { - jk << _CS("[T:") << a->dtype(); - jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); - jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); - jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); - jk << ']'; + jk << "«T:" << a->dtype(); + jk << "«Trans_a:" << (trans_a ? 'T' : 'N'); + jk << "«Trans_b:" << (trans_b ? 'T' : 'N'); + jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); } #else // JIT diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_test_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_test_op.cc index 2a0021ba..cf85acff 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_test_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_test_op.cc @@ -19,7 +19,7 @@ CublasTestOp::CublasTestOp(int size_mult) : size_mult(size_mult) { } void CublasTestOp::jit_prepare(JK& jk) { - jk << _CS("[T:float32]"); + jk << "«T:float32"; } #else // JIT diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc index 58306510..bd81438f 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc @@ -55,10 +55,9 @@ void CudnnConv3dBackwardWOp::infer_shape() { } void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Ty:") << dy->dtype(); - jk << _CS("][Tw:") << dw->dtype(); - jk << ']'; + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << dy->dtype(); + jk << "«Tw:" << dw->dtype(); } static auto make_conv3d = get_op_info("cudnn_conv3d") diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc index ab0bd7ff..d1832662 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc @@ -52,10 +52,9 @@ void CudnnConv3dBackwardXOp::infer_shape() { } void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << dx->dtype(); - jk << _CS("][Ty:") << dy->dtype(); - jk << _CS("][Tw:") << w->dtype(); - jk << ']'; + jk << "«Tx:" << dx->dtype(); + jk << "«Ty:" << dy->dtype(); + jk << "«Tw:" << w->dtype(); } diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc index 05bc068d..4b375ade 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc @@ -53,10 +53,9 @@ void CudnnConv3dOp::infer_shape() { } void CudnnConv3dOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Ty:") << y->dtype(); - jk << _CS("][Tw:") << w->dtype(); - jk << ']'; + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«Tw:" << w->dtype(); } static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x") diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc index 36ee2a86..d52b0bb3 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc @@ -71,13 +71,12 @@ void CudnnConvBackwardWOp::infer_shape() { } void CudnnConvBackwardWOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Ty:") << dy->dtype(); - jk << _CS("][Tw:") << dw->dtype(); - jk << _CS("][XFORMAT:") << xformat; - jk << _CS("][WFORMAT:") << wformat; - jk << _CS("][YFORMAT:") << yformat; - jk << ']'; + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << dy->dtype(); + jk << "«Tw:" << dw->dtype(); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; } static auto make_conv = get_op_info("cudnn_conv") diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc index 9d4a3a2c..5ee6aa39 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc @@ -70,13 +70,12 @@ void CudnnConvBackwardXOp::infer_shape() { } void CudnnConvBackwardXOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << dx->dtype(); - jk << _CS("][Ty:") << dy->dtype(); - jk << _CS("][Tw:") << w->dtype(); - jk << _CS("][XFORMAT:") << xformat; - jk << _CS("][WFORMAT:") << wformat; - jk << _CS("][YFORMAT:") << yformat; - jk << ']'; + jk << "«Tx:" << dx->dtype(); + jk << "«Ty:" << dy->dtype(); + jk << "«Tw:" << w->dtype(); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; } static auto make_conv = get_op_info("cudnn_conv") diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc index 89dc3615..7174f4c9 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc @@ -72,13 +72,12 @@ void CudnnConvOp::infer_shape() { } void CudnnConvOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Ty:") << y->dtype(); - jk << _CS("][Tw:") << w->dtype(); - jk << _CS("][XFORMAT:") << xformat; - jk << _CS("][WFORMAT:") << wformat; - jk << _CS("][YFORMAT:") << yformat; - jk << ']'; + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«Tw:" << w->dtype(); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; } static auto make_backwardx = get_op_info("cudnn_conv_backward_x") .get_constructor(); diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc index 674f283a..e51335c2 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc @@ -79,10 +79,9 @@ void CudnnRnnBackwardXOp::infer_shape() { } void CudnnRnnBackwardXOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << hx->dtype(); - jk << _CS("][Ty:") << y->dtype(); - jk << _CS("][Tw:") << w->dtype(); - jk << ']'; + jk << "«Tx:" << hx->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«Tw:" << w->dtype(); } #else // JIT diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc index d2ca5e80..53bf57dd 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc @@ -98,10 +98,9 @@ void CudnnRnnOp::infer_shape() { } void CudnnRnnOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Ty:") << y->dtype(); - jk << _CS("][Tw:") << w->dtype(); - jk << ']'; + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«Tw:" << w->dtype(); } static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x") diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.cc index dfdd1b98..88d2dc7f 100644 --- a/python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.cc +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.cc @@ -20,7 +20,7 @@ CudnnTestOp::CudnnTestOp(string cmd) : cmd(move(cmd)) { } void CudnnTestOp::jit_prepare(JK& jk) { - jk << _CS("[T:float32]"); + jk << "«T:float32"; } #else // JIT diff --git a/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc b/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc index b3277a96..99a0375b 100644 --- a/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc +++ b/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc @@ -44,9 +44,9 @@ void CufftFftOp::jit_prepare(JK& jk) { printf("not supported fft dtype: %s\n", y->dtype().to_cstring()); ASSERT(false); } - jk << _CS("[T:") << y->dtype(); - jk << _CS("][I:")<dtype()<<"\"]"; + jk << "«T:" << y->dtype(); + jk << "«I:"<dtype()<<"\"]"; } #else // JIT diff --git a/python/jittor/extern/cuda/curand/ops/curand_random_op.cc b/python/jittor/extern/cuda/curand/ops/curand_random_op.cc index 1baad3a6..6be4d081 100644 --- a/python/jittor/extern/cuda/curand/ops/curand_random_op.cc +++ b/python/jittor/extern/cuda/curand/ops/curand_random_op.cc @@ -25,8 +25,8 @@ CurandRandomOp::CurandRandomOp(NanoVector shape, NanoString dtype, NanoString ty } void CurandRandomOp::jit_prepare(JK& jk) { - jk << _CS("[T:") << output->dtype(); - jk << _CS("][R:") << type << ']'; + jk << "«T:" << output->dtype(); + jk << "«R:" << type; } #else // JIT diff --git a/python/jittor/extern/cuda/cutt/ops/cutt_test_op.cc b/python/jittor/extern/cuda/cutt/ops/cutt_test_op.cc index 5cf9ac59..490d7e7c 100644 --- a/python/jittor/extern/cuda/cutt/ops/cutt_test_op.cc +++ b/python/jittor/extern/cuda/cutt/ops/cutt_test_op.cc @@ -21,7 +21,7 @@ CuttTestOp::CuttTestOp(string cmd) : cmd(cmd) { } void CuttTestOp::jit_prepare(JK& jk) { - jk << _CS("[T:float32]"); + jk << "«T:float32"; } #else // JIT diff --git a/python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.cc b/python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.cc index fc3da8c9..c6447945 100644 --- a/python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.cc +++ b/python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.cc @@ -59,7 +59,7 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) { void CuttTransposeOp::jit_prepare(JK& jk) { // do nothing - jk << _CS("[T:1]"); + jk << "«T:1"; } unordered_map cutt_plan_cache; diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.cc b/python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.cc index b2d49dee..6ff2ba5a 100644 --- a/python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.cc +++ b/python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.cc @@ -37,7 +37,7 @@ VarPtr NcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { } void NcclAllReduceOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype() << ']'; + jk << "«Tx:" << x->dtype(); } #else // JIT diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.cc b/python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.cc index acf0fafc..f0073a05 100644 --- a/python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.cc +++ b/python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.cc @@ -35,7 +35,7 @@ VarPtr NcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) { } void NcclBroadcastOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype() << ']'; + jk << "«Tx:" << x->dtype(); } #else // JIT diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.cc b/python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.cc index e13bba48..4ff431e3 100644 --- a/python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.cc +++ b/python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.cc @@ -35,7 +35,7 @@ VarPtr NcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { } void NcclReduceOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype() << ']'; + jk << "«Tx:" << x->dtype(); } #else // JIT diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_test_op.cc b/python/jittor/extern/cuda/nccl/ops/nccl_test_op.cc index 0bd2ed29..fbb49daa 100644 --- a/python/jittor/extern/cuda/nccl/ops/nccl_test_op.cc +++ b/python/jittor/extern/cuda/nccl/ops/nccl_test_op.cc @@ -20,7 +20,7 @@ NcclTestOp::NcclTestOp(string cmd) : cmd(cmd) { } void NcclTestOp::jit_prepare(JK& jk) { - jk << _CS("[T:float32]"); + jk << "«T:float32"; } #else // JIT diff --git a/python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.cc b/python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.cc index 24d92d93..f218469b 100644 --- a/python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.cc +++ b/python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.cc @@ -79,16 +79,15 @@ static const char* short_type(Var* x) { } void MklConvBackwardWOp::jit_prepare(JK& jk) { - jk << _CS("[Txd:") << x->dtype(); - jk << _CS("][Tyd:") << dy->dtype(); - jk << _CS("][Twd:") << dw->dtype(); - jk << _CS("][Tx:") << short_type(x); - jk << _CS("][Tw:") << short_type(dw); - jk << _CS("][Ty:") << short_type(dy); - jk << _CS("][XFORMAT:") << xformat; - jk << _CS("][WFORMAT:") << wformat; - jk << _CS("][YFORMAT:") << yformat; - jk << ']'; + jk << "«Txd:" << x->dtype(); + jk << "«Tyd:" << dy->dtype(); + jk << "«Twd:" << dw->dtype(); + jk << "«Tx:" << short_type(x); + jk << "«Tw:" << short_type(dw); + jk << "«Ty:" << short_type(dy); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; } #else // JIT diff --git a/python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.cc b/python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.cc index ce407c01..2bdf4d0d 100644 --- a/python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.cc +++ b/python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.cc @@ -77,16 +77,15 @@ static const char* short_type(Var* x) { } void MklConvBackwardXOp::jit_prepare(JK& jk) { - jk << _CS("[Tyd:") << dy->dtype(); - jk << _CS("][Twd:") << w->dtype(); - jk << _CS("][Txd:") << dx->dtype(); - jk << _CS("][Tx:") << short_type(dx); - jk << _CS("][Tw:") << short_type(w); - jk << _CS("][Ty:") << short_type(dy); - jk << _CS("][XFORMAT:") << xformat; - jk << _CS("][WFORMAT:") << wformat; - jk << _CS("][YFORMAT:") << yformat; - jk << ']'; + jk << "«Tyd:" << dy->dtype(); + jk << "«Twd:" << w->dtype(); + jk << "«Txd:" << dx->dtype(); + jk << "«Tx:" << short_type(dx); + jk << "«Tw:" << short_type(w); + jk << "«Ty:" << short_type(dy); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; } #else // JIT diff --git a/python/jittor/extern/mkl/ops/mkl_conv_op.cc b/python/jittor/extern/mkl/ops/mkl_conv_op.cc index 91d8a5d1..a32892d7 100644 --- a/python/jittor/extern/mkl/ops/mkl_conv_op.cc +++ b/python/jittor/extern/mkl/ops/mkl_conv_op.cc @@ -81,16 +81,15 @@ static const char* short_type(Var* x) { } void MklConvOp::jit_prepare(JK& jk) { - jk << _CS("[Txd:") << x->dtype(); - jk << _CS("][Tyd:") << y->dtype(); - jk << _CS("][Twd:") << w->dtype(); - jk << _CS("][Tx:") << short_type(x); - jk << _CS("][Tw:") << short_type(w); - jk << _CS("][Ty:") << short_type(y); - jk << _CS("][XFORMAT:") << xformat; - jk << _CS("][WFORMAT:") << wformat; - jk << _CS("][YFORMAT:") << yformat; - jk << ']'; + jk << "«Txd:" << x->dtype(); + jk << "«Tyd:" << y->dtype(); + jk << "«Twd:" << w->dtype(); + jk << "«Tx:" << short_type(x); + jk << "«Tw:" << short_type(w); + jk << "«Ty:" << short_type(y); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; } #else // JIT diff --git a/python/jittor/extern/mkl/ops/mkl_matmul_op.cc b/python/jittor/extern/mkl/ops/mkl_matmul_op.cc index d9df1417..6e6b539d 100644 --- a/python/jittor/extern/mkl/ops/mkl_matmul_op.cc +++ b/python/jittor/extern/mkl/ops/mkl_matmul_op.cc @@ -44,9 +44,9 @@ void MklMatmulOp::infer_shape() { } void MklMatmulOp::jit_prepare(JK& jk) { - jk << _CS("[T:") << a->dtype(); - jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); - jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N') << ']'; + jk << "«T:" << a->dtype(); + jk << "«Trans_a:" << (trans_a ? 'T' : 'N'); + jk << "«Trans_b:" << (trans_b ? 'T' : 'N'); } #else // JIT diff --git a/python/jittor/extern/mkl/ops/mkl_test_op.cc b/python/jittor/extern/mkl/ops/mkl_test_op.cc index 4b6de8c8..539e8327 100644 --- a/python/jittor/extern/mkl/ops/mkl_test_op.cc +++ b/python/jittor/extern/mkl/ops/mkl_test_op.cc @@ -19,7 +19,7 @@ MklTestOp::MklTestOp() { } void MklTestOp::jit_prepare(JK& jk) { - jk << _CS("[T:float32]"); + jk << "«T:float32"; } #else // JIT diff --git a/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc b/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc index 8a567515..4da3971d 100644 --- a/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc +++ b/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc @@ -62,8 +62,8 @@ VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { } void MpiAllReduceOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][OP:") << op << ']'; + jk << "«Tx:" << x->dtype(); + jk << "«OP:" << op; } #else // JIT diff --git a/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc b/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc index 43618ca5..609dabcf 100644 --- a/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc +++ b/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc @@ -49,7 +49,7 @@ VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) { } void MpiBroadcastOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype() << ']'; + jk << "«Tx:" << x->dtype(); } #else // JIT diff --git a/python/jittor/extern/mpi/ops/mpi_reduce_op.cc b/python/jittor/extern/mpi/ops/mpi_reduce_op.cc index e68234a0..5ed3b99b 100644 --- a/python/jittor/extern/mpi/ops/mpi_reduce_op.cc +++ b/python/jittor/extern/mpi/ops/mpi_reduce_op.cc @@ -62,8 +62,8 @@ VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { } void MpiReduceOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][OP:") << op << ']'; + jk << "«Tx:" << x->dtype(); + jk << "«OP:" << op; } #else // JIT diff --git a/python/jittor/extern/mpi/ops/mpi_test_op.cc b/python/jittor/extern/mpi/ops/mpi_test_op.cc index 3fdb34a2..54e5ecb5 100644 --- a/python/jittor/extern/mpi/ops/mpi_test_op.cc +++ b/python/jittor/extern/mpi/ops/mpi_test_op.cc @@ -17,7 +17,7 @@ MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) { } void MpiTestOp::jit_prepare(JK& jk) { - jk << _CS("[T:float32]"); + jk << "«T:float32"; } #else // JIT diff --git a/python/jittor/src/executor.cc b/python/jittor/src/executor.cc index 46ac1b42..f59f692c 100644 --- a/python/jittor/src/executor.cc +++ b/python/jittor/src/executor.cc @@ -585,6 +585,10 @@ void Executor::run_sync(vector vars, bool device_sync, bool weak_sync) { if (!v->allocator->is_cuda()) migrate_to_gpu(v, allocator); } + for (Var* v : op->outputs()) { + if (!v->allocator->is_cuda()) + migrate_to_gpu(v, allocator); + } } #endif #ifdef NODE_MEMCHECK diff --git a/python/jittor/src/fused_op.cc b/python/jittor/src/fused_op.cc index 55606e0d..1ab49167 100644 --- a/python/jittor/src/fused_op.cc +++ b/python/jittor/src/fused_op.cc @@ -159,55 +159,51 @@ void FusedOp::do_jit_prepare(JK& jk) { jk.clear(); for (uint i=0; iname(); op->jit_prepare(jk); - jk << JK::end; } - jk << _CS("[JIT:1]"); + jk << "«JIT:1"; if (!use_cuda) { // only cpu - jk << _CS("[JIT_cpu:1]"); + jk << "«JIT_cpu:1"; this->flags.set(NodeFlags::_cuda, 0); this->flags.set(NodeFlags::_cpu, 1); } else { - jk << _CS("[JIT_cuda:1]"); + jk << "«JIT_cuda:1"; this->flags.set(NodeFlags::_cpu, 0); this->flags.set(NodeFlags::_cuda, 1); } - jk << _CS("[graph:"); + jk << "«graph:"; for (auto& t : edges) { uint i,j,k,l; std::tie(i,j,k,l) = t; jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ','; } - jk << _CS("][var_info:") << JK::val; + jk << "«var_info:" << JK::val; bool use_int64_t = false; for (auto& vi : vars) { jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size()); if (vi.type != 1 && vi.var->num >= std::numeric_limits::max()) use_int64_t = true; } - jk << JK::end; if (use_int64_t) - jk << _CS("[index_t:int64]"); + jk << "«index_t:int64"; else - jk << _CS("[index_t:int32]"); + jk << "«index_t:int32"; if (loop_options->size()) { if (get_loop_option("compile_shapes")) { - jk << _CS("[shapes:"); + jk << "«shapes:"; for (auto& vi : vars) { jk << '['; for (auto a : vi.var->shape) jk << a << ','; - jk << _CS("],"); + jk << "],"; } - jk << JK::end; } - jk << _CS("[choices:"); + jk << "«choices:"; for (auto& kv : *loop_options) jk << kv.first << ':' << kv.second << ','; - jk << JK::end; } jk.finilize(); } diff --git a/python/jittor/src/jit_compiler.cc b/python/jittor/src/jit_compiler.cc index 57e143c0..9b7d75ce 100644 --- a/python/jittor/src/jit_compiler.cc +++ b/python/jittor/src/jit_compiler.cc @@ -194,7 +194,7 @@ void run_cmd(string cmd, string cwd="") { static string get_symbol_name(const string& jit_key) { int i=0; - while (i=0) i++; string op_name = i ? jit_key.substr(0, i) : "fused"; op_name = Op::file_name_to_class_name(op_name); // _ZN7jittorXyyyyyy7jit_runEv diff --git a/python/jittor/src/jit_key.cc b/python/jittor/src/jit_key.cc index 56725c85..36fb482b 100644 --- a/python/jittor/src/jit_key.cc +++ b/python/jittor/src/jit_key.cc @@ -80,42 +80,26 @@ static void convert_itof(string& s) { vector> parse_jit_keys(const string& s) { vector> jit_keys; - int presum = 0; - char state=0; - string key, val; - for (char c : s) { - if (c==JK::key) { - presum++; - if (presum==1) { + auto sp = split(s, JitKey::key); + for (auto& ss : sp) { + if (!ss.size()) continue; + string key, val; + char state=0; + for (auto c : ss) { + if (state == 0 && + (c==JK::val || c==JK::hex_val)) { state = c; continue; } - } else - if (c==JK::val || c==JK::hex_val) { - if (presum==1 && state==JK::key) { - state = c; - continue; - } - } else - if (c==JK::end) { - presum--; - if (presum==0) { - if (state == JK::hex_val) - hex_to_dec(val); - if (startswith(val, "itof")) - convert_itof(val); - jit_keys.emplace_back(move(key), move(val)); - continue; - } - } - if (presum) { - if (state==JK::key) - key += c; - if (state==JK::val || state==JK::hex_val) - val += c; + if (state == 0) key += c; + else val += c; } + if (state == JK::hex_val) + hex_to_dec(val); + if (startswith(val, "itof")) + convert_itof(val); + jit_keys.emplace_back(move(key), move(val)); } - ASSERT(presum==0) << s; return jit_keys; } diff --git a/python/jittor/src/jit_key.h b/python/jittor/src/jit_key.h index 58204182..2320a92b 100644 --- a/python/jittor/src/jit_key.h +++ b/python/jittor/src/jit_key.h @@ -5,6 +5,7 @@ // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** #pragma once +#include #include "common.h" #include "misc/nano_string.h" #include "misc/nano_vector.h" @@ -13,11 +14,10 @@ namespace jittor { struct JitKey { static constexpr size_t buffer_size = 2*1024*1024; - static constexpr char - key = '[', + static constexpr const char + *key = "«", val = ':', - hex_val = '=', - end = ']'; + hex_val = '='; int64 size=0; uint64 flags=0; char buffer[buffer_size]; @@ -27,7 +27,7 @@ struct JitKey { inline void clear() {size = flags = 0;} inline void finilize() { buffer[size] = 0; } - inline bool empty() { return buffer[size-1] != end; } + inline bool empty() { return !size; } inline const char* to_cstring() { return &buffer[0]; } @@ -81,11 +81,38 @@ struct __jk_int256 { typedef JitKey JK; EXTERN_LIB JK& get_jk(); +inline void jk_put_str_with_len(JK& jk, const char* a, int n) { + char* xx = &jk.buffer[jk.size]; + int i=0; + while (i+32<=n) { + ((__jk_int256*)(xx+i))[0] = ((const __jk_int256*)(a+i))[0]; + i+=32; + } + while (i+16<=n) { + ((__jk_int128*)(xx+i))[0] = ((const __jk_int128*)(a+i))[0]; + i+=16; + } + while (i+8<=n) { + ((long long*)(xx+i))[0] = ((const long long*)(a+i))[0]; + i+=8; + } + while (i+4<=n) { + ((int*)(xx+i))[0] = ((const int*)(a+i))[0]; + i+=4; + } + while (i+2<=n) { + ((int16_t*)(xx+i))[0] = ((const int16_t*)(a+i))[0]; + i+=2; + } + while (i+1<=n) { + ((char*)(xx+i))[0] = ((const char*)(a+i))[0]; + i+=1; + } + jk.size += n; +} + inline JK& operator<<(JK& jk, const char* s) { - int i; - for (i=0; s[i]; i++) - jk.buffer[jk.size+i] = s[i]; - jk.size += i; + jk_put_str_with_len(jk, s, strlen(s)); return jk; } @@ -199,129 +226,133 @@ vector> parse_jit_keys(const string& s); template void add_jit_define(JK& jk, const Ta& key, const Tb& val) { - jk << JK::key << key << JK::val << val << JK::end; + jk << JK::key << key << JK::val << val; } template void add_jit_define(JK& jk, const Ta& key, const Tb& i, const Tc& val) { - jk << JK::key << key << i << JK::val << val << JK::end; + jk << JK::key << key << i << JK::val << val; } template void add_jit_define(JK& jk, const Ta& key, const JK::hex& val) { - jk << JK::key << key << JK::hex_val << val << JK::end; + jk << JK::key << key << JK::hex_val << val; } template void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex& val) { - jk << JK::key << key << i << JK::hex_val << val << JK::end; + jk << JK::key << key << i << JK::hex_val << val; } template void add_jit_define(JK& jk, const Ta& key, const JK::hex1& val) { - jk << JK::key << key << JK::hex_val << val << JK::end; + jk << JK::key << key << JK::hex_val << val; } template void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex1& val) { - jk << JK::key << key << i << JK::hex_val << val << JK::end; + jk << JK::key << key << i << JK::hex_val << val; } template void add_jit_define(JK& jk, const Ta& key, const JK::hex2& val) { - jk << JK::key << key << JK::hex_val << val << JK::end; + jk << JK::key << key << JK::hex_val << val; } template void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex2& val) { - jk << JK::key << key << i << JK::hex_val << val << JK::end; + jk << JK::key << key << i << JK::hex_val << val; } +#define _CS(x) x +// // begin of const string +// #define MAX_CONST_CHAR 32 -// begin of const string -#define MAX_CONST_CHAR 32 +// #define _CS_MIN(a,b) (a)<(b)?(a):(b) -#define _CS_MIN(a,b) (a)<(b)?(a):(b) +// #define _CS_T(s)\ +// getChr(s,0),\ +// getChr(s,1),\ +// getChr(s,2),\ +// getChr(s,3),\ +// getChr(s,4),\ +// getChr(s,5),\ +// getChr(s,6),\ +// getChr(s,7),\ +// getChr(s,8),\ +// getChr(s,9),\ +// getChr(s,10),\ +// getChr(s,11),\ +// getChr(s,12),\ +// getChr(s,13),\ +// getChr(s,14),\ +// getChr(s,15),\ +// getChr(s,16),\ +// getChr(s,17),\ +// getChr(s,18),\ +// getChr(s,19),\ +// getChr(s,20),\ +// getChr(s,21),\ +// getChr(s,22),\ +// getChr(s,23),\ +// getChr(s,24),\ +// getChr(s,25),\ +// getChr(s,26),\ +// getChr(s,27),\ +// getChr(s,28),\ +// getChr(s,29),\ +// getChr(s,30),\ +// getChr(s,31),\ +// getChr(s,32),\ +// getChr(s,33),\ +// getChr(s,34),\ +// getChr(s,35) -#define _CS_T(s)\ -getChr(s,0),\ -getChr(s,1),\ -getChr(s,2),\ -getChr(s,3),\ -getChr(s,4),\ -getChr(s,5),\ -getChr(s,6),\ -getChr(s,7),\ -getChr(s,8),\ -getChr(s,9),\ -getChr(s,10),\ -getChr(s,11),\ -getChr(s,12),\ -getChr(s,13),\ -getChr(s,14),\ -getChr(s,15),\ -getChr(s,16),\ -getChr(s,17),\ -getChr(s,18),\ -getChr(s,19),\ -getChr(s,20),\ -getChr(s,21),\ -getChr(s,22),\ -getChr(s,23),\ -getChr(s,24),\ -getChr(s,25),\ -getChr(s,26),\ -getChr(s,27),\ -getChr(s,28),\ -getChr(s,29),\ -getChr(s,30),\ -getChr(s,31),\ -getChr(s,32),\ -getChr(s,33),\ -getChr(s,34),\ -getChr(s,35) +// #define getChr(name, ii) ((_CS_MIN(ii,MAX_CONST_CHAR))() +// #endif -#ifdef _MSC_VER -#define _CS(str) str -#else -#define _CS(str) _CS_G<_CS_T(str)>() -#endif +// template struct _CS_G { +// }; -template struct _CS_G { - }; +// template<> struct _CS_G<0,0,0,0> {}; -template<> struct _CS_G<0,0,0,0> {}; +// template +// inline JK& operator<<(JK& jk, const _CS_G& _) { +// ((uint32*)(jk.buffer+jk.size))[0] = +// (uint32((uint8)(c4))<<24)+ +// (uint32((uint8)(c3))<<16)+ +// (uint32((uint8)(c2))<<8)+ +// uint32((uint8)(c1)); +// if (c4) { +// jk.size += 4; +// jk << _CS_G(); +// } else +// if (c3) { +// jk.size += 3; +// } else +// if (c2) { +// jk.size += 2; +// } else +// if (c1) { +// jk.size += 1; +// } +// return jk; +// } -template -inline JK& operator<<(JK& jk, const _CS_G& _) { - ((int*)(jk.buffer+jk.size))[0] = c4*(1<<24)+c3*(1<<16)+c2*(1<<8)+c1; - if (c4) { - jk.size += 4; - jk << _CS_G(); - } else - if (c3) { - jk.size += 3; - } else - if (c2) { - jk.size += 2; - } else - if (c1) { - jk.size += 1; - } - return jk; -} - -template <> -inline JK& operator<<(JK& jk, const _CS_G<0,0,0,0>& _) { - return jk; -} +// template <> +// inline JK& operator<<(JK& jk, const _CS_G<0,0,0,0>& _) { +// return jk; +// } inline JK& operator<<(JK& jk, float64 f) { - return jk << _CS("itof(0x") << JK::hex(ftoi(f)) << ')'; + return jk << "itof(0x" << JK::hex(ftoi(f)) << ')'; } } // jittor \ No newline at end of file diff --git a/python/jittor/src/op.cc b/python/jittor/src/op.cc index 9c05cf64..810abb41 100644 --- a/python/jittor/src/op.cc +++ b/python/jittor/src/op.cc @@ -128,14 +128,16 @@ string Op::get_hash_name() { void Op::do_jit_prepare(JK& jk) { memcheck_all_exist(); jk << name(); + auto pre_size = jk.size; jit_prepare(jk); - if (jk.empty()) { + if (jk.size == pre_size) { // not a jit op bool has_cuda = flags.get(NodeFlags::_cuda); bool has_cpu = flags.get(NodeFlags::_cpu); CHECK(has_cuda || has_cpu); if (has_cuda && has_cpu && !use_cuda) flags.set(NodeFlags::_cuda, 0); + jk.clear(); } else { bool use_int64_t = false; // TODO: fused op do not have inputs, @@ -149,9 +151,9 @@ void Op::do_jit_prepare(JK& jk) { if (var->num >= std::numeric_limits::max()) use_int64_t = true; } - jk << _CS("[JIT:1]"); + jk << "«JIT:1"; if (use_cuda_op && flags.get(NodeFlags::_cuda)) { - jk << _CS("[JIT_cuda:1]"); + jk << "«JIT_cuda:1"; flags.set(NodeFlags::_cpu, 0); // TODO: 64bit index in CUDA // use_int64_t = false; @@ -164,14 +166,14 @@ void Op::do_jit_prepare(JK& jk) { } ASSERT(flags.get(NodeFlags::_cpu)) << "Op" << name() << "doesn't have cpu version"; - jk << _CS("[JIT_cpu:1]"); + jk << "«JIT_cpu:1"; flags.set(NodeFlags::_cuda, 0); } if (try_use_32bit_index) use_int64_t = false; if (use_int64_t) - jk << _CS("[index_t:int64]"); + jk << "«index_t:int64"; else - jk << _CS("[index_t:int32]"); + jk << "«index_t:int32"; } jk.finilize(); } diff --git a/python/jittor/src/ops/arg_reduce_op.cc b/python/jittor/src/ops/arg_reduce_op.cc index 36add873..5e743e6c 100644 --- a/python/jittor/src/ops/arg_reduce_op.cc +++ b/python/jittor/src/ops/arg_reduce_op.cc @@ -158,14 +158,13 @@ void ArgReduceOp::infer_shape() { } void ArgReduceOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Ty:") << y->dtype(); - jk << _CS("][XDIM=") << JK::hex1(x->shape.size()); - jk << _CS("][YDIM=") << JK::hex1(y->shape.size()); - jk << _CS("][KEEPDIMS:") << (keepdims ? '1' : '0'); - jk << _CS("][DIM=") << JK::hex1(dim); - jk << _CS("][CMP:") << (op==ns_minimum ? "<" : ">"); - jk << ']'; + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«XDIM=" << JK::hex1(x->shape.size()); + jk << "«YDIM=" << JK::hex1(y->shape.size()); + jk << "«KEEPDIMS:" << (keepdims ? '1' : '0'); + jk << "«DIM=" << JK::hex1(dim); + jk << "«CMP:" << (op==ns_minimum ? "<" : ">"); } #else // JIT diff --git a/python/jittor/src/ops/argsort_op.cc b/python/jittor/src/ops/argsort_op.cc index 6c37521c..67434e12 100644 --- a/python/jittor/src/ops/argsort_op.cc +++ b/python/jittor/src/ops/argsort_op.cc @@ -130,12 +130,11 @@ void ArgsortOp::infer_shape() { } void ArgsortOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Ty:") << y->dtype(); - jk << _CS("][XDIM=") << JK::hex1(x->shape.size()); - jk << _CS("][DIM=") << JK::hex1(dim); - jk << _CS("][CMP:") << (descending ? '>' : '<'); - jk << ']'; + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«XDIM=" << JK::hex1(x->shape.size()); + jk << "«DIM=" << JK::hex1(dim); + jk << "«CMP:" << (descending ? '>' : '<'); } #else // JIT diff --git a/python/jittor/src/ops/array_op.cc b/python/jittor/src/ops/array_op.cc index b1bab8ae..3a2230ac 100644 --- a/python/jittor/src/ops/array_op.cc +++ b/python/jittor/src/ops/array_op.cc @@ -78,7 +78,7 @@ ArrayOp::ArrayOp(ArrayArgs&& args) { void ArrayOp::jit_prepare(JK& jk) { if (output->flags.get(NodeFlags::_force_fuse)) { - jk << _CS("[T:") << output->dtype() << ']'; + jk << "«T:" << output->dtype(); // fill or find cbuffer for const var pass if (output->dtype().dsize() == 4) { @@ -86,7 +86,7 @@ void ArrayOp::jit_prepare(JK& jk) { auto y = std::abs(ptr()[0]); auto z = ptr()[0]; if ((x<=2) || (y==1.0f || y==2.0f)) - jk << _CS("[o:") << z << ']'; + jk << "«o:" << z; } // end of fill cbuffer } diff --git a/python/jittor/src/ops/binary_op.cc b/python/jittor/src/ops/binary_op.cc index 23b286f4..70354547 100644 --- a/python/jittor/src/ops/binary_op.cc +++ b/python/jittor/src/ops/binary_op.cc @@ -540,10 +540,10 @@ void BinaryOp::infer_shape() { } void BinaryOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype() - << _CS("][Ty:") << y->dtype() - << _CS("][Tz:") << z->dtype() - << _CS("][OP:") << ns << ']'; + jk << "«Tx:" << x->dtype() + << "«Ty:" << y->dtype() + << "«Tz:" << z->dtype() + << "«OP:" << ns; } #else // JIT diff --git a/python/jittor/src/ops/broadcast_to_op.cc b/python/jittor/src/ops/broadcast_to_op.cc index e5551ea6..ed489e66 100644 --- a/python/jittor/src/ops/broadcast_to_op.cc +++ b/python/jittor/src/ops/broadcast_to_op.cc @@ -167,9 +167,9 @@ void BroadcastToOp::infer_shape() { } void BroadcastToOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype() - << _CS("][DIM=") << JK::hex1(z->shape.size()) - << _CS("][BCAST=") << JK::hex(bcast_mask) << ']'; + jk << "«Tx:" << x->dtype() + << "«DIM=" << JK::hex1(z->shape.size()) + << "«BCAST=" << JK::hex(bcast_mask); } #else // JIT diff --git a/python/jittor/src/ops/candidate_op.cc b/python/jittor/src/ops/candidate_op.cc index 7c30468f..3c64c824 100644 --- a/python/jittor/src/ops/candidate_op.cc +++ b/python/jittor/src/ops/candidate_op.cc @@ -25,10 +25,10 @@ void CandidateOp::infer_shape() { } void CandidateOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][Ty:") << y->dtype(); - jk << _CS("][FUNC:") << fail_cond; - jk << _CS("][XDIM=") << JK::hex1(x->shape.size()) << ']'; + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«FUNC:" << fail_cond; + jk << "«XDIM=" << JK::hex1(x->shape.size()); } #else // JIT diff --git a/python/jittor/src/ops/code_op.cc b/python/jittor/src/ops/code_op.cc index 79f91cee..92860650 100644 --- a/python/jittor/src/ops/code_op.cc +++ b/python/jittor/src/ops/code_op.cc @@ -123,18 +123,18 @@ void CodeOp::jit_prepare(JK& jk) { // forward: in0 in1 in2 -> out0 out1 // backward: in0 in1 in2 in3(pout0) in4(pout1) - jk << _CS("[IN_SIZE=") << JK::hex(_inputs.size()); + jk << "«IN_SIZE=" << JK::hex(_inputs.size()); for (uint i=0; i<_inputs.size(); i++) { - jk << _CS("][in") << JK::hex(i) << _CS("_dim=") + jk << "«in" << JK::hex(i) << "_dim=" << JK::hex1(_inputs[i]->shape.size()); - jk << _CS("][in") << JK::hex(i) << _CS("_type:") + jk << "«in" << JK::hex(i) << "_type:" << _inputs[i]->dtype(); } - jk << _CS("][OUT_SIZE=") << JK::hex(_outputs.size()); + jk << "«OUT_SIZE=" << JK::hex(_outputs.size()); for (uint i=0; i<_outputs.size(); i++) { - jk << _CS("][out") << JK::hex(i) << _CS("_dim=") + jk << "«out" << JK::hex(i) << "_dim=" << JK::hex1(_outputs[i]->shape.size()); - jk << _CS("][out") << JK::hex(i) << _CS("_type:") + jk << "«out" << JK::hex(i) << "_type:" << _outputs[i]->dtype(); } string& header = flags.get(NodeFlags::_cuda) ? @@ -142,9 +142,9 @@ void CodeOp::jit_prepare(JK& jk) { string& src = flags.get(NodeFlags::_cuda) ? cuda_src : cpu_src; - jk << _CS("][HEADER:") << header; + jk << "«HEADER:" << header; CHECK(src.size()); - jk << _CS("\nnamespace jittor {\n"); + jk << "\nnamespace jittor {\n"; int i=0; // move cuda kernel function into header for (; idtype(); - jk << _CS("][DIM=") << JK::hex1(axes.size()); - jk << _CS("][BC:") << JK::hex1(bc); + jk << "«Tx:" << x->dtype(); + jk << "«DIM=" << JK::hex1(axes.size()); + jk << "«BC:" << JK::hex1(bc); for (uint i=0; i=0 && io==-1) { if (v.is_int()) { - jk << _CS("][VS") << JK::hex1(i) << _CS(":-1"); + jk << "«VS" << JK::hex1(i) << ":-1"; } else if (v.is_str()) { - jk << _CS("][VS") << JK::hex1(i) << _CS(":-5"); - jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str(); + jk << "«VS" << JK::hex1(i) << ":-5"; + jk << "«VSS" << JK::hex1(i) << ":" << v.get_str(); } else { ASSERT(v.is_var()); auto var = v.var; @@ -475,13 +475,13 @@ void GetitemOp::jit_prepare(JK& jk) { if (vshape[j] == o_shape[k]) vsmask |= 1<<(j+var_dim-vdim); } - jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask); - jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype(); + jk << "«VS" << JK::hex1(i) << '=' << JK::hex(vsmask); + jk << "«VST" << JK::hex1(i) << ':' << var->dtype(); } } else if (iv>=0 && io>=0) { ASSERT(v.is_slice()); - jk << _CS("][VS") << JK::hex1(i) << ':'; + jk << "«VS" << JK::hex1(i) << ':'; if (std::abs(v.slice.step) <= 1) jk << JK::shex1(v.slice.step); else @@ -495,11 +495,10 @@ void GetitemOp::jit_prepare(JK& jk) { int tdims[6]; cuda_loop_schedule(o_shape, masks, tdims); for (int i=0; idtype(); - jk << _CS("][R:") << type << ']'; + jk << "«T:" << output->dtype(); + jk << "«R:" << type; } #else // JIT diff --git a/python/jittor/src/ops/reduce_op.cc b/python/jittor/src/ops/reduce_op.cc index cf3938f3..7376a348 100644 --- a/python/jittor/src/ops/reduce_op.cc +++ b/python/jittor/src/ops/reduce_op.cc @@ -338,12 +338,12 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { } void ReduceOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype() - << _CS("][Ty:") << y->dtype() - << _CS("][Tz:") << y->dtype() - << _CS("][OP:") << ns - << _CS("][DIM=") << JK::hex1(x->shape.size()) - << _CS("][REDUCE=") << JK::hex(reduce_mask) << ']'; + jk << "«Tx:" << x->dtype() + << "«Ty:" << y->dtype() + << "«Tz:" << y->dtype() + << "«OP:" << ns + << "«DIM=" << JK::hex1(x->shape.size()) + << "«REDUCE=" << JK::hex(reduce_mask); } #else // JIT diff --git a/python/jittor/src/ops/reindex_op.cc b/python/jittor/src/ops/reindex_op.cc index aea08213..04817a9d 100644 --- a/python/jittor/src/ops/reindex_op.cc +++ b/python/jittor/src/ops/reindex_op.cc @@ -91,21 +91,20 @@ void ReindexOp::infer_shape() { } void ReindexOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype() - << _CS("][XDIM=") << JK::hex1(x->shape.size()) - << _CS("][YDIM=") << JK::hex1(y->shape.size()) - << _CS("][OVERFLOW:") << overflow_value; + jk << "«Tx:" << x->dtype() + << "«XDIM=" << JK::hex1(x->shape.size()) + << "«YDIM=" << JK::hex1(y->shape.size()) + << "«OVERFLOW:" << overflow_value; for (uint i=0; idtype(); + jk << "«EDIM" << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size()); + jk << "«Te" << JK::hex1(i) << ':' << extras[i]->dtype(); } - jk << ']'; } #else // JIT diff --git a/python/jittor/src/ops/reindex_reduce_op.cc b/python/jittor/src/ops/reindex_reduce_op.cc index 2f6fa0ab..cd5fa4dc 100644 --- a/python/jittor/src/ops/reindex_reduce_op.cc +++ b/python/jittor/src/ops/reindex_reduce_op.cc @@ -73,21 +73,20 @@ void ReindexReduceOp::infer_shape() { } void ReindexReduceOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype() - << _CS("][OP:") << ns - << _CS("][YDIM=") << JK::hex1(y->shape.size()) - << _CS("][XDIM=") << JK::hex1(x->shape.size()); + jk << "«Tx:" << x->dtype() + << "«OP:" << ns + << "«YDIM=" << JK::hex1(y->shape.size()) + << "«XDIM=" << JK::hex1(x->shape.size()); for (uint i=0; idtype(); + jk << "«EDIM" << JK::hex1(i) << '=' << JK::hex1(extras[i]->shape.size()); + jk << "«Te" << JK::hex1(i) << ':' << extras[i]->dtype(); } - jk << ']'; } #else // JIT diff --git a/python/jittor/src/ops/safe_clip_op.cc b/python/jittor/src/ops/safe_clip_op.cc index 1b02fedb..4787f058 100644 --- a/python/jittor/src/ops/safe_clip_op.cc +++ b/python/jittor/src/ops/safe_clip_op.cc @@ -30,7 +30,7 @@ void SafeClipOp::infer_shape() { } void SafeClipOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype() <<']'; + jk << "«Tx:" << x->dtype() <<"«"; } #else // JIT diff --git a/python/jittor/src/ops/setitem_op.cc b/python/jittor/src/ops/setitem_op.cc index fed160d1..fde1fe1d 100644 --- a/python/jittor/src/ops/setitem_op.cc +++ b/python/jittor/src/ops/setitem_op.cc @@ -201,32 +201,32 @@ void SetitemOp::jit_prepare(JK& jk) { break; } auto data = input(1); - jk << _CS("[OP:") << op - << _CS("][Td:") << data->dtype() - << _CS("][BMASK=") << JK::hex(bmask); + jk << "«OP:" << op + << "«Td:" << data->dtype() + << "«BMASK=" << JK::hex(bmask); // TODO: merge code auto in = inputs().front(); int idim = i_to_vs.size(); - jk << _CS("][Ti:") << in->dtype(); - jk << _CS("][IDIM=") << JK::hex1(i_to_vs.size()); - jk << _CS("][ODIM=") << JK::hex1(o_shape.size()); + jk << "«Ti:" << in->dtype(); + jk << "«IDIM=" << JK::hex1(i_to_vs.size()); + jk << "«ODIM=" << JK::hex1(o_shape.size()); if (first_oid_of_var>=0) { - jk << _CS("][FOV=") << JK::hex1(first_oid_of_var); - jk << _CS("][VD=") << JK::hex1(var_dim); + jk << "«FOV=" << JK::hex1(first_oid_of_var); + jk << "«VD=" << JK::hex1(var_dim); } for (int i=0; i=0 && io==-1) { if (v.is_int()) { - jk << _CS("][VS") << JK::hex1(i) << _CS(":-1"); + jk << "«VS" << JK::hex1(i) << ":-1"; } else if (v.is_str()) { - jk << _CS("][VS") << JK::hex1(i) << _CS(":-5"); - jk << _CS("][VSS") << JK::hex1(i) << _CS(":") << v.get_str(); + jk << "«VS" << JK::hex1(i) << ":-5"; + jk << "«VSS" << JK::hex1(i) << ":" << v.get_str(); } else { ASSERT(v.is_var()); auto var = v.var; @@ -238,13 +238,13 @@ void SetitemOp::jit_prepare(JK& jk) { if (vshape[j] == o_shape[k]) vsmask |= 1<<(j+var_dim-vdim); } - jk << _CS("][VS") << JK::hex1(i) << '=' << JK::hex(vsmask); - jk << _CS("][VST") << JK::hex1(i) << ':' << var->dtype(); + jk << "«VS" << JK::hex1(i) << '=' << JK::hex(vsmask); + jk << "«VST" << JK::hex1(i) << ':' << var->dtype(); } } else if (iv>=0 && io>=0) { ASSERT(v.is_slice()); - jk << _CS("][VS") << JK::hex1(i) << ':'; + jk << "«VS" << JK::hex1(i) << ':'; if (std::abs(v.slice.step) <= 1) jk << JK::shex1(v.slice.step); else @@ -258,11 +258,10 @@ void SetitemOp::jit_prepare(JK& jk) { int tdims[6]; cuda_loop_schedule(o_shape, masks, tdims); for (int i=0; idtype(); - jk << _CS("][Tx:") << x->dtype(); - jk << _CS("][Ty:") << y->dtype(); - jk << _CS("][Tz:") << z->dtype() << ']'; + jk << "«Tc:" << cond->dtype(); + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«Tz:" << z->dtype(); } #else // JIT diff --git a/python/jittor/src/ops/transpose_op.cc b/python/jittor/src/ops/transpose_op.cc index 5fd3569e..2ea131d1 100644 --- a/python/jittor/src/ops/transpose_op.cc +++ b/python/jittor/src/ops/transpose_op.cc @@ -79,11 +79,10 @@ VarPtr TransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) { } void TransposeOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][DIM=") << JK::hex1(axes.size()); + jk << "«Tx:" << x->dtype(); + jk << "«DIM=" << JK::hex1(axes.size()); for (uint i=0; idtype() + << "«Ty:" << y->dtype() + << "«OP:" << ns; } #else // JIT diff --git a/python/jittor/src/ops/where_op.cc b/python/jittor/src/ops/where_op.cc index f9528eea..9f2b536e 100644 --- a/python/jittor/src/ops/where_op.cc +++ b/python/jittor/src/ops/where_op.cc @@ -54,10 +54,9 @@ void WhereOp::infer_shape() { } void WhereOp::jit_prepare(JK& jk) { - jk << _CS("[Ti:") << cond->dtype(); - jk << _CS("][To:") << outs[0]->dtype(); - jk << _CS("][NDIM=") << JK::hex1(cond->shape.size()); - jk << ']'; + jk << "«Ti:" << cond->dtype(); + jk << "«To:" << outs[0]->dtype(); + jk << "«NDIM=" << JK::hex1(cond->shape.size()); } #else // JIT diff --git a/python/jittor/src/test/test_jit_key.cc b/python/jittor/src/test/test_jit_key.cc index 31895000..ba047fd7 100644 --- a/python/jittor/src/test/test_jit_key.cc +++ b/python/jittor/src/test/test_jit_key.cc @@ -19,16 +19,16 @@ JIT_TEST(jit_key) { }); std::cerr << "get segfault, ok" << std::endl; - jk << JK::key << "key" << JK::val << "value" << JK::end; - jk << JK::key << "key" << JK::val << JK::hex(0x123123) << JK::end; - jk << JK::key << "key" << JK::val << JK::hex1(0x123123) << JK::end; - jk << JK::key << "key" << JK::val << JK::hex2(0x123123) << JK::end; - jk << JK::key << "key" << JK::val << JK::Oxhex(0x123123) << JK::end; - jk << JK::key << "key" << JK::val << JK::Oxhex1(0x123123) << JK::end; - jk << JK::key << "key" << JK::val << JK::Oxhex2(0x123123) << JK::end; - string key = "[key:value][key:123123][key:3][key:23][key:0x123123][key:0x3][key:0x23]"; + jk << JK::key << "key" << JK::val << "value"; + jk << JK::key << "key" << JK::val << JK::hex(0x123123); + jk << JK::key << "key" << JK::val << JK::hex1(0x123123); + jk << JK::key << "key" << JK::val << JK::hex2(0x123123); + jk << JK::key << "key" << JK::val << JK::Oxhex(0x123123); + jk << JK::key << "key" << JK::val << JK::Oxhex1(0x123123); + jk << JK::key << "key" << JK::val << JK::Oxhex2(0x123123); + string key = "«key:value«key:123123«key:3«key:23«key:0x123123«key:0x3«key:0x23"; ASSERTop(jk.to_string(),==,key); - auto keys = parse_jit_keys("[a:11][b:22][a[3]:b::[x]][x=11][f=itof(0x0)]"); + auto keys = parse_jit_keys("«a:11«b:22«a[3]:b::[x]«x=11«f=itof(0x0)"); vector> k2 = {{"a","11"},{"b","22"},{"a[3]","b::[x]"},{"x","17"},{"f","0"}}; ASSERTop(keys,==,k2); diff --git a/python/jittor/src/test/test_op_relay.cc b/python/jittor/src/test/test_op_relay.cc index 92a70915..8814c1a0 100644 --- a/python/jittor/src/test/test_op_relay.cc +++ b/python/jittor/src/test/test_op_relay.cc @@ -94,7 +94,7 @@ JIT_TEST(fused_op_relay_matmul) { auto allocator = get_allocator(); for (auto& v : fop.vars) if (v.type!=1) v.var->alloc(allocator); - auto entry = oc.compile("[OP:_fused_op_relay_matmul]", oc.src); + auto entry = oc.compile("«OP:_fused_op_relay_matmul", oc.src); for (uint i=0; inum; i++) a->ptr()[i] = b->ptr()[i] = 1; entry(&fop); diff --git a/python/jittor/test/test_merge_single_array_op.py b/python/jittor/test/test_merge_single_array_op.py index f07dab76..0c99cf1b 100644 --- a/python/jittor/test/test_merge_single_array_op.py +++ b/python/jittor/test/test_merge_single_array_op.py @@ -46,7 +46,7 @@ def test(shape, op1, op2): with jt.log_capture_scope(log_v=0, log_vprefix="fused_op.cc=100") as logs: d__ = d.data logs = find_log_with_re(logs, - "Jit (fused )?op key (not )?found: \[opkey0:array\[T:float32") + "Jit (fused )?op key (not )?found: «opkey0:array«T:float32") assert(len(logs)==1), logs a_ = a.data diff --git a/python/jittor/test/test_resize_and_crop.py b/python/jittor/test/test_resize_and_crop.py index 3e7f699e..ac44d152 100644 --- a/python/jittor/test/test_resize_and_crop.py +++ b/python/jittor/test/test_resize_and_crop.py @@ -81,7 +81,7 @@ def test_case(box_num, out_size, time_limit): for i in range(1, len(rep)): t += float(rep[i][3]) / 1e9 name = rep[i][0] - if name.startswith('[') and (not '[graph:]' in name): + if name.startswith('«') and (not '«graph:«' in name): fused_op_num += 1 assert fused_op_num == 1, fused_op_num assert t <= time_limit, t diff --git a/python/jittor/test/test_setitem.py b/python/jittor/test/test_setitem.py index d92d6ea3..192e43a0 100644 --- a/python/jittor/test/test_setitem.py +++ b/python/jittor/test/test_setitem.py @@ -428,6 +428,15 @@ class TestSetitem(unittest.TestCase): np.arange(4)[None,::-1]] np.testing.assert_allclose(nb, b.data) + def test_cuda_slice_migrate_bug(self): + a = jt.array([1,2,3,4,5]) + jt.sync_all() + if not jt.has_cuda: return + with jt.flag_scope(use_cuda=1): + b = a[0] + b.sync(True) + assert b.item() == 1 + if __name__ == "__main__":