diff --git a/bolt/src/BinaryPassManager.cpp b/bolt/src/BinaryPassManager.cpp index 9aacff6d793e..18d3281fe6a0 100644 --- a/bolt/src/BinaryPassManager.cpp +++ b/bolt/src/BinaryPassManager.cpp @@ -29,6 +29,7 @@ #include "Passes/SplitFunctions.h" #include "Passes/StokeInfo.h" #include "Passes/TailDuplication.h" +#include "Passes/ThreeWayBranch.h" #include "Passes/ValidateInternalCalls.h" #include "Passes/VeneerElimination.h" #include "llvm/Support/FormatVariadic.h" @@ -83,6 +84,11 @@ static cl::opt TailDuplicationFlag( cl::desc("duplicate unconditional branches that cross a cache line"), cl::ZeroOrMore, cl::ReallyHidden, cl::cat(BoltOptCategory)); +static cl::opt ThreeWayBranchFlag("three-way-branch", + cl::desc("reorder three way branches"), + cl::ZeroOrMore, cl::ReallyHidden, + cl::cat(BoltOptCategory)); + static cl::opt PrintJTFootprintReduction("print-after-jt-footprint-reduction", cl::desc("print function after jt-footprint-reduction pass"), @@ -446,6 +452,9 @@ void BinaryFunctionPassManager::runAllPasses(BinaryContext &BC) { Manager.registerPass(std::make_unique(PrintPLT)); + Manager.registerPass(std::make_unique(), + opts::ThreeWayBranchFlag); + Manager.registerPass(std::make_unique(PrintReordered)); Manager.registerPass( diff --git a/bolt/src/MCPlusBuilder.h b/bolt/src/MCPlusBuilder.h index 0da004000627..833999c6ed5d 100644 --- a/bolt/src/MCPlusBuilder.h +++ b/bolt/src/MCPlusBuilder.h @@ -583,6 +583,11 @@ public: return false; } + virtual bool isPacked(const MCInst &Inst) const { + llvm_unreachable("not implemented"); + return false; + } + /// If non-zero, this is used to fill the executable space with instructions /// that will trap. Defaults to 0. virtual unsigned getTrapFillValue() const { return 0; } @@ -1572,6 +1577,27 @@ public: return false; } + virtual bool replaceBranchCondition(MCInst &Inst, const MCSymbol *TBB, + MCContext *Ctx, unsigned CC) const { + llvm_unreachable("not implemented"); + return false; + } + + virtual unsigned getInvertedCondCode(unsigned CC) const { + llvm_unreachable("not implemented"); + return false; + } + + virtual unsigned getCondCodesLogicalOr(unsigned CC1, unsigned CC2) const { + llvm_unreachable("not implemented"); + return false; + } + + virtual bool isValidCondCode(unsigned CC) const { + llvm_unreachable("not implemented"); + return false; + } + /// Return the conditional code used in a conditional jump instruction. /// Returns invalid code if not conditional jump. virtual unsigned getCondCode(const MCInst &Inst) const { diff --git a/bolt/src/Passes/CMakeLists.txt b/bolt/src/Passes/CMakeLists.txt index 19203acc137b..b03f03d7dfbf 100644 --- a/bolt/src/Passes/CMakeLists.txt +++ b/bolt/src/Passes/CMakeLists.txt @@ -38,6 +38,7 @@ add_llvm_library(LLVMBOLTPasses StackReachingUses.cpp StokeInfo.cpp TailDuplication.cpp + ThreeWayBranch.cpp ValidateInternalCalls.cpp VeneerElimination.cpp RetpolineInsertion.cpp diff --git a/bolt/src/Passes/ThreeWayBranch.cpp b/bolt/src/Passes/ThreeWayBranch.cpp new file mode 100644 index 000000000000..56cf2ff87e20 --- /dev/null +++ b/bolt/src/Passes/ThreeWayBranch.cpp @@ -0,0 +1,168 @@ +//===--------- Passes/ThreeWayBranch.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "ThreeWayBranch.h" + +#include + +using namespace llvm; + +namespace llvm { +namespace bolt { + +bool ThreeWayBranch::shouldRunOnFunction(BinaryFunction &Function) { + BinaryContext &BC = Function.getBinaryContext(); + BinaryFunction::BasicBlockOrderType BlockLayout = Function.getLayout(); + for (BinaryBasicBlock *BB : BlockLayout) { + for (MCInst &Inst : *BB) { + if (BC.MIB->isPacked(Inst)) + return false; + } + } + return true; +} + +void ThreeWayBranch::runOnFunction(BinaryFunction &Function) { + BinaryContext &BC = Function.getBinaryContext(); + MCContext *Ctx = BC.Ctx.get(); + // New blocks will be added and layout will change, + // so make a copy here to iterate over the original layout + BinaryFunction::BasicBlockOrderType BlockLayout = Function.getLayout(); + for (BinaryBasicBlock *BB : BlockLayout) { + // The block must be hot + if (BB->getExecutionCount() == 0 || + BB->getExecutionCount() == BinaryBasicBlock::COUNT_NO_PROFILE) + continue; + // with two successors + if (BB->succ_size() != 2) + continue; + // no jump table + if (BB->hasJumpTable()) + continue; + + BinaryBasicBlock *FalseSucc = BB->getConditionalSuccessor(false); + BinaryBasicBlock *TrueSucc = BB->getConditionalSuccessor(true); + + // One of BB's successors must have only one instruction that is a + // conditional jump + if ((FalseSucc->succ_size() != 2 || FalseSucc->size() != 1) && + (TrueSucc->succ_size() != 2 || TrueSucc->size() != 1)) + continue; + + // SecondBranch has the second conditional jump + BinaryBasicBlock *SecondBranch = FalseSucc; + BinaryBasicBlock *FirstEndpoint = TrueSucc; + if (FalseSucc->succ_size() != 2) { + SecondBranch = TrueSucc; + FirstEndpoint = FalseSucc; + } + + BinaryBasicBlock *SecondEndpoint = + SecondBranch->getConditionalSuccessor(false); + BinaryBasicBlock *ThirdEndpoint = + SecondBranch->getConditionalSuccessor(true); + + // Make sure we can modify the jump in SecondBranch without disturbing any + // other paths + if (SecondBranch->pred_size() != 1) + continue; + + // Get Jump Instructions + MCInst *FirstJump = BB->getLastNonPseudoInstr(); + MCInst *SecondJump = SecondBranch->getLastNonPseudoInstr(); + + // Get condition codes + unsigned FirstCC = BC.MIB->getCondCode(*FirstJump); + if (SecondBranch != FalseSucc) + FirstCC = BC.MIB->getInvertedCondCode(FirstCC); + // ThirdCC = ThirdCond && !FirstCC = !(!ThirdCond || + // !(!FirstCC)) = !(!ThirdCond || FirstCC) + unsigned ThirdCC = + BC.MIB->getInvertedCondCode(BC.MIB->getCondCodesLogicalOr( + BC.MIB->getInvertedCondCode(BC.MIB->getCondCode(*SecondJump)), + FirstCC)); + // SecondCC = !ThirdCond && !FirstCC = !(!(!ThirdCond) || + // !(!FirstCC)) = !(ThirdCond || FirstCC) + unsigned SecondCC = + BC.MIB->getInvertedCondCode(BC.MIB->getCondCodesLogicalOr( + BC.MIB->getCondCode(*SecondJump), FirstCC)); + + if (!BC.MIB->isValidCondCode(FirstCC) || + !BC.MIB->isValidCondCode(ThirdCC) || !BC.MIB->isValidCondCode(SecondCC)) + continue; + + std::vector> Blocks; + Blocks.push_back(std::make_pair(FirstEndpoint, FirstCC)); + Blocks.push_back(std::make_pair(SecondEndpoint, SecondCC)); + Blocks.push_back(std::make_pair(ThirdEndpoint, ThirdCC)); + + std::sort(Blocks.begin(), Blocks.end(), + [&](const std::pair A, + const std::pair B) { + return A.first->getExecutionCount() < + B.first->getExecutionCount(); + }); + + uint64_t NewSecondBranchCount = Blocks[1].first->getExecutionCount() + + Blocks[0].first->getExecutionCount(); + bool SecondBranchBigger = + NewSecondBranchCount > Blocks[2].first->getExecutionCount(); + + BB->removeAllSuccessors(); + if (SecondBranchBigger) { + BB->addSuccessor(Blocks[2].first, Blocks[2].first->getExecutionCount()); + BB->addSuccessor(SecondBranch, NewSecondBranchCount); + } else { + BB->addSuccessor(SecondBranch, NewSecondBranchCount); + BB->addSuccessor(Blocks[2].first, Blocks[2].first->getExecutionCount()); + } + + // Remove and add so there is no duplicate successors + SecondBranch->removeAllSuccessors(); + SecondBranch->addSuccessor(Blocks[0].first, + Blocks[0].first->getExecutionCount()); + SecondBranch->addSuccessor(Blocks[1].first, + Blocks[1].first->getExecutionCount()); + + SecondBranch->setExecutionCount(NewSecondBranchCount); + + // Replace the branch condition to fallthrough for the most common block + if (SecondBranchBigger) { + BC.MIB->replaceBranchCondition(*FirstJump, Blocks[2].first->getLabel(), + Ctx, Blocks[2].second); + } else { + BC.MIB->replaceBranchCondition( + *FirstJump, SecondBranch->getLabel(), Ctx, + BC.MIB->getInvertedCondCode(Blocks[2].second)); + } + + // Replace the branch condition to fallthrough for the second most common + // block + BC.MIB->replaceBranchCondition(*SecondJump, Blocks[0].first->getLabel(), + Ctx, Blocks[0].second); + + ++BranchesAltered; + } +} + +void ThreeWayBranch::runOnFunctions(BinaryContext &BC) { + for (auto &It : BC.getBinaryFunctions()) { + BinaryFunction &Function = It.second; + if (!shouldRunOnFunction(Function)) + continue; + runOnFunction(Function); + } + + outs() << "BOLT-INFO: number of three way branches order changed: " + << BranchesAltered << "\n"; +} + +} // end namespace bolt +} // end namespace llvm diff --git a/bolt/src/Passes/ThreeWayBranch.h b/bolt/src/Passes/ThreeWayBranch.h new file mode 100644 index 000000000000..3732735f1d35 --- /dev/null +++ b/bolt/src/Passes/ThreeWayBranch.h @@ -0,0 +1,43 @@ +//===--------- Passes/ThreeWayBranch.h ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TOOLS_LLVM_BOLT_PASSES_THREEWAYBRANCH_H +#define LLVM_TOOLS_LLVM_BOLT_PASSES_THREEWAYBRANCH_H + +#include "BinaryPasses.h" + +namespace llvm { +namespace bolt { + +/// Pass for optimizing a three way branch namely a single comparison and 2 +/// conditional jumps by reordering blocks, replacing successors, and replacing +/// jump conditions and destinations +class ThreeWayBranch : public BinaryFunctionPass { + /// Record how many 3 way branches were adjusted + uint64_t BranchesAltered = 0; + + /// Returns true if this pass should run on Function + bool shouldRunOnFunction(BinaryFunction &Function); + + /// Runs pass on Function + void runOnFunction(BinaryFunction &Function); + +public: + explicit ThreeWayBranch() : BinaryFunctionPass(false) {} + + const char *getName() const override { return "three way branch"; } + + void runOnFunctions(BinaryContext &BC) override; +}; + +} // namespace bolt +} // namespace llvm + +#endif diff --git a/bolt/src/Target/X86/X86MCPlusBuilder.cpp b/bolt/src/Target/X86/X86MCPlusBuilder.cpp index df73a37a4cba..b49c562da429 100644 --- a/bolt/src/Target/X86/X86MCPlusBuilder.cpp +++ b/bolt/src/Target/X86/X86MCPlusBuilder.cpp @@ -111,28 +111,6 @@ unsigned getShortArithOpcode(unsigned Opcode) { } } -unsigned getInvertedCondCode(unsigned CC) { - switch (CC) { - default: return X86::COND_INVALID; - case X86::COND_E: return X86::COND_NE; - case X86::COND_NE: return X86::COND_E; - case X86::COND_L: return X86::COND_GE; - case X86::COND_LE: return X86::COND_G; - case X86::COND_G: return X86::COND_LE; - case X86::COND_GE: return X86::COND_L; - case X86::COND_B: return X86::COND_AE; - case X86::COND_BE: return X86::COND_A; - case X86::COND_A: return X86::COND_BE; - case X86::COND_AE: return X86::COND_B; - case X86::COND_S: return X86::COND_NS; - case X86::COND_NS: return X86::COND_S; - case X86::COND_P: return X86::COND_NP; - case X86::COND_NP: return X86::COND_P; - case X86::COND_O: return X86::COND_NO; - case X86::COND_NO: return X86::COND_O; - } -} - bool isADD(unsigned Opcode) { switch (Opcode) { default: @@ -422,6 +400,86 @@ public: } } + unsigned getInvertedCondCode(unsigned CC) const override { + switch (CC) { + default: return X86::COND_INVALID; + case X86::COND_E: return X86::COND_NE; + case X86::COND_NE: return X86::COND_E; + case X86::COND_L: return X86::COND_GE; + case X86::COND_LE: return X86::COND_G; + case X86::COND_G: return X86::COND_LE; + case X86::COND_GE: return X86::COND_L; + case X86::COND_B: return X86::COND_AE; + case X86::COND_BE: return X86::COND_A; + case X86::COND_A: return X86::COND_BE; + case X86::COND_AE: return X86::COND_B; + case X86::COND_S: return X86::COND_NS; + case X86::COND_NS: return X86::COND_S; + case X86::COND_P: return X86::COND_NP; + case X86::COND_NP: return X86::COND_P; + case X86::COND_O: return X86::COND_NO; + case X86::COND_NO: return X86::COND_O; + } + } + + unsigned getCondCodesLogicalOr(unsigned CC1, unsigned CC2) const override { + enum DecodedCondCode : uint8_t { + DCC_EQUAL = 0x1, + DCC_GREATER = 0x2, + DCC_LESSER = 0x4, + DCC_GREATER_OR_LESSER = 0x6, + DCC_UNSIGNED = 0x8, + DCC_SIGNED = 0x10, + DCC_INVALID = 0x20, + }; + + auto decodeCondCode = [&](unsigned CC) -> uint8_t { + switch (CC) { + default: return DCC_INVALID; + case X86::COND_E: return DCC_EQUAL; + case X86::COND_NE: return DCC_GREATER | DCC_LESSER; + case X86::COND_L: return DCC_LESSER | DCC_SIGNED; + case X86::COND_LE: return DCC_EQUAL | DCC_LESSER | DCC_SIGNED; + case X86::COND_G: return DCC_GREATER | DCC_SIGNED; + case X86::COND_GE: return DCC_GREATER | DCC_EQUAL | DCC_SIGNED; + case X86::COND_B: return DCC_LESSER | DCC_UNSIGNED; + case X86::COND_BE: return DCC_EQUAL | DCC_LESSER | DCC_UNSIGNED; + case X86::COND_A: return DCC_GREATER | DCC_UNSIGNED; + case X86::COND_AE: return DCC_GREATER | DCC_EQUAL | DCC_UNSIGNED; + } + }; + + uint8_t DCC = decodeCondCode(CC1) | decodeCondCode(CC2); + + if (DCC & DCC_INVALID) + return X86::COND_INVALID; + + if (DCC & DCC_SIGNED && DCC & DCC_UNSIGNED) + return X86::COND_INVALID; + + switch (DCC) { + default: return X86::COND_INVALID; + case DCC_EQUAL | DCC_LESSER | DCC_SIGNED: return X86::COND_LE; + case DCC_EQUAL | DCC_LESSER | DCC_UNSIGNED: return X86::COND_BE; + case DCC_EQUAL | DCC_GREATER | DCC_SIGNED: return X86::COND_GE; + case DCC_EQUAL | DCC_GREATER | DCC_UNSIGNED: return X86::COND_AE; + case DCC_GREATER | DCC_LESSER | DCC_SIGNED: return X86::COND_NE; + case DCC_GREATER | DCC_LESSER | DCC_UNSIGNED: return X86::COND_NE; + case DCC_GREATER | DCC_LESSER: return X86::COND_NE; + case DCC_EQUAL | DCC_SIGNED: return X86::COND_E; + case DCC_EQUAL | DCC_UNSIGNED: return X86::COND_E; + case DCC_EQUAL: return X86::COND_E; + case DCC_LESSER | DCC_SIGNED: return X86::COND_L; + case DCC_LESSER | DCC_UNSIGNED: return X86::COND_B; + case DCC_GREATER | DCC_SIGNED: return X86::COND_G; + case DCC_GREATER | DCC_UNSIGNED: return X86::COND_A; + } + } + + bool isValidCondCode(unsigned CC) const override { + return (CC != X86::COND_INVALID); + } + bool isBreakpoint(const MCInst &Inst) const override { return Inst.getOpcode() == X86::INT3; } @@ -654,6 +712,11 @@ public: Inst.getOperand(2).getReg()); } + bool isPacked(const MCInst &Inst) const override { + const MCInstrDesc &Desc = Info->get(Inst.getOpcode()); + return (Desc.TSFlags & X86II::OpPrefixMask) == X86II::PD; + } + unsigned getTrapFillValue() const override { return 0xCC; } struct IndJmpMatcherFrag1 : MCInstMatcher { @@ -3252,6 +3315,16 @@ public: return true; } + bool replaceBranchCondition(MCInst &Inst, const MCSymbol *TBB, MCContext *Ctx, + unsigned CC) const override { + if (CC == X86::COND_INVALID) + return false; + Inst.getOperand(Info->get(Inst.getOpcode()).NumOperands - 1).setImm(CC); + Inst.getOperand(0) = MCOperand::createExpr( + MCSymbolRefExpr::create(TBB, MCSymbolRefExpr::VK_None, *Ctx)); + return true; + } + unsigned getCanonicalBranchCondCode(unsigned CC) const override { switch (CC) { default: return X86::COND_INVALID; diff --git a/bolt/test/X86/three-way-branch-pass.s b/bolt/test/X86/three-way-branch-pass.s new file mode 100644 index 000000000000..2a56f83de99f --- /dev/null +++ b/bolt/test/X86/three-way-branch-pass.s @@ -0,0 +1,36 @@ +# REQUIRES: system-linux + +# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown \ +# RUN: %s -o %t.o +# RUN: link_fdata %s %t.o %t.fdata +# RUN: %host_cc %cflags %t.o -o %t.exe -Wl,-q +# RUN: llvm-bolt %t.exe -data %t.fdata -print-finalized \ +# RUN: -o %t.out -three-way-branch | FileCheck %s +# RUN: %t.exe +# RUN: %t.out + +# FDATA: 1 main 8 1 main a 0 22 +# FDATA: 1 main 8 1 main #.BB1# 0 50 +# FDATA: 1 main 12 1 main 14 0 30 +# FDATA: 1 main 12 1 main #.BB2# 0 40 +# CHECK: Successors: .Ltmp1 (mispreds: 0, count: 40), .Ltmp0 (mispreds: 0, count: 52) +# CHECK: Successors: .LFT0 (mispreds: 0, count: 22), .LFT1 (mispreds: 0, count: 30) + + .text + .globl main + .type main, %function + .size main, .Lend-main +main: + mov $0x0, %eax + cmp $0x1, %eax + jge .BB1 + mov $0xf, %eax + xor %eax, %eax + retq +.BB1: + jg .BB2 + retq +.BB2: + mov $0x7, %eax + retq +.Lend: