From cb6f30fbd7fc9f1c59b3621a88c041a79a08ce5a Mon Sep 17 00:00:00 2001 From: Aries Date: Thu, 22 Dec 2022 17:17:02 +0800 Subject: [PATCH] Add initial support to lower ISD::SELECT into branch instructions in divergent execution path. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 117 ++++++++++++++++++++ llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 59 +++++++++- llvm/lib/Target/RISCV/RISCVInstrInfo.h | 4 + llvm/lib/Target/RISCV/VentusInstrInfo.td | 11 +- llvm/lib/Target/RISCV/VentusInstrInfoC.td | 2 +- llvm/lib/Target/RISCV/VentusInstrInfoV.td | 53 ++++++++- 6 files changed, 236 insertions(+), 10 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 819168a21397..b5d4a2efd170 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -4771,6 +4771,7 @@ static bool isSelectPseudo(MachineInstr &MI) { default: return false; case RISCV::Select_GPR_Using_CC_GPR: + case RISCV::Select_VGPR_Using_CC_VGPR: case RISCV::Select_FPR16_Using_CC_GPR: case RISCV::Select_FPR32_Using_CC_GPR: case RISCV::Select_FPR64_Using_CC_GPR: @@ -5050,6 +5051,120 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI, return TailMBB; } + +static MachineBasicBlock *emitVSelectPseudo(MachineInstr &MI, + MachineBasicBlock *BB, + const RISCVSubtarget &Subtarget) { + // vALU version of emitSelectPseudo, explicit join instruction should be + // generated for each branch. + // + // We produce the following control flow: + // HeadMBB + // / \ + // TrueMBB FalseMBB + // \ / + // JoinMBB + Register LHS = MI.getOperand(1).getReg(); + Register RHS = MI.getOperand(2).getReg(); + auto CC = static_cast(MI.getOperand(3).getImm()); + + SmallVector SelectDebugValues; + SmallSet SelectDests; + SelectDests.insert(MI.getOperand(0).getReg()); + + MachineInstr *LastSelectPseudo = &MI; + + for (auto E = BB->end(), SequenceMBBI = MachineBasicBlock::iterator(MI); + SequenceMBBI != E; ++SequenceMBBI) { + if (SequenceMBBI->isDebugInstr()) + continue; + if (isSelectPseudo(*SequenceMBBI)) { + if (SequenceMBBI->getOperand(1).getReg() != LHS || + SequenceMBBI->getOperand(2).getReg() != RHS || + SequenceMBBI->getOperand(3).getImm() != CC || + SelectDests.count(SequenceMBBI->getOperand(4).getReg()) || + SelectDests.count(SequenceMBBI->getOperand(5).getReg())) + break; + LastSelectPseudo = &*SequenceMBBI; + SequenceMBBI->collectDebugValues(SelectDebugValues); + SelectDests.insert(SequenceMBBI->getOperand(0).getReg()); + continue; + } + if (SequenceMBBI->hasUnmodeledSideEffects() || + SequenceMBBI->mayLoadOrStore()) + break; + if (llvm::any_of(SequenceMBBI->operands(), [&](MachineOperand &MO) { + return MO.isReg() && MO.isUse() && SelectDests.count(MO.getReg()); + })) + break; + } + + const RISCVInstrInfo &TII = *Subtarget.getInstrInfo(); + const BasicBlock *LLVM_BB = BB->getBasicBlock(); + DebugLoc DL = MI.getDebugLoc(); + MachineFunction::iterator I = ++BB->getIterator(); + + MachineBasicBlock *HeadMBB = BB; + MachineFunction *F = BB->getParent(); + MachineBasicBlock *JoinMBB = F->CreateMachineBasicBlock(LLVM_BB); + MachineBasicBlock *IfTrueMBB = F->CreateMachineBasicBlock(LLVM_BB); + MachineBasicBlock *IfFalseMBB = F->CreateMachineBasicBlock(LLVM_BB); + + F->insert(I, IfTrueMBB); + F->insert(I, IfFalseMBB); + F->insert(I, JoinMBB); + + // Transfer debug instructions associated with the selects to JoinMBB. + for (MachineInstr *DebugInstr : SelectDebugValues) { + JoinMBB->push_back(DebugInstr->removeFromParent()); + } + + // Move all instructions after the sequence to JoinMBB. + JoinMBB->splice(JoinMBB->end(), HeadMBB, + std::next(LastSelectPseudo->getIterator()), HeadMBB->end()); + // Update machine-CFG edges by transferring all successors of the current + // block to the new block which will contain the Phi nodes for the selects. + JoinMBB->transferSuccessorsAndUpdatePHIs(HeadMBB); + // Set the successors for HeadMBB. + HeadMBB->addSuccessor(IfTrueMBB); + HeadMBB->addSuccessor(IfFalseMBB); + + // Insert appropriate branch. + BuildMI(HeadMBB, DL, TII.getVBrCond(CC)) + .addReg(LHS) + .addReg(RHS) + .addMBB(IfTrueMBB); + + // Insert appropriate join + BuildMI(IfTrueMBB, DL, TII.get(RISCV::JOIN)).addMBB(JoinMBB); + BuildMI(IfFalseMBB, DL, TII.get(RISCV::JOIN)).addMBB(JoinMBB); + + IfTrueMBB->addSuccessor(JoinMBB); + IfFalseMBB->addSuccessor(JoinMBB); + + // Create PHIs for all of the select pseudo-instructions. + auto SelectMBBI = MI.getIterator(); + auto SelectEnd = std::next(LastSelectPseudo->getIterator()); + auto InsertionPoint = JoinMBB->begin(); + while (SelectMBBI != SelectEnd) { + auto Next = std::next(SelectMBBI); + if (isSelectPseudo(*SelectMBBI)) { + // %Result = phi [ %TrueValue, IfTrueMBB ], [ %FalseValue, IfFalseMBB ] + BuildMI(*JoinMBB, InsertionPoint, SelectMBBI->getDebugLoc(), + TII.get(RISCV::PHI), SelectMBBI->getOperand(0).getReg()) + .addReg(SelectMBBI->getOperand(4).getReg()) + .addMBB(IfTrueMBB) + .addReg(SelectMBBI->getOperand(5).getReg()) + .addMBB(IfFalseMBB); + SelectMBBI->eraseFromParent(); + } + SelectMBBI = Next; + } + + F->getProperties().reset(MachineFunctionProperties::Property::NoPHIs); + return JoinMBB; +} + static MachineBasicBlock *emitFROUND(MachineInstr &MI, MachineBasicBlock *MBB, const RISCVSubtarget &Subtarget) { unsigned CmpOpc, F2IOpc, I2FOpc, FSGNJOpc, FSGNJXOpc; @@ -5172,6 +5287,8 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case RISCV::Select_FPR32_Using_CC_GPR: case RISCV::Select_FPR64_Using_CC_GPR: return emitSelectPseudo(MI, BB, Subtarget); + case RISCV::Select_VGPR_Using_CC_VGPR: + return emitVSelectPseudo(MI, BB, Subtarget); case RISCV::BuildPairF64Pseudo: return emitBuildPairF64Pseudo(MI, BB); case RISCV::SplitF64Pseudo: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 72c78e573a23..bb10f83ccd66 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -385,20 +385,41 @@ static RISCVCC::CondCode getCondFromBranchOpc(unsigned Opc) { default: return RISCVCC::COND_INVALID; case RISCV::BEQ: + case RISCV::VBEQ: return RISCVCC::COND_EQ; case RISCV::BNE: + case RISCV::VBNE: return RISCVCC::COND_NE; case RISCV::BLT: + case RISCV::VBLT: return RISCVCC::COND_LT; case RISCV::BGE: + case RISCV::VBGE: return RISCVCC::COND_GE; case RISCV::BLTU: + case RISCV::VBLTU: return RISCVCC::COND_LTU; case RISCV::BGEU: + case RISCV::VBGEU: return RISCVCC::COND_GEU; } } +static bool isDivergentBranch(MachineInstr &I) { + switch (I.getOpcode()) { + default: + return false; + case RISCV::VBEQ: + case RISCV::VBNE: + case RISCV::VBLT: + case RISCV::VBGE: + case RISCV::VBLTU: + case RISCV::VBGEU: + return true; + } +} + + // The contents of values added to Cond are not examined outside of // RISCVInstrInfo, giving us flexibility in what to push to it. For RISCV, we // push BranchOpcode, Reg1, Reg2. @@ -433,6 +454,25 @@ const MCInstrDesc &RISCVInstrInfo::getBrCond(RISCVCC::CondCode CC) const { } } +const MCInstrDesc &RISCVInstrInfo::getVBrCond(RISCVCC::CondCode CC) const { + switch (CC) { + default: + llvm_unreachable("Unknown condition code!"); + case RISCVCC::COND_EQ: + return get(RISCV::VBEQ); + case RISCVCC::COND_NE: + return get(RISCV::VBNE); + case RISCVCC::COND_LT: + return get(RISCV::VBLT); + case RISCVCC::COND_GE: + return get(RISCV::VBGE); + case RISCVCC::COND_LTU: + return get(RISCV::VBLTU); + case RISCVCC::COND_GEU: + return get(RISCV::VBGEU); + } +} + RISCVCC::CondCode RISCVCC::getOppositeBranchCondition(RISCVCC::CondCode CC) { switch (CC) { default: @@ -535,6 +575,7 @@ unsigned RISCVInstrInfo::removeBranch(MachineBasicBlock &MBB, // Remove the branch. if (BytesRemoved) *BytesRemoved += getInstSizeInBytes(*I); + IsDivergentBranch = isDivergentBranch(*I); I->eraseFromParent(); I = MBB.end(); @@ -548,6 +589,7 @@ unsigned RISCVInstrInfo::removeBranch(MachineBasicBlock &MBB, // Remove the branch. if (BytesRemoved) *BytesRemoved += getInstSizeInBytes(*I); + IsDivergentBranch = isDivergentBranch(*I); I->eraseFromParent(); return 2; } @@ -565,9 +607,11 @@ unsigned RISCVInstrInfo::insertBranch( assert((Cond.size() == 3 || Cond.size() == 0) && "RISCV branch conditions have two components!"); + unsigned UncondBr = IsDivergentBranch ? RISCV::JOIN : RISCV::PseudoBR; + // Unconditional branch. if (Cond.empty()) { - MachineInstr &MI = *BuildMI(&MBB, DL, get(RISCV::PseudoBR)).addMBB(TBB); + MachineInstr &MI = *BuildMI(&MBB, DL, get(UncondBr)).addMBB(TBB); if (BytesAdded) *BytesAdded += getInstSizeInBytes(MI); return 1; @@ -575,8 +619,9 @@ unsigned RISCVInstrInfo::insertBranch( // Either a one or two-way conditional branch. auto CC = static_cast(Cond[0].getImm()); + const MCInstrDesc& CondBr = IsDivergentBranch ? getVBrCond(CC) : getBrCond(CC); MachineInstr &CondMI = - *BuildMI(&MBB, DL, getBrCond(CC)).add(Cond[1]).add(Cond[2]).addMBB(TBB); + *BuildMI(&MBB, DL, CondBr).add(Cond[1]).add(Cond[2]).addMBB(TBB); if (BytesAdded) *BytesAdded += getInstSizeInBytes(CondMI); @@ -585,7 +630,7 @@ unsigned RISCVInstrInfo::insertBranch( return 1; // Two-way conditional branch. - MachineInstr &MI = *BuildMI(&MBB, DL, get(RISCV::PseudoBR)).addMBB(FBB); + MachineInstr &MI = *BuildMI(&MBB, DL, get(UncondBr)).addMBB(FBB); if (BytesAdded) *BytesAdded += getInstSizeInBytes(MI); return 2; @@ -596,6 +641,7 @@ void RISCVInstrInfo::insertIndirectBranch(MachineBasicBlock &MBB, MachineBasicBlock &RestoreBB, const DebugLoc &DL, int64_t BrOffset, RegScavenger *RS) const { + assert(0 && "Add vALU support!"); assert(RS && "RegScavenger required for long branching"); assert(MBB.empty() && "new block should be inserted for expanding unconditional branch"); @@ -682,11 +728,18 @@ bool RISCVInstrInfo::isBranchOffsetInRange(unsigned BranchOp, default: llvm_unreachable("Unexpected opcode!"); case RISCV::BEQ: + case RISCV::VBEQ: case RISCV::BNE: + case RISCV::VBNE: case RISCV::BLT: + case RISCV::VBLT: case RISCV::BGE: + case RISCV::VBGE: case RISCV::BLTU: + case RISCV::VBLTU: case RISCV::BGEU: + case RISCV::VBGEU: + case RISCV::JOIN: return isIntN(13, BrOffset); case RISCV::JAL: case RISCV::PseudoBR: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index a0c537c51089..f4d29b62bef6 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -42,12 +42,16 @@ CondCode getOppositeBranchCondition(CondCode); } // end of namespace RISCVCC class RISCVInstrInfo : public RISCVGenInstrInfo { + // WORKAROUND: Indicate the branch remove is in a divergent execution env. + // so that a newly inserted branch should be it as well. + mutable bool IsDivergentBranch; public: explicit RISCVInstrInfo(RISCVSubtarget &STI); MCInst getNop() const override; const MCInstrDesc &getBrCond(RISCVCC::CondCode CC) const; + const MCInstrDesc &getVBrCond(RISCVCC::CondCode CC) const; unsigned isLoadFromStackSlot(const MachineInstr &MI, int &FrameIndex) const override; diff --git a/llvm/lib/Target/RISCV/VentusInstrInfo.td b/llvm/lib/Target/RISCV/VentusInstrInfo.td index 658d854e6a16..754f626f654d 100644 --- a/llvm/lib/Target/RISCV/VentusInstrInfo.td +++ b/llvm/lib/Target/RISCV/VentusInstrInfo.td @@ -1386,11 +1386,12 @@ def IntCCtoRISCVCC : SDNodeXFormgetTargetConstant(BrCC, SDLoc(N), Subtarget->getXLenVT()); }]>; -def riscv_selectcc_frag : PatFrag<(ops node:$lhs, node:$rhs, node:$cc, +def UniformSelectCCFrag : PatFrag<(ops node:$lhs, node:$rhs, node:$cc, node:$truev, node:$falsev), (riscv_selectcc node:$lhs, node:$rhs, node:$cc, node:$truev, - node:$falsev), [{}], + node:$falsev), + [{ return !N->isDivergent(); }], IntCCtoRISCVCC>; let Predicates = [HasShortForwardBranchOpt], @@ -1401,7 +1402,7 @@ def PseudoCCMOVGPR : Pseudo<(outs GPR:$dst), (ins GPR:$lhs, GPR:$rhs, ixlenimm:$cc, GPR:$falsev, GPR:$truev), [(set GPR:$dst, - (riscv_selectcc_frag:$cc GPR:$lhs, GPR:$rhs, + (UniformSelectCCFrag:$cc GPR:$lhs, GPR:$rhs, cond, GPR:$truev, GPR:$falsev))]>, Sched<[WriteSFB, ReadSFB, ReadSFB, ReadSFB, ReadSFB]>; @@ -1413,11 +1414,11 @@ multiclass SelectCC_GPR_rrirr { (ins GPR:$lhs, GPR:$rhs, ixlenimm:$cc, valty:$truev, valty:$falsev), [(set valty:$dst, - (riscv_selectcc_frag:$cc GPR:$lhs, GPR:$rhs, cond, + (UniformSelectCCFrag:$cc GPR:$lhs, GPR:$rhs, cond, valty:$truev, valty:$falsev))]>; // Explicitly select 0 in the condition to X0. The register coalescer doesn't // always do it. - def : Pat<(riscv_selectcc_frag:$cc GPR:$lhs, 0, cond, valty:$truev, + def : Pat<(UniformSelectCCFrag:$cc GPR:$lhs, 0, cond, valty:$truev, valty:$falsev), (!cast(NAME#"_Using_CC_GPR") GPR:$lhs, X0, (IntCCtoRISCVCC $cc), valty:$truev, valty:$falsev)>; diff --git a/llvm/lib/Target/RISCV/VentusInstrInfoC.td b/llvm/lib/Target/RISCV/VentusInstrInfoC.td index 9f59e8c13dd1..e81d4937088b 100644 --- a/llvm/lib/Target/RISCV/VentusInstrInfoC.td +++ b/llvm/lib/Target/RISCV/VentusInstrInfoC.td @@ -753,7 +753,7 @@ def : InstAlias<"c.fsdsp $rs2, (${rs1})", (C_FSDSP FPR64C:$rs2, SPMem:$rs1, 0)>; //===----------------------------------------------------------------------===// class SelectCompressOpt: Pat< - (riscv_selectcc_frag:$select GPR:$lhs, simm12_no6:$Constant, Cond, + (UniformSelectCCFrag:$select GPR:$lhs, simm12_no6:$Constant, Cond, GPR:$truev, GPR:$falsev), (Select_GPR_Using_CC_GPR (ADDI GPR:$lhs, (NegImm simm12:$Constant)), X0, (IntCCtoRISCVCC $select), GPR:$truev, GPR:$falsev)>; diff --git a/llvm/lib/Target/RISCV/VentusInstrInfoV.td b/llvm/lib/Target/RISCV/VentusInstrInfoV.td index 9a119a6e6f30..7a7a3532b6c8 100644 --- a/llvm/lib/Target/RISCV/VentusInstrInfoV.td +++ b/llvm/lib/Target/RISCV/VentusInstrInfoV.td @@ -628,6 +628,15 @@ class BranchCC_vvi funct3, string opcodestr> : RVInstVB, Sched<[]>; +let hasSideEffects = 0, mayLoad = 0, mayStore = 0, + isBarrier = 1, isBranch = 1, isTerminator = 1 in +class Branch_i funct3, string opcodestr> + : RVInstVB, Sched<[]> { + let vs1 = 0; + let vs2 = 0; +} + //===----------------------------------------------------------------------===// // Instructions //===----------------------------------------------------------------------===// @@ -638,7 +647,7 @@ def VBLT : BranchCC_vvi<0b100, "vblt">; def VBGE : BranchCC_vvi<0b101, "vbge">; def VBLTU : BranchCC_vvi<0b110, "vbltu">; def VBGEU : BranchCC_vvi<0b111, "vbgeu">; -def JOIN : BranchCC_vvi<0b011, "join">; +def JOIN : Branch_i<0b011, "join">; def VLUXEI8 : VectorLoad; def VLUXEI16 : VectorLoad; @@ -1068,6 +1077,48 @@ defm : DivergentStPat; defm : DivergentStPat; defm : DivergentStPat; + +def DivergentSelectCCFrag : PatFrag<(ops node:$lhs, node:$rhs, node:$cc, + node:$truev, node:$falsev), + (riscv_selectcc node:$lhs, node:$rhs, + node:$cc, node:$truev, + node:$falsev), + [{ return N->isDivergent(); }], + IntCCtoRISCVCC>; + +let Predicates = [HasShortForwardBranchOpt], + Constraints = "$dst = $falsev", isCommutable = 1, Size = 8 in { +// This instruction moves $truev to $dst when the condition is true. It will +// be expanded to control flow in RISCVExpandPseudoInsts. +def PseudoCCMOVVGPR : Pseudo<(outs VGPR:$dst), + (ins VGPR:$lhs, VGPR:$rhs, ixlenimm:$cc, + VGPR:$falsev, VGPR:$truev), + [(set VGPR:$dst, + (DivergentSelectCCFrag:$cc VGPR:$lhs, VGPR:$rhs, + cond, VGPR:$truev, + VGPR:$falsev))]>, + Sched<[WriteSFB, ReadSFB, ReadSFB, ReadSFB, ReadSFB]>; +} + +multiclass SelectCC_VGPR_rrirr { + let usesCustomInserter = 1 in + def _Using_CC_VGPR : Pseudo<(outs valty:$dst), + (ins VGPR:$lhs, VGPR:$rhs, ixlenimm:$cc, + valty:$truev, valty:$falsev), + [(set valty:$dst, + (DivergentSelectCCFrag:$cc VGPR:$lhs, VGPR:$rhs, cond, + valty:$truev, valty:$falsev))]>; + // Explicitly select 0 in the condition to X0. The register coalescer doesn't + // always do it. + def : Pat<(DivergentSelectCCFrag:$cc VGPR:$lhs, 0, cond, valty:$truev, + valty:$falsev), + (!cast(NAME#"_Using_CC_VGPR") VGPR:$lhs, X0, + (IntCCtoRISCVCC $cc), valty:$truev, valty:$falsev)>; +} + +let Predicates = [NoShortForwardBranchOpt] in +defm Select_VGPR : SelectCC_VGPR_rrirr; + // Match `riscv_brcc` and lower to the appropriate RISC-V branch instruction. class DivergentBccPat : Pat<(DivergentTetradFrag VGPR:$vs1, VGPR:$vs2, Cond, bb:$imm12),