diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 4d23ce9ac021..9a3a2e432725 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -984,75 +984,76 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Result = visitSelectInstWithICmp(SI, ICI)) return Result; - if (Instruction *TI = dyn_cast(TrueVal)) - if (Instruction *FI = dyn_cast(FalseVal)) - if (TI->hasOneUse() && FI->hasOneUse()) { - Instruction *AddOp = nullptr, *SubOp = nullptr; + auto *TI = dyn_cast(TrueVal); + auto *FI = dyn_cast(FalseVal); + if (TI && FI && TI->hasOneUse() && FI->hasOneUse()) { + Instruction *AddOp = nullptr, *SubOp = nullptr; - // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) - if (TI->getOpcode() == FI->getOpcode()) - if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) - return IV; + // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) + if (TI->getOpcode() == FI->getOpcode()) + if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) + return IV; - // Turn select C, (X+Y), (X-Y) --> (X+(select C, Y, (-Y))). This is - // even legal for FP. - if ((TI->getOpcode() == Instruction::Sub && - FI->getOpcode() == Instruction::Add) || - (TI->getOpcode() == Instruction::FSub && - FI->getOpcode() == Instruction::FAdd)) { - AddOp = FI; SubOp = TI; - } else if ((FI->getOpcode() == Instruction::Sub && - TI->getOpcode() == Instruction::Add) || - (FI->getOpcode() == Instruction::FSub && - TI->getOpcode() == Instruction::FAdd)) { - AddOp = TI; SubOp = FI; - } + // Turn select C, (X+Y), (X-Y) --> (X+(select C, Y, (-Y))). This is + // even legal for FP. + if ((TI->getOpcode() == Instruction::Sub && + FI->getOpcode() == Instruction::Add) || + (TI->getOpcode() == Instruction::FSub && + FI->getOpcode() == Instruction::FAdd)) { + AddOp = FI; + SubOp = TI; + } else if ((FI->getOpcode() == Instruction::Sub && + TI->getOpcode() == Instruction::Add) || + (FI->getOpcode() == Instruction::FSub && + TI->getOpcode() == Instruction::FAdd)) { + AddOp = TI; + SubOp = FI; + } - if (AddOp) { - Value *OtherAddOp = nullptr; - if (SubOp->getOperand(0) == AddOp->getOperand(0)) { - OtherAddOp = AddOp->getOperand(1); - } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { - OtherAddOp = AddOp->getOperand(0); - } - - if (OtherAddOp) { - // So at this point we know we have (Y -> OtherAddOp): - // select C, (add X, Y), (sub X, Z) - Value *NegVal; // Compute -Z - if (SI.getType()->isFPOrFPVectorTy()) { - NegVal = Builder->CreateFNeg(SubOp->getOperand(1)); - if (Instruction *NegInst = dyn_cast(NegVal)) { - FastMathFlags Flags = AddOp->getFastMathFlags(); - Flags &= SubOp->getFastMathFlags(); - NegInst->setFastMathFlags(Flags); - } - } else { - NegVal = Builder->CreateNeg(SubOp->getOperand(1)); - } - - Value *NewTrueOp = OtherAddOp; - Value *NewFalseOp = NegVal; - if (AddOp != TI) - std::swap(NewTrueOp, NewFalseOp); - Value *NewSel = - Builder->CreateSelect(CondVal, NewTrueOp, - NewFalseOp, SI.getName() + ".p"); - - if (SI.getType()->isFPOrFPVectorTy()) { - Instruction *RI = - BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); - - FastMathFlags Flags = AddOp->getFastMathFlags(); - Flags &= SubOp->getFastMathFlags(); - RI->setFastMathFlags(Flags); - return RI; - } else - return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); - } - } + if (AddOp) { + Value *OtherAddOp = nullptr; + if (SubOp->getOperand(0) == AddOp->getOperand(0)) { + OtherAddOp = AddOp->getOperand(1); + } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { + OtherAddOp = AddOp->getOperand(0); } + if (OtherAddOp) { + // So at this point we know we have (Y -> OtherAddOp): + // select C, (add X, Y), (sub X, Z) + Value *NegVal; // Compute -Z + if (SI.getType()->isFPOrFPVectorTy()) { + NegVal = Builder->CreateFNeg(SubOp->getOperand(1)); + if (Instruction *NegInst = dyn_cast(NegVal)) { + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + NegInst->setFastMathFlags(Flags); + } + } else { + NegVal = Builder->CreateNeg(SubOp->getOperand(1)); + } + + Value *NewTrueOp = OtherAddOp; + Value *NewFalseOp = NegVal; + if (AddOp != TI) + std::swap(NewTrueOp, NewFalseOp); + Value *NewSel = Builder->CreateSelect(CondVal, NewTrueOp, NewFalseOp, + SI.getName() + ".p"); + + if (SI.getType()->isFPOrFPVectorTy()) { + Instruction *RI = + BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); + + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + RI->setFastMathFlags(Flags); + return RI; + } else + return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); + } + } + } + // See if we can fold the select into one of our operands. if (SI.getType()->isIntOrIntVectorTy() || SI.getType()->isFPOrFPVectorTy()) { if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal))