polish ternary

This commit is contained in:
cxjyxx_me 2022-11-17 08:18:25 -05:00
parent ec2eef1fd9
commit fd5bd4aba9
2 changed files with 6 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.5.34'
__version__ = '1.3.5.35'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -33,7 +33,11 @@ TernaryOp::TernaryOp(Var* cond, Var* x, Var* y) : cond(cond), x(x), y(y) {
set_type(OpType::element);
flags.set(NodeFlags::_manual_set_vnbb);
cond->flags.set(NodeFlags::_needed_by_backward);
z = create_output(nullptr, dtype_infer(x->ns, y->ns));
if (x->dtype() == y->dtype()) {
z = create_output(nullptr, x->dtype());
} else {
z = create_output(nullptr, dtype_infer(x->ns, y->ns));
}
}
VarPtr TernaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {