polish corex compatible

This commit is contained in:
Dun Liang 2022-11-04 23:55:55 +08:00
parent 20357caf42
commit 89998ebc60
4 changed files with 9 additions and 5 deletions

View File

@ -47,6 +47,10 @@ string process_acl(const string& src, const string& name, const map<string,strin
"if (is_cuda_op && $1 != string::npos)",
"if (is_cuda_op)");
}
if (name == "where_op.cc") {
// default where kernel cannot handle 64 warp size, use cub_where instead
new_src = token_replace_all(new_src, "if (cub_where$1) {", "if (cub_where) {");
}
return new_src;
}
}''', compiler.cc_flags + " " + " ".join(cc_files) + cc_flags)

View File

@ -138,7 +138,7 @@ struct float16 {
this->x = (sign | (exponent << 10) | mantissa);
}
inline operator float() {
inline operator float() const {
unsigned sign = ((x >> 15) & 1);
unsigned exponent = ((x >> 10) & 0x1f);

View File

@ -95,13 +95,13 @@ struct FP16OpType : OpByType {
{"atan", "(($1) std::atan(($2)))"},
{"tanh", "(($1) std::tanh(($2)))"},
{"atanh", "(($1) std::atanh(($2)))"},
{"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"},
{"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min<float>($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"},
{"erf", "(($1) std::erf(($2)))"},
{"erfinv", "(jittor::_erfinv($2))"},
{"cast", "(($1)($2))"},
{"pow", "std::pow(($2),($4))"},
{"maximum", "std::max($1($2), $1($4))"},
{"minimum", "std::min($1($2), $1($4))"},
{"maximum", "std::max<float>($1($2), $1($4))"},
{"minimum", "std::min<float>($1($2), $1($4))"},
{"mod", "$1(($2)-std::floor(($2)/($4))*($4))"},
{"init_maximum", "-32768.0f"},
{"init_minimum", "32768.0f"},

View File

@ -27,7 +27,7 @@ class TestUnaryOp(unittest.TestCase):
assert jt.float64(1).data.dtype == "float64"
assert (jt.abs(-1) == 1).data.all()
assert (abs(-jt.float64(1)) == 1).data.all()
a = np.array([-1,2,3,0])
a = np.array([-1,2,3,0], dtype="int32")
check("abs", a)
check("negative", a)
check("logical_not", a)