support conv

This commit is contained in:
Dun Liang 2022-10-13 01:20:14 +08:00
parent 042c3610a3
commit 335a2e5c1d
3 changed files with 22 additions and 7 deletions

View File

@ -201,10 +201,16 @@ void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string&
src = new_src;
new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;");
new_src = token_replace_all(new_src, "::max($1,$2);", "($1)>($2)?($1):($2);");
// new_src = replace(new_src, "::max", "fmax");
// TODO: support max
// new_src = token_replace_all(new_src, "::max($1,$2);", "($1)>($2)?($1):($2);");
auto ss = split(new_src, ";");
for (auto &s : ss) {
if (s.find("?") != string::npos) {
s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
}
}
new_src = join(ss, ";");
src = new_src;
// auto tokens = token_split(new_src);
}
}

View File

@ -125,6 +125,11 @@ static void parse_reg(const string& src,
}
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst, bool match_whitespace) {
if (!(src.at(0) != '$' && src.at(src.size()-1) != '$' &&
src.at(src.size()-2) != '$')) {
LOGe << "illegal src:" << src;
LOGf << "illegal src:" << src;
}
ASSERT(src.at(0) != '$' && src.at(src.size()-1) != '$' &&
src.at(src.size()-2) != '$') << "illegal src:" << src;
vector<string> patterns;

View File

@ -59,12 +59,16 @@ class TestACL(unittest.TestCase):
@jt.flag_scope(use_acl=1)
def test_conv(self):
# x = jt.rand(10, 3, 50, 50)
# w = jt.rand(4,3,3,3)
x = jt.rand(2, 2, 1, 1)
w = jt.rand(2,2,1,1)
x = jt.rand(10, 3, 50, 50)
w = jt.rand(4,3,3,3)
# x = jt.rand(2, 2, 1, 1)
# w = jt.rand(2,2,1,1)
y = jt.nn.conv2d(x, w)
y.sync(True)
y1 = y.data
with jt.flag_scope(use_acl=0):
y2 = jt.nn.conv2d(x, w).data
np.testing.assert_allclose(y1, y2)
@jt.flag_scope(use_acl=1)
def test_matmul(self):