polish reindex reduce fuse op

This commit is contained in:
Dun Liang 2022-01-12 16:43:39 +08:00
parent 70ccdb1f17
commit 6a372f5f4f
5 changed files with 27 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.36'
__version__ = '1.3.1.37'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -32,6 +32,11 @@ ReindexReduceOp::ReindexReduceOp(Var* y, NanoString op, NanoVector shape, vector
ns = op;
ASSERT(ns.is_binary() && ns!=ns_mean);
x = create_output(nullptr, y->dtype());
for (auto e : extras) {
if (e->shape != y->shape) {
e->flags.set(NodeFlags::_stop_fuse);
}
}
}
VarPtr ReindexReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {

View File

@ -37,7 +37,9 @@
namespace jittor {
DECLARE_FLAG(string, cc_type);
DEFINE_FLAG(string, exclude_pass, "", "Don't run certian pass.");
DEFINE_FLAG(string, exclude_pass, "", "Don't run certain pass.");
DEFINE_FLAG(string, log_op_hash, "", "Output compiler pass result of certain hash of op.");
PassManager::PassManager(OpCompiler* oc) : oc(oc), all(oc->get_src()) {
main_ir = nullptr;

View File

@ -14,6 +14,7 @@
namespace jittor {
DECLARE_FLAG(string, exclude_pass);
DECLARE_FLAG(string, log_op_hash);
struct PassManager {
OpCompiler* oc;
@ -46,6 +47,10 @@ void PassManager::run_pass() {
pass->run();
LOGvvvv << "Kernel IR after pass" << pass->name << ":\n"
<< main_ir->to_string(0, true);
if (log_op_hash.size() && log_op_hash == oc->op->get_hash_name())
LOGi << "hash mach:" << log_op_hash << "pass:" << pass->name
<< main_ir->to_string(0, true);
pass_map.emplace(pass->name, pass.get());
finished_passes.push_back(move(pass));
}

View File

@ -75,6 +75,19 @@ class TestReindexReduceOp(unittest.TestCase):
nmask = mask.data
_, (ndx,) = ngrad(lambda args: (pool_naive(args[0], size, op)*nmask).sum(), [nx], 1e-6)
assert np.allclose(jdx, ndx), (op, jdx[0,:,:,0], ndx[0,:,:,0])
def test_fuse_error(self):
a = jt.array([1,2,3,4])
b = jt.zeros((3,3))
jt.sync_all()
c = b.reindex_reduce("add", [4,4], ["@e0(i0)", "@e0(i1)"], extras=[-a])
c.sync()
a = jt.zeros((3,3))
b = jt.zeros((3,3))
jt.sync_all()
c = b.reindex_reduce("add", [4,4], ["@e0(i0,i1)", "@e0(i1,i0)"], extras=[-a])
c.sync()
def test_error(self):
jt.random([3]).reindex_reduce("add", [3], ["i0"])