477 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			477 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
| //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
 | ||
| //
 | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | ||
| // See https://llvm.org/LICENSE.txt for license information.
 | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | ||
| //
 | ||
| //===----------------------------------------------------------------------===//
 | ||
| //
 | ||
| // This file implements an algorithm that returns for a divergent branch
 | ||
| // the set of basic blocks whose phi nodes become divergent due to divergent
 | ||
| // control. These are the blocks that are reachable by two disjoint paths from
 | ||
| // the branch or loop exits that have a reaching path that is disjoint from a
 | ||
| // path to the loop latch.
 | ||
| //
 | ||
| // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
 | ||
| // control-induced divergence in phi nodes.
 | ||
| //
 | ||
| //
 | ||
| // -- Reference --
 | ||
| // The algorithm is presented in Section 5 of 
 | ||
| //
 | ||
| //   An abstract interpretation for SPMD divergence
 | ||
| //       on reducible control flow graphs.
 | ||
| //   Julian Rosemann, Simon Moll and Sebastian Hack
 | ||
| //   POPL '21
 | ||
| //
 | ||
| //
 | ||
| // -- Sync dependence --
 | ||
| // Sync dependence characterizes the control flow aspect of the
 | ||
| // propagation of branch divergence. For example,
 | ||
| //
 | ||
| //   %cond = icmp slt i32 %tid, 10
 | ||
| //   br i1 %cond, label %then, label %else
 | ||
| // then:
 | ||
| //   br label %merge
 | ||
| // else:
 | ||
| //   br label %merge
 | ||
| // merge:
 | ||
| //   %a = phi i32 [ 0, %then ], [ 1, %else ]
 | ||
| //
 | ||
| // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
 | ||
| // because %tid is not on its use-def chains, %a is sync dependent on %tid
 | ||
| // because the branch "br i1 %cond" depends on %tid and affects which value %a
 | ||
| // is assigned to.
 | ||
| //
 | ||
| //
 | ||
| // -- Reduction to SSA construction --
 | ||
| // There are two disjoint paths from A to X, if a certain variant of SSA
 | ||
| // construction places a phi node in X under the following set-up scheme.
 | ||
| //
 | ||
| // This variant of SSA construction ignores incoming undef values.
 | ||
| // That is paths from the entry without a definition do not result in
 | ||
| // phi nodes.
 | ||
| //
 | ||
| //       entry
 | ||
| //     /      \
 | ||
| //    A        \
 | ||
| //  /   \       Y
 | ||
| // B     C     /
 | ||
| //  \   /  \  /
 | ||
| //    D     E
 | ||
| //     \   /
 | ||
| //       F
 | ||
| //
 | ||
| // Assume that A contains a divergent branch. We are interested
 | ||
| // in the set of all blocks where each block is reachable from A
 | ||
| // via two disjoint paths. This would be the set {D, F} in this
 | ||
| // case.
 | ||
| // To generally reduce this query to SSA construction we introduce
 | ||
| // a virtual variable x and assign to x different values in each
 | ||
| // successor block of A.
 | ||
| //
 | ||
| //           entry
 | ||
| //         /      \
 | ||
| //        A        \
 | ||
| //      /   \       Y
 | ||
| // x = 0   x = 1   /
 | ||
| //      \  /   \  /
 | ||
| //        D     E
 | ||
| //         \   /
 | ||
| //           F
 | ||
| //
 | ||
| // Our flavor of SSA construction for x will construct the following
 | ||
| //
 | ||
| //            entry
 | ||
| //          /      \
 | ||
| //         A        \
 | ||
| //       /   \       Y
 | ||
| // x0 = 0   x1 = 1  /
 | ||
| //       \   /   \ /
 | ||
| //     x2 = phi   E
 | ||
| //         \     /
 | ||
| //         x3 = phi
 | ||
| //
 | ||
| // The blocks D and F contain phi nodes and are thus each reachable
 | ||
| // by two disjoins paths from A.
 | ||
| //
 | ||
| // -- Remarks --
 | ||
| // * In case of loop exits we need to check the disjoint path criterion for loops.
 | ||
| //   To this end, we check whether the definition of x differs between the
 | ||
| //   loop exit and the loop header (_after_ SSA construction).
 | ||
| //
 | ||
| // -- Known Limitations & Future Work --
 | ||
| // * The algorithm requires reducible loops because the implementation
 | ||
| //   implicitly performs a single iteration of the underlying data flow analysis.
 | ||
| //   This was done for pragmatism, simplicity and speed.
 | ||
| //
 | ||
| //   Relevant related work for extending the algorithm to irreducible control:
 | ||
| //     A simple algorithm for global data flow analysis problems.
 | ||
| //     Matthew S. Hecht and Jeffrey D. Ullman.
 | ||
| //     SIAM Journal on Computing, 4(4):519–532, December 1975.
 | ||
| //
 | ||
| // * Another reason for requiring reducible loops is that points of
 | ||
| //   synchronization in irreducible loops aren't 'obvious' - there is no unique
 | ||
| //   header where threads 'should' synchronize when entering or coming back
 | ||
| //   around from the latch.
 | ||
| //
 | ||
| //===----------------------------------------------------------------------===//
 | ||
| 
 | ||
| #include "llvm/Analysis/SyncDependenceAnalysis.h"
 | ||
| #include "llvm/ADT/SmallPtrSet.h"
 | ||
| #include "llvm/Analysis/LoopInfo.h"
 | ||
| #include "llvm/IR/BasicBlock.h"
 | ||
| #include "llvm/IR/CFG.h"
 | ||
| #include "llvm/IR/Dominators.h"
 | ||
| #include "llvm/IR/Function.h"
 | ||
| 
 | ||
| #include <functional>
 | ||
| 
 | ||
| #define DEBUG_TYPE "sync-dependence"
 | ||
| 
 | ||
| // The SDA algorithm operates on a modified CFG - we modify the edges leaving
 | ||
| // loop headers as follows:
 | ||
| //
 | ||
| // * We remove all edges leaving all loop headers.
 | ||
| // * We add additional edges from the loop headers to their exit blocks.
 | ||
| //
 | ||
| // The modification is virtual, that is whenever we visit a loop header we
 | ||
| // pretend it had different successors.
 | ||
| namespace {
 | ||
| using namespace llvm;
 | ||
| 
 | ||
| // Custom Post-Order Traveral
 | ||
| //
 | ||
| // We cannot use the vanilla (R)PO computation of LLVM because:
 | ||
| // * We (virtually) modify the CFG.
 | ||
| // * We want a loop-compact block enumeration, that is the numbers assigned to
 | ||
| //   blocks of a loop form an interval
 | ||
| //   
 | ||
| using POCB = std::function<void(const BasicBlock &)>;
 | ||
| using VisitedSet = std::set<const BasicBlock *>;
 | ||
| using BlockStack = std::vector<const BasicBlock *>;
 | ||
| 
 | ||
| // forward
 | ||
| static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
 | ||
|                           VisitedSet &Finalized);
 | ||
| 
 | ||
| // for a nested region (top-level loop or nested loop)
 | ||
| static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
 | ||
|                            POCB CallBack, VisitedSet &Finalized) {
 | ||
|   const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
 | ||
|   while (!Stack.empty()) {
 | ||
|     const auto *NextBB = Stack.back();
 | ||
| 
 | ||
|     auto *NestedLoop = LI.getLoopFor(NextBB);
 | ||
|     bool IsNestedLoop = NestedLoop != Loop;
 | ||
| 
 | ||
|     // Treat the loop as a node
 | ||
|     if (IsNestedLoop) {
 | ||
|       SmallVector<BasicBlock *, 3> NestedExits;
 | ||
|       NestedLoop->getUniqueExitBlocks(NestedExits);
 | ||
|       bool PushedNodes = false;
 | ||
|       for (const auto *NestedExitBB : NestedExits) {
 | ||
|         if (NestedExitBB == LoopHeader)
 | ||
|           continue;
 | ||
|         if (Loop && !Loop->contains(NestedExitBB))
 | ||
|           continue;
 | ||
|         if (Finalized.count(NestedExitBB))
 | ||
|           continue;
 | ||
|         PushedNodes = true;
 | ||
|         Stack.push_back(NestedExitBB);
 | ||
|       }
 | ||
|       if (!PushedNodes) {
 | ||
|         // All loop exits finalized -> finish this node
 | ||
|         Stack.pop_back();
 | ||
|         computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
 | ||
|       }
 | ||
|       continue;
 | ||
|     }
 | ||
| 
 | ||
|     // DAG-style
 | ||
|     bool PushedNodes = false;
 | ||
|     for (const auto *SuccBB : successors(NextBB)) {
 | ||
|       if (SuccBB == LoopHeader)
 | ||
|         continue;
 | ||
|       if (Loop && !Loop->contains(SuccBB))
 | ||
|         continue;
 | ||
|       if (Finalized.count(SuccBB))
 | ||
|         continue;
 | ||
|       PushedNodes = true;
 | ||
|       Stack.push_back(SuccBB);
 | ||
|     }
 | ||
|     if (!PushedNodes) {
 | ||
|       // Never push nodes twice
 | ||
|       Stack.pop_back();
 | ||
|       if (!Finalized.insert(NextBB).second)
 | ||
|         continue;
 | ||
|       CallBack(*NextBB);
 | ||
|     }
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
 | ||
|   VisitedSet Finalized;
 | ||
|   BlockStack Stack;
 | ||
|   Stack.reserve(24); // FIXME made-up number
 | ||
|   Stack.push_back(&F.getEntryBlock());
 | ||
|   computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
 | ||
| }
 | ||
| 
 | ||
| static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
 | ||
|                           VisitedSet &Finalized) {
 | ||
|   /// Call CallBack on all loop blocks.
 | ||
|   std::vector<const BasicBlock *> Stack;
 | ||
|   const auto *LoopHeader = Loop.getHeader();
 | ||
| 
 | ||
|   // Visit the header last
 | ||
|   Finalized.insert(LoopHeader);
 | ||
|   CallBack(*LoopHeader);
 | ||
| 
 | ||
|   // Initialize with immediate successors
 | ||
|   for (const auto *BB : successors(LoopHeader)) {
 | ||
|     if (!Loop.contains(BB))
 | ||
|       continue;
 | ||
|     if (BB == LoopHeader)
 | ||
|       continue;
 | ||
|     Stack.push_back(BB);
 | ||
|   }
 | ||
| 
 | ||
|   // Compute PO inside region
 | ||
|   computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
 | ||
| }
 | ||
| 
 | ||
| } // namespace
 | ||
| 
 | ||
| namespace llvm {
 | ||
| 
 | ||
| ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
 | ||
| 
 | ||
| SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
 | ||
|                                                const PostDominatorTree &PDT,
 | ||
|                                                const LoopInfo &LI)
 | ||
|     : DT(DT), PDT(PDT), LI(LI) {
 | ||
|   computeTopLevelPO(*DT.getRoot()->getParent(), LI,
 | ||
|                     [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
 | ||
| }
 | ||
| 
 | ||
| SyncDependenceAnalysis::~SyncDependenceAnalysis() = default;
 | ||
| 
 | ||
| // divergence propagator for reducible CFGs
 | ||
| struct DivergencePropagator {
 | ||
|   const ModifiedPO &LoopPOT;
 | ||
|   const DominatorTree &DT;
 | ||
|   const PostDominatorTree &PDT;
 | ||
|   const LoopInfo &LI;
 | ||
|   const BasicBlock &DivTermBlock;
 | ||
| 
 | ||
|   // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
 | ||
|   //   block B
 | ||
|   // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
 | ||
|   // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
 | ||
|   // from X or B is an immediate successor of X (initial value).
 | ||
|   using BlockLabelVec = std::vector<const BasicBlock *>;
 | ||
|   BlockLabelVec BlockLabels;
 | ||
|   // divergent join and loop exit descriptor.
 | ||
|   std::unique_ptr<ControlDivergenceDesc> DivDesc;
 | ||
| 
 | ||
|   DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
 | ||
|                        const PostDominatorTree &PDT, const LoopInfo &LI,
 | ||
|                        const BasicBlock &DivTermBlock)
 | ||
|       : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
 | ||
|         BlockLabels(LoopPOT.size(), nullptr),
 | ||
|         DivDesc(new ControlDivergenceDesc) {}
 | ||
| 
 | ||
|   void printDefs(raw_ostream &Out) {
 | ||
|     Out << "Propagator::BlockLabels {\n";
 | ||
|     for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
 | ||
|       const auto *Label = BlockLabels[BlockIdx];
 | ||
|       Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
 | ||
|           << ") : ";
 | ||
|       if (!Label) {
 | ||
|         Out << "<null>\n";
 | ||
|       } else {
 | ||
|         Out << Label->getName() << "\n";
 | ||
|       }
 | ||
|     }
 | ||
|     Out << "}\n";
 | ||
|   }
 | ||
| 
 | ||
|   // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
 | ||
|   // causes a divergent join.
 | ||
|   bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
 | ||
|     auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
 | ||
| 
 | ||
|     // unset or same reaching label
 | ||
|     const auto *OldLabel = BlockLabels[SuccIdx];
 | ||
|     if (!OldLabel || (OldLabel == &PushedLabel)) {
 | ||
|       BlockLabels[SuccIdx] = &PushedLabel;
 | ||
|       return false;
 | ||
|     }
 | ||
| 
 | ||
|     // Update the definition
 | ||
|     BlockLabels[SuccIdx] = &SuccBlock;
 | ||
|     return true;
 | ||
|   }
 | ||
| 
 | ||
|   // visiting a virtual loop exit edge from the loop header --> temporal
 | ||
|   // divergence on join
 | ||
|   bool visitLoopExitEdge(const BasicBlock &ExitBlock,
 | ||
|                          const BasicBlock &DefBlock, bool FromParentLoop) {
 | ||
|     // Pushing from a non-parent loop cannot cause temporal divergence.
 | ||
|     if (!FromParentLoop)
 | ||
|       return visitEdge(ExitBlock, DefBlock);
 | ||
| 
 | ||
|     if (!computeJoin(ExitBlock, DefBlock))
 | ||
|       return false;
 | ||
| 
 | ||
|     // Identified a divergent loop exit
 | ||
|     DivDesc->LoopDivBlocks.insert(&ExitBlock);
 | ||
|     LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
 | ||
|                       << "\n");
 | ||
|     return true;
 | ||
|   }
 | ||
| 
 | ||
|   // process \p SuccBlock with reaching definition \p DefBlock
 | ||
|   bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
 | ||
|     if (!computeJoin(SuccBlock, DefBlock))
 | ||
|       return false;
 | ||
| 
 | ||
|     // Divergent, disjoint paths join.
 | ||
|     DivDesc->JoinDivBlocks.insert(&SuccBlock);
 | ||
|     LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
 | ||
|     return true;
 | ||
|   }
 | ||
| 
 | ||
|   std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
 | ||
|     assert(DivDesc);
 | ||
| 
 | ||
|     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
 | ||
|                       << "\n");
 | ||
| 
 | ||
|     const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
 | ||
| 
 | ||
|     // Early stopping criterion
 | ||
|     int FloorIdx = LoopPOT.size() - 1;
 | ||
|     const BasicBlock *FloorLabel = nullptr;
 | ||
| 
 | ||
|     // bootstrap with branch targets
 | ||
|     int BlockIdx = 0;
 | ||
| 
 | ||
|     for (const auto *SuccBlock : successors(&DivTermBlock)) {
 | ||
|       auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
 | ||
|       BlockLabels[SuccIdx] = SuccBlock;
 | ||
| 
 | ||
|       // Find the successor with the highest index to start with
 | ||
|       BlockIdx = std::max<int>(BlockIdx, SuccIdx);
 | ||
|       FloorIdx = std::min<int>(FloorIdx, SuccIdx);
 | ||
| 
 | ||
|       // Identify immediate divergent loop exits
 | ||
|       if (!DivBlockLoop)
 | ||
|         continue;
 | ||
| 
 | ||
|       const auto *BlockLoop = LI.getLoopFor(SuccBlock);
 | ||
|       if (BlockLoop && DivBlockLoop->contains(BlockLoop))
 | ||
|         continue;
 | ||
|       DivDesc->LoopDivBlocks.insert(SuccBlock);
 | ||
|       LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
 | ||
|                         << SuccBlock->getName() << "\n");
 | ||
|     }
 | ||
| 
 | ||
|     // propagate definitions at the immediate successors of the node in RPO
 | ||
|     for (; BlockIdx >= FloorIdx; --BlockIdx) {
 | ||
|       LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
 | ||
| 
 | ||
|       // Any label available here
 | ||
|       const auto *Label = BlockLabels[BlockIdx];
 | ||
|       if (!Label)
 | ||
|         continue;
 | ||
| 
 | ||
|       // Ok. Get the block
 | ||
|       const auto *Block = LoopPOT.getBlockAt(BlockIdx);
 | ||
|       LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
 | ||
| 
 | ||
|       auto *BlockLoop = LI.getLoopFor(Block);
 | ||
|       bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
 | ||
|       bool CausedJoin = false;
 | ||
|       int LoweredFloorIdx = FloorIdx;
 | ||
|       if (IsLoopHeader) {
 | ||
|         // Disconnect from immediate successors and propagate directly to loop
 | ||
|         // exits.
 | ||
|         SmallVector<BasicBlock *, 4> BlockLoopExits;
 | ||
|         BlockLoop->getExitBlocks(BlockLoopExits);
 | ||
| 
 | ||
|         bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
 | ||
|         for (const auto *BlockLoopExit : BlockLoopExits) {
 | ||
|           CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
 | ||
|           LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
 | ||
|                                           LoopPOT.getIndexOf(*BlockLoopExit));
 | ||
|         }
 | ||
|       } else {
 | ||
|         // Acyclic successor case
 | ||
|         for (const auto *SuccBlock : successors(Block)) {
 | ||
|           CausedJoin |= visitEdge(*SuccBlock, *Label);
 | ||
|           LoweredFloorIdx =
 | ||
|               std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
 | ||
|         }
 | ||
|       }
 | ||
| 
 | ||
|       // Floor update
 | ||
|       if (CausedJoin) {
 | ||
|         // 1. Different labels pushed to successors
 | ||
|         FloorIdx = LoweredFloorIdx;
 | ||
|       } else if (FloorLabel != Label) {
 | ||
|         // 2. No join caused BUT we pushed a label that is different than the
 | ||
|         // last pushed label
 | ||
|         FloorIdx = LoweredFloorIdx;
 | ||
|         FloorLabel = Label;
 | ||
|       }
 | ||
|     }
 | ||
| 
 | ||
|     LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
 | ||
| 
 | ||
|     return std::move(DivDesc);
 | ||
|   }
 | ||
| };
 | ||
| 
 | ||
| #ifndef NDEBUG
 | ||
| static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
 | ||
|   Out << "[";
 | ||
|   ListSeparator LS;
 | ||
|   for (const auto *BB : Blocks)
 | ||
|     Out << LS << BB->getName();
 | ||
|   Out << "]";
 | ||
| }
 | ||
| #endif
 | ||
| 
 | ||
| const ControlDivergenceDesc &
 | ||
| SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) {
 | ||
|   // trivial case
 | ||
|   if (Term.getNumSuccessors() <= 1) {
 | ||
|     return EmptyDivergenceDesc;
 | ||
|   }
 | ||
| 
 | ||
|   // already available in cache?
 | ||
|   auto ItCached = CachedControlDivDescs.find(&Term);
 | ||
|   if (ItCached != CachedControlDivDescs.end())
 | ||
|     return *ItCached->second;
 | ||
| 
 | ||
|   // compute all join points
 | ||
|   // Special handling of divergent loop exits is not needed for LCSSA
 | ||
|   const auto &TermBlock = *Term.getParent();
 | ||
|   DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
 | ||
|   auto DivDesc = Propagator.computeJoinPoints();
 | ||
| 
 | ||
|   LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
 | ||
|              dbgs() << "JoinDivBlocks: ";
 | ||
|              printBlockSet(DivDesc->JoinDivBlocks, dbgs());
 | ||
|              dbgs() << "\nLoopDivBlocks: ";
 | ||
|              printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
 | ||
| 
 | ||
|   auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
 | ||
|   assert(ItInserted.second);
 | ||
|   return *ItInserted.first->second;
 | ||
| }
 | ||
| 
 | ||
| } // namespace llvm
 |