mirror of https://github.com/Jittor/Jittor
support conv
This commit is contained in:
parent
042c3610a3
commit
335a2e5c1d
|
@ -201,10 +201,16 @@ void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string&
|
|||
src = new_src;
|
||||
|
||||
new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;");
|
||||
new_src = token_replace_all(new_src, "::max($1,$2);", "($1)>($2)?($1):($2);");
|
||||
// new_src = replace(new_src, "::max", "fmax");
|
||||
// TODO: support max
|
||||
// new_src = token_replace_all(new_src, "::max($1,$2);", "($1)>($2)?($1):($2);");
|
||||
auto ss = split(new_src, ";");
|
||||
for (auto &s : ss) {
|
||||
if (s.find("?") != string::npos) {
|
||||
s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
|
||||
}
|
||||
}
|
||||
new_src = join(ss, ";");
|
||||
src = new_src;
|
||||
// auto tokens = token_split(new_src);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -125,6 +125,11 @@ static void parse_reg(const string& src,
|
|||
}
|
||||
|
||||
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst, bool match_whitespace) {
|
||||
if (!(src.at(0) != '$' && src.at(src.size()-1) != '$' &&
|
||||
src.at(src.size()-2) != '$')) {
|
||||
LOGe << "illegal src:" << src;
|
||||
LOGf << "illegal src:" << src;
|
||||
}
|
||||
ASSERT(src.at(0) != '$' && src.at(src.size()-1) != '$' &&
|
||||
src.at(src.size()-2) != '$') << "illegal src:" << src;
|
||||
vector<string> patterns;
|
||||
|
|
|
@ -59,12 +59,16 @@ class TestACL(unittest.TestCase):
|
|||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_conv(self):
|
||||
# x = jt.rand(10, 3, 50, 50)
|
||||
# w = jt.rand(4,3,3,3)
|
||||
x = jt.rand(2, 2, 1, 1)
|
||||
w = jt.rand(2,2,1,1)
|
||||
x = jt.rand(10, 3, 50, 50)
|
||||
w = jt.rand(4,3,3,3)
|
||||
# x = jt.rand(2, 2, 1, 1)
|
||||
# w = jt.rand(2,2,1,1)
|
||||
y = jt.nn.conv2d(x, w)
|
||||
y.sync(True)
|
||||
y1 = y.data
|
||||
with jt.flag_scope(use_acl=0):
|
||||
y2 = jt.nn.conv2d(x, w).data
|
||||
np.testing.assert_allclose(y1, y2)
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_matmul(self):
|
||||
|
|
Loading…
Reference in New Issue