mirror of https://github.com/Jittor/Jittor
add strcmp for jit macro
This commit is contained in:
parent
d2ae3c05ff
commit
273be0db93
|
@ -46,7 +46,7 @@ def check(jt_model, torch_model, shape, near_data):
|
|||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestArgPoolOp(unittest.TestCase):
|
||||
@unittest.skipIf(jt.compiler.has_cuda, "No cuda found")
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda(self):
|
||||
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
|
||||
|
@ -59,15 +59,18 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
def test_cpu_(self):
|
||||
x = jt.random([32, 128, 157, 300])
|
||||
# x = jt.random([32, 128, 157, 300])
|
||||
x = jt.random([4, 128, 157, 300])
|
||||
x = jt.nn.pool(x, 2, "maximum", 0, 2)
|
||||
|
||||
def test_cpu(self):
|
||||
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
|
||||
torch_model = Sequential(MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0, ceil_mode=True), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(3, 1, 1))
|
||||
shape = [64, 64, 300, 300]
|
||||
# shape = [64, 64, 300, 300]
|
||||
shape = [4, 64, 300, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
shape = [32, 128, 157, 300]
|
||||
# shape = [32, 128, 157, 300]
|
||||
shape = [4, 128, 157, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
|
|
@ -133,5 +133,34 @@ class TestOpCompiler(unittest.TestCase):
|
|||
expect_error(lambda: jit_precompile(vars, "@if(1)"))
|
||||
expect_error(lambda: jit_precompile(vars, "#define OP1(a,b) a+b\n@expand_macro(OP1,1)"))
|
||||
|
||||
def test_strcmp(self):
|
||||
vars = {"Tx":"float"}
|
||||
check = lambda expr, result: \
|
||||
self.assertEqual(jit_precompile(vars, expr), result)
|
||||
check("@strcmp(aaa,aaa)", "0")
|
||||
check("@strcmp(aaa,bbb)", "-1")
|
||||
check("@strcmp(ccc,bbb)", "1")
|
||||
check("@{@strcmp(aaa,aaa)}", "0")
|
||||
check("@{@strcmp(aaa,bbb)}", "-1")
|
||||
check("@{@strcmp(ccc,bbb)}", "1")
|
||||
|
||||
code = \
|
||||
"""@define(T_NCCL,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
|
||||
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
|
||||
@if(@strcmp(@Tx,int64)==0, ncclInt64)
|
||||
)
|
||||
ncclBcast(..., @T_NCCL, ...)
|
||||
"""
|
||||
assert "ncclFloat" in jit_precompile({"Tx":"float"}, code)
|
||||
assert "ncclFloat" in jit_precompile({"Tx":"float32"}, code)
|
||||
assert "ncclFloat64" in jit_precompile({"Tx":"float64"}, code)
|
||||
assert "ncclInt" in jit_precompile({"Tx":"int"}, code)
|
||||
assert "ncclInt" in jit_precompile({"Tx":"int32"}, code)
|
||||
assert "ncclInt64" in jit_precompile({"Tx":"int64"}, code)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -15,6 +15,7 @@
|
|||
#include "ops/op_register.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "lock.h"
|
||||
#include "opt/expr.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -104,48 +105,6 @@ int OpCompiler::total_member_count() {
|
|||
return member_count;
|
||||
}
|
||||
|
||||
#define FOR_ALL_UOPS(m) \
|
||||
m(!,3) m(~,3)
|
||||
#define FOR_ALL_BOPS(m) \
|
||||
m(*,5) m(/,5) m(%,5) \
|
||||
m(+,6) m(-,6) \
|
||||
m(<<,7) m(>>,7) \
|
||||
m(<,9) m(<=,9) m(>,9) m(>=,9) \
|
||||
m(!=,10) m(==,10) \
|
||||
m(&,11) \
|
||||
m(^,12) \
|
||||
m(|,13) \
|
||||
m(&&,14) \
|
||||
m(||,15)
|
||||
|
||||
#define FOR_ALL_OPS(m) FOR_ALL_UOPS(m) FOR_ALL_BOPS(m)
|
||||
|
||||
inline bool is_unary_op(const string& op) {
|
||||
#define _u(o, _) if (op == #o) return true;
|
||||
FOR_ALL_UOPS(_u);
|
||||
return false;
|
||||
}
|
||||
|
||||
inline int precedence(const string& op) {
|
||||
#define _prior(o, p) if (op == #o) return p;
|
||||
FOR_ALL_OPS(_prior);
|
||||
return 20;
|
||||
}
|
||||
|
||||
inline bool check_precedence(const string& op1, const string& op2) {
|
||||
if (op1 == op2 && is_unary_op(op1)) return false;
|
||||
return precedence(op1) <= precedence(op2);
|
||||
}
|
||||
|
||||
inline int64_t calc_op(int64_t a, int64_t b, const string& op) {
|
||||
#define _calc_b(o, _) if (op == #o) return a o b;
|
||||
FOR_ALL_BOPS(_calc_b);
|
||||
#define _calc_u(o, _) if (op == #o) return o b;
|
||||
FOR_ALL_UOPS(_calc_u);
|
||||
ASSERT(0) << "Unrecognized op " << op;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>& vars) {
|
||||
if (expr.find("@") != string::npos) {
|
||||
string new_expr;
|
||||
|
@ -175,6 +134,22 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
|
|||
ASSERT(isvar(expr[j]));
|
||||
size_t k=j+1;
|
||||
while (k<expr.size() && isvar(expr[k])) k++;
|
||||
if (k<expr.size() && expr[k]=='(') {
|
||||
// syntax @xx(...)
|
||||
// ij k l
|
||||
size_t l=k+1;
|
||||
int presum = 1;
|
||||
while (l<expr.size() && presum) {
|
||||
if (expr[l] == ')')
|
||||
presum--;
|
||||
else if (expr[l] == '(')
|
||||
presum++;
|
||||
l++;
|
||||
}
|
||||
new_expr += precompile(vars, expr.substr(i, l-i));
|
||||
i = l-1;
|
||||
continue;
|
||||
}
|
||||
string var = expr.substr(j, k-j);
|
||||
auto iter = vars.find(var);
|
||||
ASSERT(iter!=vars.end()) << "Jit var " << var << " not found." << vars;
|
||||
|
@ -185,68 +160,18 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
|
|||
}
|
||||
return eval(new_expr, vars);
|
||||
}
|
||||
vector<int64> values = {0};
|
||||
vector<string> ops;
|
||||
auto pop_values_and_calc_op = [&]() {
|
||||
CHECK(ops.size());
|
||||
auto op = ops.back();
|
||||
ops.pop_back();
|
||||
CHECK(values.size());
|
||||
auto val2 = values.back();
|
||||
values.pop_back();
|
||||
auto val1 = val2;
|
||||
if (!is_unary_op(op)) {
|
||||
CHECK(values.size());
|
||||
val1 = values.back();
|
||||
values.pop_back();
|
||||
auto e = expr::make(expr);
|
||||
e->dfs([&](expr::Expr* s) {
|
||||
if (s->is_sym()) {
|
||||
auto iter = vars.find(s->str);
|
||||
ASSERT(iter!=vars.end()) << "Jit var " << s->str << " not found.";
|
||||
auto e = expr::make(iter->second);
|
||||
s->swap(e.get());
|
||||
}
|
||||
values.push_back(calc_op(val1, val2, op));
|
||||
};
|
||||
for (size_t i=0; i<expr.size(); i++) {
|
||||
if (expr[i] == ' ')
|
||||
continue;
|
||||
if (expr[i] == '(')
|
||||
ops.push_back(string()+expr[i]);
|
||||
else if (isdigit(expr[i])) {
|
||||
int64_t val = 0;
|
||||
while (i<expr.length() && isdigit(expr[i])) {
|
||||
val = val*10 + (expr[i]-'0');
|
||||
i++;
|
||||
}
|
||||
i--;
|
||||
values.push_back(val);
|
||||
} else if (isvar(expr[i])) {
|
||||
auto j=i+1;
|
||||
while (j<expr.size() && isvar(expr[j])) j++;
|
||||
auto var_name = expr.substr(i,j-i);
|
||||
auto iter = vars.find(var_name);
|
||||
ASSERT(iter!=vars.end()) << "Jit var " << var_name << " not found.";
|
||||
try {
|
||||
values.push_back(std::stoll(iter->second));
|
||||
} catch (...) {
|
||||
ASSERT(0) << "'" << iter->second << "' is not integer, expr " << expr;
|
||||
}
|
||||
i = j-1;
|
||||
} else if (expr[i] == ')') {
|
||||
while (ops.size() && ops.back() != "(")
|
||||
pop_values_and_calc_op();
|
||||
ops.pop_back();
|
||||
} else {
|
||||
auto j=i+1;
|
||||
while (j<expr.size() && expr[j] != ' ' &&
|
||||
expr[j] != '!' && expr[j] != '~' &&
|
||||
!isdigit(expr[j]) && !isvar(expr[j]) &&
|
||||
expr[j] != '(' && expr[j] != ')') j++;
|
||||
auto op = expr.substr(i, j-i);
|
||||
while (ops.size() && check_precedence(ops.back(), op))
|
||||
pop_values_and_calc_op();
|
||||
ops.push_back(op);
|
||||
i = j-1;
|
||||
}
|
||||
}
|
||||
while (ops.size())
|
||||
pop_values_and_calc_op();
|
||||
return values.back();
|
||||
});
|
||||
e = e->eval();
|
||||
ASSERT(e->is(expr::_int));
|
||||
return e->as_int();
|
||||
}
|
||||
|
||||
void load_macros(const string& src, unordered_map<string,string>& macros) {
|
||||
|
@ -588,6 +513,19 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (expr == "strcmp") {
|
||||
// syntax: @strcmp(s1,s2)
|
||||
// ij k l
|
||||
ASSERT(args.size()==2u)
|
||||
<< "Jit error: strcmp wrong arguments.";
|
||||
auto s1 = precompile(defs, args[0], macros);
|
||||
auto s2 = precompile(defs, args[1], macros);
|
||||
if (s1<s2) new_src += "-1"; else
|
||||
if (s1==s2) new_src += "0"; else
|
||||
new_src += "1";
|
||||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (args.size()) {
|
||||
// syntax: @e0(i0,i1,...,in) -> e0p[i0*e0stride0+i1*e0stride1+...]
|
||||
int nid=(int)expr.size();
|
||||
|
|
Loading…
Reference in New Issue