Merge branch 'master' of https://github.com/Jittor/jittor into gword

This commit is contained in:
guowei yang 2020-04-16 13:25:58 +08:00
commit f4130cab28
6 changed files with 92 additions and 109 deletions

View File

@ -361,7 +361,8 @@ def setup_mpi():
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
if get_version(mpicc_path).startswith("(1."):
mpi_version = get_version(mpicc_path)
if mpi_version.startswith("(1.") or mpi_version.startswith("(2."):
# mpi version 1.x need to link like this
manual_link(mpi_flags)
# mpi(4.x) cannot use deepbind, it need to

View File

@ -46,7 +46,7 @@ def check(jt_model, torch_model, shape, near_data):
@unittest.skipIf(skip_this_test, "No Torch found")
class TestArgPoolOp(unittest.TestCase):
@unittest.skipIf(jt.compiler.has_cuda, "No cuda found")
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)
def test_cuda(self):
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
@ -59,15 +59,18 @@ class TestArgPoolOp(unittest.TestCase):
check(jt_model, torch_model, [1,1,300,300], True)
def test_cpu_(self):
x = jt.random([32, 128, 157, 300])
# x = jt.random([32, 128, 157, 300])
x = jt.random([4, 128, 157, 300])
x = jt.nn.pool(x, 2, "maximum", 0, 2)
def test_cpu(self):
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
torch_model = Sequential(MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0, ceil_mode=True), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(3, 1, 1))
shape = [64, 64, 300, 300]
# shape = [64, 64, 300, 300]
shape = [4, 64, 300, 300]
check(jt_model, torch_model, shape, False)
shape = [32, 128, 157, 300]
# shape = [32, 128, 157, 300]
shape = [4, 128, 157, 300]
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)

View File

@ -129,5 +129,13 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
assert a.min().data == a.data.min(), (a.min(), a.data.min())
assert a.max().data == a.data.max(), (a.max(), a.data.max())
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_cuda_pow_grad_nan(self):
a = jt.float32([1,-1, -1000.1])
da = jt.grad(a**2, a)
assert np.isnan(da.data).sum()==0, da.data
if __name__ == "__main__":
unittest.main()

View File

@ -133,5 +133,34 @@ class TestOpCompiler(unittest.TestCase):
expect_error(lambda: jit_precompile(vars, "@if(1)"))
expect_error(lambda: jit_precompile(vars, "#define OP1(a,b) a+b\n@expand_macro(OP1,1)"))
def test_strcmp(self):
vars = {"Tx":"float"}
check = lambda expr, result: \
self.assertEqual(jit_precompile(vars, expr), result)
check("@strcmp(aaa,aaa)", "0")
check("@strcmp(aaa,bbb)", "-1")
check("@strcmp(ccc,bbb)", "1")
check("@{@strcmp(aaa,aaa)}", "0")
check("@{@strcmp(aaa,bbb)}", "-1")
check("@{@strcmp(ccc,bbb)}", "1")
code = \
"""@define(T_NCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
@if(@strcmp(@Tx,int64)==0, ncclInt64)
)
ncclBcast(..., @T_NCCL, ...)
"""
assert "ncclFloat" in jit_precompile({"Tx":"float"}, code)
assert "ncclFloat" in jit_precompile({"Tx":"float32"}, code)
assert "ncclFloat64" in jit_precompile({"Tx":"float64"}, code)
assert "ncclInt" in jit_precompile({"Tx":"int"}, code)
assert "ncclInt" in jit_precompile({"Tx":"int32"}, code)
assert "ncclInt64" in jit_precompile({"Tx":"int64"}, code)
if __name__ == "__main__":
unittest.main()

View File

@ -15,6 +15,7 @@
#include "ops/op_register.h"
#include "ops/array_op.h"
#include "lock.h"
#include "opt/expr.h"
namespace jittor {
@ -104,48 +105,6 @@ int OpCompiler::total_member_count() {
return member_count;
}
#define FOR_ALL_UOPS(m) \
m(!,3) m(~,3)
#define FOR_ALL_BOPS(m) \
m(*,5) m(/,5) m(%,5) \
m(+,6) m(-,6) \
m(<<,7) m(>>,7) \
m(<,9) m(<=,9) m(>,9) m(>=,9) \
m(!=,10) m(==,10) \
m(&,11) \
m(^,12) \
m(|,13) \
m(&&,14) \
m(||,15)
#define FOR_ALL_OPS(m) FOR_ALL_UOPS(m) FOR_ALL_BOPS(m)
inline bool is_unary_op(const string& op) {
#define _u(o, _) if (op == #o) return true;
FOR_ALL_UOPS(_u);
return false;
}
inline int precedence(const string& op) {
#define _prior(o, p) if (op == #o) return p;
FOR_ALL_OPS(_prior);
return 20;
}
inline bool check_precedence(const string& op1, const string& op2) {
if (op1 == op2 && is_unary_op(op1)) return false;
return precedence(op1) <= precedence(op2);
}
inline int64_t calc_op(int64_t a, int64_t b, const string& op) {
#define _calc_b(o, _) if (op == #o) return a o b;
FOR_ALL_BOPS(_calc_b);
#define _calc_u(o, _) if (op == #o) return o b;
FOR_ALL_UOPS(_calc_u);
ASSERT(0) << "Unrecognized op " << op;
return 0;
}
int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>& vars) {
if (expr.find("@") != string::npos) {
string new_expr;
@ -175,6 +134,22 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
ASSERT(isvar(expr[j]));
size_t k=j+1;
while (k<expr.size() && isvar(expr[k])) k++;
if (k<expr.size() && expr[k]=='(') {
// syntax @xx(...)
// ij k l
size_t l=k+1;
int presum = 1;
while (l<expr.size() && presum) {
if (expr[l] == ')')
presum--;
else if (expr[l] == '(')
presum++;
l++;
}
new_expr += precompile(vars, expr.substr(i, l-i));
i = l-1;
continue;
}
string var = expr.substr(j, k-j);
auto iter = vars.find(var);
ASSERT(iter!=vars.end()) << "Jit var " << var << " not found." << vars;
@ -185,68 +160,18 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
}
return eval(new_expr, vars);
}
vector<int64> values = {0};
vector<string> ops;
auto pop_values_and_calc_op = [&]() {
CHECK(ops.size());
auto op = ops.back();
ops.pop_back();
CHECK(values.size());
auto val2 = values.back();
values.pop_back();
auto val1 = val2;
if (!is_unary_op(op)) {
CHECK(values.size());
val1 = values.back();
values.pop_back();
auto e = expr::make(expr);
e->dfs([&](expr::Expr* s) {
if (s->is_sym()) {
auto iter = vars.find(s->str);
ASSERT(iter!=vars.end()) << "Jit var " << s->str << " not found.";
auto e = expr::make(iter->second);
s->swap(e.get());
}
values.push_back(calc_op(val1, val2, op));
};
for (size_t i=0; i<expr.size(); i++) {
if (expr[i] == ' ')
continue;
if (expr[i] == '(')
ops.push_back(string()+expr[i]);
else if (isdigit(expr[i])) {
int64_t val = 0;
while (i<expr.length() && isdigit(expr[i])) {
val = val*10 + (expr[i]-'0');
i++;
}
i--;
values.push_back(val);
} else if (isvar(expr[i])) {
auto j=i+1;
while (j<expr.size() && isvar(expr[j])) j++;
auto var_name = expr.substr(i,j-i);
auto iter = vars.find(var_name);
ASSERT(iter!=vars.end()) << "Jit var " << var_name << " not found.";
try {
values.push_back(std::stoll(iter->second));
} catch (...) {
ASSERT(0) << "'" << iter->second << "' is not integer, expr " << expr;
}
i = j-1;
} else if (expr[i] == ')') {
while (ops.size() && ops.back() != "(")
pop_values_and_calc_op();
ops.pop_back();
} else {
auto j=i+1;
while (j<expr.size() && expr[j] != ' ' &&
expr[j] != '!' && expr[j] != '~' &&
!isdigit(expr[j]) && !isvar(expr[j]) &&
expr[j] != '(' && expr[j] != ')') j++;
auto op = expr.substr(i, j-i);
while (ops.size() && check_precedence(ops.back(), op))
pop_values_and_calc_op();
ops.push_back(op);
i = j-1;
}
}
while (ops.size())
pop_values_and_calc_op();
return values.back();
});
e = e->eval();
ASSERT(e->is(expr::_int));
return e->as_int();
}
void load_macros(const string& src, unordered_map<string,string>& macros) {
@ -588,6 +513,19 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
i = l-1;
continue;
} else
if (expr == "strcmp") {
// syntax: @strcmp(s1,s2)
// ij k l
ASSERT(args.size()==2u)
<< "Jit error: strcmp wrong arguments.";
auto s1 = precompile(defs, args[0], macros);
auto s2 = precompile(defs, args[1], macros);
if (s1<s2) new_src += "-1"; else
if (s1==s2) new_src += "0"; else
new_src += "1";
i = l-1;
continue;
} else
if (args.size()) {
// syntax: @e0(i0,i1,...,in) -> e0p[i0*e0stride0+i1*e0stride1+...]
int nid=(int)expr.size();

View File

@ -13,6 +13,8 @@
namespace jittor {
#ifndef JIT
static auto make_array = get_op_info("array")
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
static auto make_broadcast_to = get_op_info("broadcast_to")
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
static auto make_binary = get_op_info("binary")
@ -122,7 +124,9 @@ VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index == 0) {
// dout * y * x^(y-1)
auto d = make_binary(dout, y, ns_multiply);
auto ones = make_number(1, dout);
// auto ones = make_number(1, dout);
int number = 1;
auto ones = make_array(&number, 1, ns_int32);
auto y_1 = make_binary(y, ones, ns_subtract);
auto x_y_1 = make_binary(x, y_1, ns_pow);
return make_binary(d, x_y_1, ns_multiply);