Add initial support to lower ISD::SELECT into branch instructions in divergent execution path.

This commit is contained in:
Aries 2022-12-22 17:17:02 +08:00
parent b9da010dd5
commit cb6f30fbd7
6 changed files with 236 additions and 10 deletions

View File

@ -4771,6 +4771,7 @@ static bool isSelectPseudo(MachineInstr &MI) {
default:
return false;
case RISCV::Select_GPR_Using_CC_GPR:
case RISCV::Select_VGPR_Using_CC_VGPR:
case RISCV::Select_FPR16_Using_CC_GPR:
case RISCV::Select_FPR32_Using_CC_GPR:
case RISCV::Select_FPR64_Using_CC_GPR:
@ -5050,6 +5051,120 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
return TailMBB;
}
static MachineBasicBlock *emitVSelectPseudo(MachineInstr &MI,
MachineBasicBlock *BB,
const RISCVSubtarget &Subtarget) {
// vALU version of emitSelectPseudo, explicit join instruction should be
// generated for each branch.
//
// We produce the following control flow:
// HeadMBB
// / \
// TrueMBB FalseMBB
// \ /
// JoinMBB
Register LHS = MI.getOperand(1).getReg();
Register RHS = MI.getOperand(2).getReg();
auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm());
SmallVector<MachineInstr *, 4> SelectDebugValues;
SmallSet<Register, 4> SelectDests;
SelectDests.insert(MI.getOperand(0).getReg());
MachineInstr *LastSelectPseudo = &MI;
for (auto E = BB->end(), SequenceMBBI = MachineBasicBlock::iterator(MI);
SequenceMBBI != E; ++SequenceMBBI) {
if (SequenceMBBI->isDebugInstr())
continue;
if (isSelectPseudo(*SequenceMBBI)) {
if (SequenceMBBI->getOperand(1).getReg() != LHS ||
SequenceMBBI->getOperand(2).getReg() != RHS ||
SequenceMBBI->getOperand(3).getImm() != CC ||
SelectDests.count(SequenceMBBI->getOperand(4).getReg()) ||
SelectDests.count(SequenceMBBI->getOperand(5).getReg()))
break;
LastSelectPseudo = &*SequenceMBBI;
SequenceMBBI->collectDebugValues(SelectDebugValues);
SelectDests.insert(SequenceMBBI->getOperand(0).getReg());
continue;
}
if (SequenceMBBI->hasUnmodeledSideEffects() ||
SequenceMBBI->mayLoadOrStore())
break;
if (llvm::any_of(SequenceMBBI->operands(), [&](MachineOperand &MO) {
return MO.isReg() && MO.isUse() && SelectDests.count(MO.getReg());
}))
break;
}
const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
const BasicBlock *LLVM_BB = BB->getBasicBlock();
DebugLoc DL = MI.getDebugLoc();
MachineFunction::iterator I = ++BB->getIterator();
MachineBasicBlock *HeadMBB = BB;
MachineFunction *F = BB->getParent();
MachineBasicBlock *JoinMBB = F->CreateMachineBasicBlock(LLVM_BB);
MachineBasicBlock *IfTrueMBB = F->CreateMachineBasicBlock(LLVM_BB);
MachineBasicBlock *IfFalseMBB = F->CreateMachineBasicBlock(LLVM_BB);
F->insert(I, IfTrueMBB);
F->insert(I, IfFalseMBB);
F->insert(I, JoinMBB);
// Transfer debug instructions associated with the selects to JoinMBB.
for (MachineInstr *DebugInstr : SelectDebugValues) {
JoinMBB->push_back(DebugInstr->removeFromParent());
}
// Move all instructions after the sequence to JoinMBB.
JoinMBB->splice(JoinMBB->end(), HeadMBB,
std::next(LastSelectPseudo->getIterator()), HeadMBB->end());
// Update machine-CFG edges by transferring all successors of the current
// block to the new block which will contain the Phi nodes for the selects.
JoinMBB->transferSuccessorsAndUpdatePHIs(HeadMBB);
// Set the successors for HeadMBB.
HeadMBB->addSuccessor(IfTrueMBB);
HeadMBB->addSuccessor(IfFalseMBB);
// Insert appropriate branch.
BuildMI(HeadMBB, DL, TII.getVBrCond(CC))
.addReg(LHS)
.addReg(RHS)
.addMBB(IfTrueMBB);
// Insert appropriate join
BuildMI(IfTrueMBB, DL, TII.get(RISCV::JOIN)).addMBB(JoinMBB);
BuildMI(IfFalseMBB, DL, TII.get(RISCV::JOIN)).addMBB(JoinMBB);
IfTrueMBB->addSuccessor(JoinMBB);
IfFalseMBB->addSuccessor(JoinMBB);
// Create PHIs for all of the select pseudo-instructions.
auto SelectMBBI = MI.getIterator();
auto SelectEnd = std::next(LastSelectPseudo->getIterator());
auto InsertionPoint = JoinMBB->begin();
while (SelectMBBI != SelectEnd) {
auto Next = std::next(SelectMBBI);
if (isSelectPseudo(*SelectMBBI)) {
// %Result = phi [ %TrueValue, IfTrueMBB ], [ %FalseValue, IfFalseMBB ]
BuildMI(*JoinMBB, InsertionPoint, SelectMBBI->getDebugLoc(),
TII.get(RISCV::PHI), SelectMBBI->getOperand(0).getReg())
.addReg(SelectMBBI->getOperand(4).getReg())
.addMBB(IfTrueMBB)
.addReg(SelectMBBI->getOperand(5).getReg())
.addMBB(IfFalseMBB);
SelectMBBI->eraseFromParent();
}
SelectMBBI = Next;
}
F->getProperties().reset(MachineFunctionProperties::Property::NoPHIs);
return JoinMBB;
}
static MachineBasicBlock *emitFROUND(MachineInstr &MI, MachineBasicBlock *MBB,
const RISCVSubtarget &Subtarget) {
unsigned CmpOpc, F2IOpc, I2FOpc, FSGNJOpc, FSGNJXOpc;
@ -5172,6 +5287,8 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
case RISCV::Select_FPR32_Using_CC_GPR:
case RISCV::Select_FPR64_Using_CC_GPR:
return emitSelectPseudo(MI, BB, Subtarget);
case RISCV::Select_VGPR_Using_CC_VGPR:
return emitVSelectPseudo(MI, BB, Subtarget);
case RISCV::BuildPairF64Pseudo:
return emitBuildPairF64Pseudo(MI, BB);
case RISCV::SplitF64Pseudo:

View File

@ -385,20 +385,41 @@ static RISCVCC::CondCode getCondFromBranchOpc(unsigned Opc) {
default:
return RISCVCC::COND_INVALID;
case RISCV::BEQ:
case RISCV::VBEQ:
return RISCVCC::COND_EQ;
case RISCV::BNE:
case RISCV::VBNE:
return RISCVCC::COND_NE;
case RISCV::BLT:
case RISCV::VBLT:
return RISCVCC::COND_LT;
case RISCV::BGE:
case RISCV::VBGE:
return RISCVCC::COND_GE;
case RISCV::BLTU:
case RISCV::VBLTU:
return RISCVCC::COND_LTU;
case RISCV::BGEU:
case RISCV::VBGEU:
return RISCVCC::COND_GEU;
}
}
static bool isDivergentBranch(MachineInstr &I) {
switch (I.getOpcode()) {
default:
return false;
case RISCV::VBEQ:
case RISCV::VBNE:
case RISCV::VBLT:
case RISCV::VBGE:
case RISCV::VBLTU:
case RISCV::VBGEU:
return true;
}
}
// The contents of values added to Cond are not examined outside of
// RISCVInstrInfo, giving us flexibility in what to push to it. For RISCV, we
// push BranchOpcode, Reg1, Reg2.
@ -433,6 +454,25 @@ const MCInstrDesc &RISCVInstrInfo::getBrCond(RISCVCC::CondCode CC) const {
}
}
const MCInstrDesc &RISCVInstrInfo::getVBrCond(RISCVCC::CondCode CC) const {
switch (CC) {
default:
llvm_unreachable("Unknown condition code!");
case RISCVCC::COND_EQ:
return get(RISCV::VBEQ);
case RISCVCC::COND_NE:
return get(RISCV::VBNE);
case RISCVCC::COND_LT:
return get(RISCV::VBLT);
case RISCVCC::COND_GE:
return get(RISCV::VBGE);
case RISCVCC::COND_LTU:
return get(RISCV::VBLTU);
case RISCVCC::COND_GEU:
return get(RISCV::VBGEU);
}
}
RISCVCC::CondCode RISCVCC::getOppositeBranchCondition(RISCVCC::CondCode CC) {
switch (CC) {
default:
@ -535,6 +575,7 @@ unsigned RISCVInstrInfo::removeBranch(MachineBasicBlock &MBB,
// Remove the branch.
if (BytesRemoved)
*BytesRemoved += getInstSizeInBytes(*I);
IsDivergentBranch = isDivergentBranch(*I);
I->eraseFromParent();
I = MBB.end();
@ -548,6 +589,7 @@ unsigned RISCVInstrInfo::removeBranch(MachineBasicBlock &MBB,
// Remove the branch.
if (BytesRemoved)
*BytesRemoved += getInstSizeInBytes(*I);
IsDivergentBranch = isDivergentBranch(*I);
I->eraseFromParent();
return 2;
}
@ -565,9 +607,11 @@ unsigned RISCVInstrInfo::insertBranch(
assert((Cond.size() == 3 || Cond.size() == 0) &&
"RISCV branch conditions have two components!");
unsigned UncondBr = IsDivergentBranch ? RISCV::JOIN : RISCV::PseudoBR;
// Unconditional branch.
if (Cond.empty()) {
MachineInstr &MI = *BuildMI(&MBB, DL, get(RISCV::PseudoBR)).addMBB(TBB);
MachineInstr &MI = *BuildMI(&MBB, DL, get(UncondBr)).addMBB(TBB);
if (BytesAdded)
*BytesAdded += getInstSizeInBytes(MI);
return 1;
@ -575,8 +619,9 @@ unsigned RISCVInstrInfo::insertBranch(
// Either a one or two-way conditional branch.
auto CC = static_cast<RISCVCC::CondCode>(Cond[0].getImm());
const MCInstrDesc& CondBr = IsDivergentBranch ? getVBrCond(CC) : getBrCond(CC);
MachineInstr &CondMI =
*BuildMI(&MBB, DL, getBrCond(CC)).add(Cond[1]).add(Cond[2]).addMBB(TBB);
*BuildMI(&MBB, DL, CondBr).add(Cond[1]).add(Cond[2]).addMBB(TBB);
if (BytesAdded)
*BytesAdded += getInstSizeInBytes(CondMI);
@ -585,7 +630,7 @@ unsigned RISCVInstrInfo::insertBranch(
return 1;
// Two-way conditional branch.
MachineInstr &MI = *BuildMI(&MBB, DL, get(RISCV::PseudoBR)).addMBB(FBB);
MachineInstr &MI = *BuildMI(&MBB, DL, get(UncondBr)).addMBB(FBB);
if (BytesAdded)
*BytesAdded += getInstSizeInBytes(MI);
return 2;
@ -596,6 +641,7 @@ void RISCVInstrInfo::insertIndirectBranch(MachineBasicBlock &MBB,
MachineBasicBlock &RestoreBB,
const DebugLoc &DL, int64_t BrOffset,
RegScavenger *RS) const {
assert(0 && "Add vALU support!");
assert(RS && "RegScavenger required for long branching");
assert(MBB.empty() &&
"new block should be inserted for expanding unconditional branch");
@ -682,11 +728,18 @@ bool RISCVInstrInfo::isBranchOffsetInRange(unsigned BranchOp,
default:
llvm_unreachable("Unexpected opcode!");
case RISCV::BEQ:
case RISCV::VBEQ:
case RISCV::BNE:
case RISCV::VBNE:
case RISCV::BLT:
case RISCV::VBLT:
case RISCV::BGE:
case RISCV::VBGE:
case RISCV::BLTU:
case RISCV::VBLTU:
case RISCV::BGEU:
case RISCV::VBGEU:
case RISCV::JOIN:
return isIntN(13, BrOffset);
case RISCV::JAL:
case RISCV::PseudoBR:

View File

@ -42,12 +42,16 @@ CondCode getOppositeBranchCondition(CondCode);
} // end of namespace RISCVCC
class RISCVInstrInfo : public RISCVGenInstrInfo {
// WORKAROUND: Indicate the branch remove is in a divergent execution env.
// so that a newly inserted branch should be it as well.
mutable bool IsDivergentBranch;
public:
explicit RISCVInstrInfo(RISCVSubtarget &STI);
MCInst getNop() const override;
const MCInstrDesc &getBrCond(RISCVCC::CondCode CC) const;
const MCInstrDesc &getVBrCond(RISCVCC::CondCode CC) const;
unsigned isLoadFromStackSlot(const MachineInstr &MI,
int &FrameIndex) const override;

View File

@ -1386,11 +1386,12 @@ def IntCCtoRISCVCC : SDNodeXForm<riscv_selectcc, [{
return CurDAG->getTargetConstant(BrCC, SDLoc(N), Subtarget->getXLenVT());
}]>;
def riscv_selectcc_frag : PatFrag<(ops node:$lhs, node:$rhs, node:$cc,
def UniformSelectCCFrag : PatFrag<(ops node:$lhs, node:$rhs, node:$cc,
node:$truev, node:$falsev),
(riscv_selectcc node:$lhs, node:$rhs,
node:$cc, node:$truev,
node:$falsev), [{}],
node:$falsev),
[{ return !N->isDivergent(); }],
IntCCtoRISCVCC>;
let Predicates = [HasShortForwardBranchOpt],
@ -1401,7 +1402,7 @@ def PseudoCCMOVGPR : Pseudo<(outs GPR:$dst),
(ins GPR:$lhs, GPR:$rhs, ixlenimm:$cc,
GPR:$falsev, GPR:$truev),
[(set GPR:$dst,
(riscv_selectcc_frag:$cc GPR:$lhs, GPR:$rhs,
(UniformSelectCCFrag:$cc GPR:$lhs, GPR:$rhs,
cond, GPR:$truev,
GPR:$falsev))]>,
Sched<[WriteSFB, ReadSFB, ReadSFB, ReadSFB, ReadSFB]>;
@ -1413,11 +1414,11 @@ multiclass SelectCC_GPR_rrirr<RegisterClass valty> {
(ins GPR:$lhs, GPR:$rhs, ixlenimm:$cc,
valty:$truev, valty:$falsev),
[(set valty:$dst,
(riscv_selectcc_frag:$cc GPR:$lhs, GPR:$rhs, cond,
(UniformSelectCCFrag:$cc GPR:$lhs, GPR:$rhs, cond,
valty:$truev, valty:$falsev))]>;
// Explicitly select 0 in the condition to X0. The register coalescer doesn't
// always do it.
def : Pat<(riscv_selectcc_frag:$cc GPR:$lhs, 0, cond, valty:$truev,
def : Pat<(UniformSelectCCFrag:$cc GPR:$lhs, 0, cond, valty:$truev,
valty:$falsev),
(!cast<Instruction>(NAME#"_Using_CC_GPR") GPR:$lhs, X0,
(IntCCtoRISCVCC $cc), valty:$truev, valty:$falsev)>;

View File

@ -753,7 +753,7 @@ def : InstAlias<"c.fsdsp $rs2, (${rs1})", (C_FSDSP FPR64C:$rs2, SPMem:$rs1, 0)>;
//===----------------------------------------------------------------------===//
class SelectCompressOpt<CondCode Cond>: Pat<
(riscv_selectcc_frag:$select GPR:$lhs, simm12_no6:$Constant, Cond,
(UniformSelectCCFrag:$select GPR:$lhs, simm12_no6:$Constant, Cond,
GPR:$truev, GPR:$falsev),
(Select_GPR_Using_CC_GPR (ADDI GPR:$lhs, (NegImm simm12:$Constant)), X0,
(IntCCtoRISCVCC $select), GPR:$truev, GPR:$falsev)>;

View File

@ -628,6 +628,15 @@ class BranchCC_vvi<bits<3> funct3, string opcodestr>
: RVInstVB<funct3, (outs), (ins VGPR:$vs1, VGPR:$vs2, simm13_lsb0:$imm12),
opcodestr, "$vs1, $vs2, $imm12">, Sched<[]>;
let hasSideEffects = 0, mayLoad = 0, mayStore = 0,
isBarrier = 1, isBranch = 1, isTerminator = 1 in
class Branch_i<bits<3> funct3, string opcodestr>
: RVInstVB<funct3, (outs), (ins simm13_lsb0:$imm12),
opcodestr, "$imm12">, Sched<[]> {
let vs1 = 0;
let vs2 = 0;
}
//===----------------------------------------------------------------------===//
// Instructions
//===----------------------------------------------------------------------===//
@ -638,7 +647,7 @@ def VBLT : BranchCC_vvi<0b100, "vblt">;
def VBGE : BranchCC_vvi<0b101, "vbge">;
def VBLTU : BranchCC_vvi<0b110, "vbltu">;
def VBGEU : BranchCC_vvi<0b111, "vbgeu">;
def JOIN : BranchCC_vvi<0b011, "join">;
def JOIN : Branch_i<0b011, "join">;
def VLUXEI8 : VectorLoad<LSWidth8, "vluxei8.v">;
def VLUXEI16 : VectorLoad<LSWidth16, "vluxei16.v">;
@ -1068,6 +1077,48 @@ defm : DivergentStPat<truncstorei8, VSUXEI8>;
defm : DivergentStPat<truncstorei16, VSUXEI16>;
defm : DivergentStPat<store, VSUXEI32>;
def DivergentSelectCCFrag : PatFrag<(ops node:$lhs, node:$rhs, node:$cc,
node:$truev, node:$falsev),
(riscv_selectcc node:$lhs, node:$rhs,
node:$cc, node:$truev,
node:$falsev),
[{ return N->isDivergent(); }],
IntCCtoRISCVCC>;
let Predicates = [HasShortForwardBranchOpt],
Constraints = "$dst = $falsev", isCommutable = 1, Size = 8 in {
// This instruction moves $truev to $dst when the condition is true. It will
// be expanded to control flow in RISCVExpandPseudoInsts.
def PseudoCCMOVVGPR : Pseudo<(outs VGPR:$dst),
(ins VGPR:$lhs, VGPR:$rhs, ixlenimm:$cc,
VGPR:$falsev, VGPR:$truev),
[(set VGPR:$dst,
(DivergentSelectCCFrag:$cc VGPR:$lhs, VGPR:$rhs,
cond, VGPR:$truev,
VGPR:$falsev))]>,
Sched<[WriteSFB, ReadSFB, ReadSFB, ReadSFB, ReadSFB]>;
}
multiclass SelectCC_VGPR_rrirr<RegisterClass valty> {
let usesCustomInserter = 1 in
def _Using_CC_VGPR : Pseudo<(outs valty:$dst),
(ins VGPR:$lhs, VGPR:$rhs, ixlenimm:$cc,
valty:$truev, valty:$falsev),
[(set valty:$dst,
(DivergentSelectCCFrag:$cc VGPR:$lhs, VGPR:$rhs, cond,
valty:$truev, valty:$falsev))]>;
// Explicitly select 0 in the condition to X0. The register coalescer doesn't
// always do it.
def : Pat<(DivergentSelectCCFrag:$cc VGPR:$lhs, 0, cond, valty:$truev,
valty:$falsev),
(!cast<Instruction>(NAME#"_Using_CC_VGPR") VGPR:$lhs, X0,
(IntCCtoRISCVCC $cc), valty:$truev, valty:$falsev)>;
}
let Predicates = [NoShortForwardBranchOpt] in
defm Select_VGPR : SelectCC_VGPR_rrirr<VGPR>;
// Match `riscv_brcc` and lower to the appropriate RISC-V branch instruction.
class DivergentBccPat<CondCode Cond, RVInstVB Inst>
: Pat<(DivergentTetradFrag<riscv_brcc> VGPR:$vs1, VGPR:$vs2, Cond, bb:$imm12),