mirror of https://github.com/Jittor/Jittor
polish reindex reduce fuse op
This commit is contained in:
parent
70ccdb1f17
commit
6a372f5f4f
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.36'
|
||||
__version__ = '1.3.1.37'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -32,6 +32,11 @@ ReindexReduceOp::ReindexReduceOp(Var* y, NanoString op, NanoVector shape, vector
|
|||
ns = op;
|
||||
ASSERT(ns.is_binary() && ns!=ns_mean);
|
||||
x = create_output(nullptr, y->dtype());
|
||||
for (auto e : extras) {
|
||||
if (e->shape != y->shape) {
|
||||
e->flags.set(NodeFlags::_stop_fuse);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VarPtr ReindexReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
|
|
|
@ -37,7 +37,9 @@
|
|||
namespace jittor {
|
||||
|
||||
DECLARE_FLAG(string, cc_type);
|
||||
DEFINE_FLAG(string, exclude_pass, "", "Don't run certian pass.");
|
||||
DEFINE_FLAG(string, exclude_pass, "", "Don't run certain pass.");
|
||||
DEFINE_FLAG(string, log_op_hash, "", "Output compiler pass result of certain hash of op.");
|
||||
|
||||
|
||||
PassManager::PassManager(OpCompiler* oc) : oc(oc), all(oc->get_src()) {
|
||||
main_ir = nullptr;
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
namespace jittor {
|
||||
|
||||
DECLARE_FLAG(string, exclude_pass);
|
||||
DECLARE_FLAG(string, log_op_hash);
|
||||
|
||||
struct PassManager {
|
||||
OpCompiler* oc;
|
||||
|
@ -46,6 +47,10 @@ void PassManager::run_pass() {
|
|||
pass->run();
|
||||
LOGvvvv << "Kernel IR after pass" << pass->name << ":\n"
|
||||
<< main_ir->to_string(0, true);
|
||||
|
||||
if (log_op_hash.size() && log_op_hash == oc->op->get_hash_name())
|
||||
LOGi << "hash mach:" << log_op_hash << "pass:" << pass->name
|
||||
<< main_ir->to_string(0, true);
|
||||
pass_map.emplace(pass->name, pass.get());
|
||||
finished_passes.push_back(move(pass));
|
||||
}
|
||||
|
|
|
@ -75,6 +75,19 @@ class TestReindexReduceOp(unittest.TestCase):
|
|||
nmask = mask.data
|
||||
_, (ndx,) = ngrad(lambda args: (pool_naive(args[0], size, op)*nmask).sum(), [nx], 1e-6)
|
||||
assert np.allclose(jdx, ndx), (op, jdx[0,:,:,0], ndx[0,:,:,0])
|
||||
|
||||
def test_fuse_error(self):
|
||||
a = jt.array([1,2,3,4])
|
||||
b = jt.zeros((3,3))
|
||||
jt.sync_all()
|
||||
c = b.reindex_reduce("add", [4,4], ["@e0(i0)", "@e0(i1)"], extras=[-a])
|
||||
c.sync()
|
||||
|
||||
a = jt.zeros((3,3))
|
||||
b = jt.zeros((3,3))
|
||||
jt.sync_all()
|
||||
c = b.reindex_reduce("add", [4,4], ["@e0(i0,i1)", "@e0(i1,i0)"], extras=[-a])
|
||||
c.sync()
|
||||
|
||||
def test_error(self):
|
||||
jt.random([3]).reindex_reduce("add", [3], ["i0"])
|
||||
|
|
Loading…
Reference in New Issue