[Matrix] Refactor tiled loops in a struct. NFC
The three loops have the same structure: index, header, latch.
This commit is contained in:
		
							parent
							
								
									39d431d811
								
							
						
					
					
						commit
						2c6e8b4636
					
				| 
						 | 
					@ -25,9 +25,9 @@ class IRBuilderBase;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// A helper struct to create IR loop nests for tiling in IR of the following
 | 
					/// A helper struct to create IR loop nests for tiling in IR of the following
 | 
				
			||||||
/// form:
 | 
					/// form:
 | 
				
			||||||
///   for CurrentColumn = 0..NumColumns
 | 
					///   for ColumnLoop.Index = 0..NumColumns
 | 
				
			||||||
///     for CurrentRow = 0..NumRows
 | 
					///     for RowLoop.Index = 0..NumRows
 | 
				
			||||||
///       for CurrentInner = 0..NumInner
 | 
					///       for KLoop.Index = 0..NumInner
 | 
				
			||||||
struct TileInfo {
 | 
					struct TileInfo {
 | 
				
			||||||
  /// Number of rows of the matrix.
 | 
					  /// Number of rows of the matrix.
 | 
				
			||||||
  unsigned NumRows;
 | 
					  unsigned NumRows;
 | 
				
			||||||
| 
						 | 
					@ -42,26 +42,21 @@ struct TileInfo {
 | 
				
			||||||
  /// Number of rows/columns in a tile.
 | 
					  /// Number of rows/columns in a tile.
 | 
				
			||||||
  unsigned TileSize = -1;
 | 
					  unsigned TileSize = -1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// Start row of the current tile to compute.
 | 
					  /// Properties of a single loop used when generating the tiled loop nest.
 | 
				
			||||||
  Value *CurrentRow;
 | 
					  struct MatrixLoop {
 | 
				
			||||||
 | 
					    /// The index updated on every iteration.
 | 
				
			||||||
 | 
					    Value *Index = nullptr;
 | 
				
			||||||
 | 
					    /// The header and latch of the loop.
 | 
				
			||||||
 | 
					    BasicBlock *Header = nullptr;
 | 
				
			||||||
 | 
					    BasicBlock *Latch = nullptr;
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// Start column of the current tile to compute.
 | 
					  /// The loop iterating on the rows.
 | 
				
			||||||
  Value *CurrentCol;
 | 
					  MatrixLoop RowLoop;
 | 
				
			||||||
 | 
					  /// The loop iterating on the columns.
 | 
				
			||||||
  /// Current tile offset during the tile computation.
 | 
					  MatrixLoop ColumnLoop;
 | 
				
			||||||
  Value *CurrentK;
 | 
					  /// The loop iterating on k (inner dimension).
 | 
				
			||||||
 | 
					  MatrixLoop KLoop;
 | 
				
			||||||
  /// Header of the outermost loop iterating from 0..NumColumns.
 | 
					 | 
				
			||||||
  BasicBlock *ColumnLoopHeader = nullptr;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /// Header of the second loop iterating from 0..NumRows.
 | 
					 | 
				
			||||||
  BasicBlock *RowLoopHeader = nullptr;
 | 
					 | 
				
			||||||
  /// Latch of the second loop iterating from 0..NumRows.
 | 
					 | 
				
			||||||
  BasicBlock *RowLoopLatch = nullptr;
 | 
					 | 
				
			||||||
  /// Header of the innermost loop iterating from 0..NumInner.
 | 
					 | 
				
			||||||
  BasicBlock *InnerLoopHeader = nullptr;
 | 
					 | 
				
			||||||
  /// Latch of the innermost loop iterating from 0..NumInner.
 | 
					 | 
				
			||||||
  BasicBlock *InnerLoopLatch = nullptr;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
 | 
					  TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
 | 
				
			||||||
           unsigned TileSize)
 | 
					           unsigned TileSize)
 | 
				
			||||||
| 
						 | 
					@ -72,9 +67,9 @@ struct TileInfo {
 | 
				
			||||||
  /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
 | 
					  /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
 | 
				
			||||||
  /// fields.
 | 
					  /// fields.
 | 
				
			||||||
  ///
 | 
					  ///
 | 
				
			||||||
  /// for CurrentColumn = 0..NumColumns
 | 
					  /// for ColumnLoop.Index = 0..NumColumns
 | 
				
			||||||
  ///   for CurrentRow = 0..NumRows
 | 
					  ///   for RowLoop.Index = 0..NumRows
 | 
				
			||||||
  ///     for CurrentInner = 0..NumInner
 | 
					  ///     for InnerLoop.Index = 0..NumInner
 | 
				
			||||||
  BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
 | 
					  BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
 | 
				
			||||||
                               IRBuilderBase &B, DomTreeUpdater &DTU,
 | 
					                               IRBuilderBase &B, DomTreeUpdater &DTU,
 | 
				
			||||||
                               LoopInfo &LI);
 | 
					                               LoopInfo &LI);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1423,13 +1423,13 @@ public:
 | 
				
			||||||
        FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
 | 
					        FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
 | 
				
			||||||
    MatrixTy TileResult;
 | 
					    MatrixTy TileResult;
 | 
				
			||||||
    // Insert in the inner loop header.
 | 
					    // Insert in the inner loop header.
 | 
				
			||||||
    Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
 | 
					    Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
 | 
				
			||||||
    // Create PHI nodes for the result columns to accumulate across iterations.
 | 
					    // Create PHI nodes for the result columns to accumulate across iterations.
 | 
				
			||||||
    SmallVector<PHINode *, 4> ColumnPhis;
 | 
					    SmallVector<PHINode *, 4> ColumnPhis;
 | 
				
			||||||
    for (unsigned I = 0; I < TileSize; I++) {
 | 
					    for (unsigned I = 0; I < TileSize; I++) {
 | 
				
			||||||
      auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
 | 
					      auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
 | 
				
			||||||
      Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
 | 
					      Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
 | 
				
			||||||
                       TI.RowLoopHeader->getSingleSuccessor());
 | 
					                       TI.RowLoop.Header->getSingleSuccessor());
 | 
				
			||||||
      TileResult.addVector(Phi);
 | 
					      TileResult.addVector(Phi);
 | 
				
			||||||
      ColumnPhis.push_back(Phi);
 | 
					      ColumnPhis.push_back(Phi);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -1438,27 +1438,29 @@ public:
 | 
				
			||||||
    //   Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
 | 
					    //   Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
 | 
				
			||||||
    Builder.SetInsertPoint(InnerBody->getTerminator());
 | 
					    Builder.SetInsertPoint(InnerBody->getTerminator());
 | 
				
			||||||
    // Load tiles of the operands.
 | 
					    // Load tiles of the operands.
 | 
				
			||||||
    MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK,
 | 
					    MatrixTy A =
 | 
				
			||||||
                            {TileSize, TileSize}, EltType, Builder);
 | 
					        loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
 | 
				
			||||||
    MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol,
 | 
					                   {TileSize, TileSize}, EltType, Builder);
 | 
				
			||||||
                            {TileSize, TileSize}, EltType, Builder);
 | 
					    MatrixTy B =
 | 
				
			||||||
 | 
					        loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
 | 
				
			||||||
 | 
					                   {TileSize, TileSize}, EltType, Builder);
 | 
				
			||||||
    emitMatrixMultiply(TileResult, A, B, Builder, true, false,
 | 
					    emitMatrixMultiply(TileResult, A, B, Builder, true, false,
 | 
				
			||||||
                       getFastMathFlags(MatMul));
 | 
					                       getFastMathFlags(MatMul));
 | 
				
			||||||
    // Store result after the inner loop is done.
 | 
					    // Store result after the inner loop is done.
 | 
				
			||||||
    Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
 | 
					    Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
 | 
				
			||||||
    storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
 | 
					    storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
 | 
				
			||||||
                Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
 | 
					                Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
 | 
				
			||||||
                TI.CurrentRow, TI.CurrentCol, EltType, Builder);
 | 
					                TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
 | 
					    for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
 | 
				
			||||||
      ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch);
 | 
					      ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Force unrolling of a few iterations of the inner loop, to make sure there
 | 
					    // Force unrolling of a few iterations of the inner loop, to make sure there
 | 
				
			||||||
    // is enough work per iteration.
 | 
					    // is enough work per iteration.
 | 
				
			||||||
    // FIXME: The unroller should make this decision directly instead, but
 | 
					    // FIXME: The unroller should make this decision directly instead, but
 | 
				
			||||||
    // currently the cost-model is not up to the task.
 | 
					    // currently the cost-model is not up to the task.
 | 
				
			||||||
    unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
 | 
					    unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
 | 
				
			||||||
    addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader),
 | 
					    addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
 | 
				
			||||||
                            "llvm.loop.unroll.count", InnerLoopUnrollCount);
 | 
					                            "llvm.loop.unroll.count", InnerLoopUnrollCount);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -70,35 +70,35 @@ BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
 | 
				
			||||||
BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
 | 
					BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
 | 
				
			||||||
                                       IRBuilderBase &B, DomTreeUpdater &DTU,
 | 
					                                       IRBuilderBase &B, DomTreeUpdater &DTU,
 | 
				
			||||||
                                       LoopInfo &LI) {
 | 
					                                       LoopInfo &LI) {
 | 
				
			||||||
  Loop *ColLoop = LI.AllocateLoop();
 | 
					  Loop *ColumnLoopInfo = LI.AllocateLoop();
 | 
				
			||||||
  Loop *RowLoop = LI.AllocateLoop();
 | 
					  Loop *RowLoopInfo = LI.AllocateLoop();
 | 
				
			||||||
  Loop *InnerLoop = LI.AllocateLoop();
 | 
					  Loop *KLoopInfo = LI.AllocateLoop();
 | 
				
			||||||
  RowLoop->addChildLoop(InnerLoop);
 | 
					  RowLoopInfo->addChildLoop(KLoopInfo);
 | 
				
			||||||
  ColLoop->addChildLoop(RowLoop);
 | 
					  ColumnLoopInfo->addChildLoop(RowLoopInfo);
 | 
				
			||||||
  if (Loop *ParentL = LI.getLoopFor(Start))
 | 
					  if (Loop *ParentL = LI.getLoopFor(Start))
 | 
				
			||||||
    ParentL->addChildLoop(ColLoop);
 | 
					    ParentL->addChildLoop(ColumnLoopInfo);
 | 
				
			||||||
  else
 | 
					  else
 | 
				
			||||||
    LI.addTopLevelLoop(ColLoop);
 | 
					    LI.addTopLevelLoop(ColumnLoopInfo);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  BasicBlock *ColBody =
 | 
					  BasicBlock *ColBody =
 | 
				
			||||||
      CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
 | 
					      CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
 | 
				
			||||||
                 "cols", B, DTU, ColLoop, LI);
 | 
					                 "cols", B, DTU, ColumnLoopInfo, LI);
 | 
				
			||||||
  BasicBlock *ColLatch = ColBody->getSingleSuccessor();
 | 
					  ColumnLoop.Latch = ColBody->getSingleSuccessor();
 | 
				
			||||||
  BasicBlock *RowBody =
 | 
					  BasicBlock *RowBody =
 | 
				
			||||||
      CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize),
 | 
					      CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
 | 
				
			||||||
                 "rows", B, DTU, RowLoop, LI);
 | 
					                 B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
 | 
				
			||||||
  RowLoopLatch = RowBody->getSingleSuccessor();
 | 
					  RowLoop.Latch = RowBody->getSingleSuccessor();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  BasicBlock *InnerBody =
 | 
					  BasicBlock *InnerBody =
 | 
				
			||||||
      CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner),
 | 
					      CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
 | 
				
			||||||
                 B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI);
 | 
					                 B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
 | 
				
			||||||
  InnerLoopLatch = InnerBody->getSingleSuccessor();
 | 
					  KLoop.Latch = InnerBody->getSingleSuccessor();
 | 
				
			||||||
  ColumnLoopHeader = ColBody->getSinglePredecessor();
 | 
					  ColumnLoop.Header = ColBody->getSinglePredecessor();
 | 
				
			||||||
  RowLoopHeader = RowBody->getSinglePredecessor();
 | 
					  RowLoop.Header = RowBody->getSinglePredecessor();
 | 
				
			||||||
  InnerLoopHeader = InnerBody->getSinglePredecessor();
 | 
					  KLoop.Header = InnerBody->getSinglePredecessor();
 | 
				
			||||||
  CurrentRow = &*RowLoopHeader->begin();
 | 
					  RowLoop.Index = &*RowLoop.Header->begin();
 | 
				
			||||||
  CurrentCol = &*ColumnLoopHeader->begin();
 | 
					  ColumnLoop.Index = &*ColumnLoop.Header->begin();
 | 
				
			||||||
  CurrentK = &*InnerLoopHeader->begin();
 | 
					  KLoop.Index = &*KLoop.Header->begin();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return InnerBody;
 | 
					  return InnerBody;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue