mirror of https://github.com/Jittor/Jittor
polish loop var pickup method
This commit is contained in:
parent
10359f02fa
commit
3a438e21c8
|
@ -0,0 +1,21 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestFuser(unittest.TestCase):
|
||||
def test_wrong_fuse(self):
|
||||
a = jt.array([1])
|
||||
b = jt.random([10,])
|
||||
c = (a * b).sum() + (a + 1)
|
||||
print(c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -100,6 +100,15 @@ vector<pair<string,string>> Op::get_jit_define() {
|
|||
return parse_jit_keys(get_jit_key(jk));
|
||||
}
|
||||
|
||||
string Op::get_hash_name() {
|
||||
string hash_name;
|
||||
std::stringstream ss;
|
||||
do_prepare(jk);
|
||||
ss << std::hex << std::hash<string>()(jk.to_string());
|
||||
hash_name = ss.str();
|
||||
return hash_name;
|
||||
}
|
||||
|
||||
void Op::do_jit_prepare(JK& jk) {
|
||||
memcheck_all_exist();
|
||||
jk << name();
|
||||
|
|
1
src/op.h
1
src/op.h
|
@ -55,6 +55,7 @@ struct Op : Node {
|
|||
string name_ex() const;
|
||||
string get_jit_key(JK& jk);
|
||||
vector<pair<string,string>> get_jit_define();
|
||||
string get_hash_name();
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Op* var);
|
||||
|
|
|
@ -20,11 +20,7 @@ void LoopToFuncPass::run() {
|
|||
if (cc_type=="clang") choice=1;
|
||||
if (!choice) return;
|
||||
int func_num=0;
|
||||
string hash_name;
|
||||
std::stringstream ss;
|
||||
op->do_prepare(jk);
|
||||
ss << std::hex << std::hash<string>()(jk.to_string());
|
||||
hash_name = ss.str();
|
||||
string hash_name = op->get_hash_name();
|
||||
|
||||
ir->push_back("using namespace jittor;", &ir->before);
|
||||
if ((cc_type=="icc" || cc_type=="g++") && choice)
|
||||
|
|
|
@ -81,6 +81,10 @@ void LoopVarAnalyzePass::run() {
|
|||
// 1. reduce input
|
||||
// 2. element input
|
||||
// 3. broadcast output
|
||||
|
||||
// ugly fix multi different dim element input
|
||||
// (caused by force fused array op)
|
||||
int max_elm_dim = 0;
|
||||
for (uint i=0; i<vars.size(); i++) {
|
||||
// output
|
||||
if (vars[i].type == 2) {
|
||||
|
@ -91,8 +95,10 @@ void LoopVarAnalyzePass::run() {
|
|||
has_op = true;
|
||||
if (op->type() == OpType::reduce)
|
||||
has_reduce = true;
|
||||
if (op->type() == OpType::element)
|
||||
if (op->type() == OpType::element) {
|
||||
has_element = true;
|
||||
max_elm_dim = std::max(max_elm_dim, op->outputs().front()->shape.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
for (uint i=0; i<vars.size(); i++) {
|
||||
|
@ -109,6 +115,9 @@ void LoopVarAnalyzePass::run() {
|
|||
continue;
|
||||
if (has_element && !has_reduce && op->type() != OpType::element)
|
||||
continue;
|
||||
if (op->type() == OpType::element
|
||||
&& op->outputs().front()->shape.size() != max_elm_dim)
|
||||
continue;
|
||||
Var* loop_var;
|
||||
if (op->type() == OpType::broadcast || op->name_ex() == "index") {
|
||||
loop_var = op->output(0);
|
||||
|
|
Loading…
Reference in New Issue