polish loop var pickup method

This commit is contained in:
Dun Liang 2021-01-17 22:15:53 +08:00
parent 10359f02fa
commit 3a438e21c8
5 changed files with 42 additions and 6 deletions

View File

@ -0,0 +1,21 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
class TestFuser(unittest.TestCase):
def test_wrong_fuse(self):
a = jt.array([1])
b = jt.random([10,])
c = (a * b).sum() + (a + 1)
print(c)
if __name__ == "__main__":
unittest.main()

View File

@ -100,6 +100,15 @@ vector<pair<string,string>> Op::get_jit_define() {
return parse_jit_keys(get_jit_key(jk));
}
string Op::get_hash_name() {
string hash_name;
std::stringstream ss;
do_prepare(jk);
ss << std::hex << std::hash<string>()(jk.to_string());
hash_name = ss.str();
return hash_name;
}
void Op::do_jit_prepare(JK& jk) {
memcheck_all_exist();
jk << name();

View File

@ -55,6 +55,7 @@ struct Op : Node {
string name_ex() const;
string get_jit_key(JK& jk);
vector<pair<string,string>> get_jit_define();
string get_hash_name();
};
std::ostream& operator<<(std::ostream& os, const Op* var);

View File

@ -20,11 +20,7 @@ void LoopToFuncPass::run() {
if (cc_type=="clang") choice=1;
if (!choice) return;
int func_num=0;
string hash_name;
std::stringstream ss;
op->do_prepare(jk);
ss << std::hex << std::hash<string>()(jk.to_string());
hash_name = ss.str();
string hash_name = op->get_hash_name();
ir->push_back("using namespace jittor;", &ir->before);
if ((cc_type=="icc" || cc_type=="g++") && choice)

View File

@ -81,6 +81,10 @@ void LoopVarAnalyzePass::run() {
// 1. reduce input
// 2. element input
// 3. broadcast output
// ugly fix multi different dim element input
// (caused by force fused array op)
int max_elm_dim = 0;
for (uint i=0; i<vars.size(); i++) {
// output
if (vars[i].type == 2) {
@ -91,8 +95,10 @@ void LoopVarAnalyzePass::run() {
has_op = true;
if (op->type() == OpType::reduce)
has_reduce = true;
if (op->type() == OpType::element)
if (op->type() == OpType::element) {
has_element = true;
max_elm_dim = std::max(max_elm_dim, op->outputs().front()->shape.size());
}
}
}
for (uint i=0; i<vars.size(); i++) {
@ -109,6 +115,9 @@ void LoopVarAnalyzePass::run() {
continue;
if (has_element && !has_reduce && op->type() != OpType::element)
continue;
if (op->type() == OpType::element
&& op->outputs().front()->shape.size() != max_elm_dim)
continue;
Var* loop_var;
if (op->type() == OpType::broadcast || op->name_ex() == "index") {
loop_var = op->output(0);