mirror of https://github.com/Jittor/Jittor
polish ternary
This commit is contained in:
parent
ec2eef1fd9
commit
fd5bd4aba9
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.34'
|
||||
__version__ = '1.3.5.35'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -33,7 +33,11 @@ TernaryOp::TernaryOp(Var* cond, Var* x, Var* y) : cond(cond), x(x), y(y) {
|
|||
set_type(OpType::element);
|
||||
flags.set(NodeFlags::_manual_set_vnbb);
|
||||
cond->flags.set(NodeFlags::_needed_by_backward);
|
||||
z = create_output(nullptr, dtype_infer(x->ns, y->ns));
|
||||
if (x->dtype() == y->dtype()) {
|
||||
z = create_output(nullptr, x->dtype());
|
||||
} else {
|
||||
z = create_output(nullptr, dtype_infer(x->ns, y->ns));
|
||||
}
|
||||
}
|
||||
|
||||
VarPtr TernaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
|
|
Loading…
Reference in New Issue