mirror of https://github.com/Jittor/Jittor
polish corex compatible
This commit is contained in:
parent
20357caf42
commit
89998ebc60
|
@ -47,6 +47,10 @@ string process_acl(const string& src, const string& name, const map<string,strin
|
|||
"if (is_cuda_op && $1 != string::npos)",
|
||||
"if (is_cuda_op)");
|
||||
}
|
||||
if (name == "where_op.cc") {
|
||||
// default where kernel cannot handle 64 warp size, use cub_where instead
|
||||
new_src = token_replace_all(new_src, "if (cub_where$1) {", "if (cub_where) {");
|
||||
}
|
||||
return new_src;
|
||||
}
|
||||
}''', compiler.cc_flags + " " + " ".join(cc_files) + cc_flags)
|
||||
|
|
|
@ -138,7 +138,7 @@ struct float16 {
|
|||
this->x = (sign | (exponent << 10) | mantissa);
|
||||
}
|
||||
|
||||
inline operator float() {
|
||||
inline operator float() const {
|
||||
|
||||
unsigned sign = ((x >> 15) & 1);
|
||||
unsigned exponent = ((x >> 10) & 0x1f);
|
||||
|
|
|
@ -95,13 +95,13 @@ struct FP16OpType : OpByType {
|
|||
{"atan", "(($1) std::atan(($2)))"},
|
||||
{"tanh", "(($1) std::tanh(($2)))"},
|
||||
{"atanh", "(($1) std::atanh(($2)))"},
|
||||
{"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"},
|
||||
{"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min<float>($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"},
|
||||
{"erf", "(($1) std::erf(($2)))"},
|
||||
{"erfinv", "(jittor::_erfinv($2))"},
|
||||
{"cast", "(($1)($2))"},
|
||||
{"pow", "std::pow(($2),($4))"},
|
||||
{"maximum", "std::max($1($2), $1($4))"},
|
||||
{"minimum", "std::min($1($2), $1($4))"},
|
||||
{"maximum", "std::max<float>($1($2), $1($4))"},
|
||||
{"minimum", "std::min<float>($1($2), $1($4))"},
|
||||
{"mod", "$1(($2)-std::floor(($2)/($4))*($4))"},
|
||||
{"init_maximum", "-32768.0f"},
|
||||
{"init_minimum", "32768.0f"},
|
||||
|
|
|
@ -27,7 +27,7 @@ class TestUnaryOp(unittest.TestCase):
|
|||
assert jt.float64(1).data.dtype == "float64"
|
||||
assert (jt.abs(-1) == 1).data.all()
|
||||
assert (abs(-jt.float64(1)) == 1).data.all()
|
||||
a = np.array([-1,2,3,0])
|
||||
a = np.array([-1,2,3,0], dtype="int32")
|
||||
check("abs", a)
|
||||
check("negative", a)
|
||||
check("logical_not", a)
|
||||
|
|
Loading…
Reference in New Issue