[Matrix] Refactor tiled loops in a struct. NFC

The three loops have the same structure: index, header, latch.
This commit is contained in:
Francis Visoiu Mistrih 2022-07-20 11:12:30 +02:00
parent 39d431d811
commit 2c6e8b4636
3 changed files with 53 additions and 56 deletions

View File

@ -25,9 +25,9 @@ class IRBuilderBase;
/// A helper struct to create IR loop nests for tiling in IR of the following /// A helper struct to create IR loop nests for tiling in IR of the following
/// form: /// form:
/// for CurrentColumn = 0..NumColumns /// for ColumnLoop.Index = 0..NumColumns
/// for CurrentRow = 0..NumRows /// for RowLoop.Index = 0..NumRows
/// for CurrentInner = 0..NumInner /// for KLoop.Index = 0..NumInner
struct TileInfo { struct TileInfo {
/// Number of rows of the matrix. /// Number of rows of the matrix.
unsigned NumRows; unsigned NumRows;
@ -42,26 +42,21 @@ struct TileInfo {
/// Number of rows/columns in a tile. /// Number of rows/columns in a tile.
unsigned TileSize = -1; unsigned TileSize = -1;
/// Start row of the current tile to compute. /// Properties of a single loop used when generating the tiled loop nest.
Value *CurrentRow; struct MatrixLoop {
/// The index updated on every iteration.
Value *Index = nullptr;
/// The header and latch of the loop.
BasicBlock *Header = nullptr;
BasicBlock *Latch = nullptr;
};
/// Start column of the current tile to compute. /// The loop iterating on the rows.
Value *CurrentCol; MatrixLoop RowLoop;
/// The loop iterating on the columns.
/// Current tile offset during the tile computation. MatrixLoop ColumnLoop;
Value *CurrentK; /// The loop iterating on k (inner dimension).
MatrixLoop KLoop;
/// Header of the outermost loop iterating from 0..NumColumns.
BasicBlock *ColumnLoopHeader = nullptr;
/// Header of the second loop iterating from 0..NumRows.
BasicBlock *RowLoopHeader = nullptr;
/// Latch of the second loop iterating from 0..NumRows.
BasicBlock *RowLoopLatch = nullptr;
/// Header of the innermost loop iterating from 0..NumInner.
BasicBlock *InnerLoopHeader = nullptr;
/// Latch of the innermost loop iterating from 0..NumInner.
BasicBlock *InnerLoopLatch = nullptr;
TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner, TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
unsigned TileSize) unsigned TileSize)
@ -72,9 +67,9 @@ struct TileInfo {
/// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
/// fields. /// fields.
/// ///
/// for CurrentColumn = 0..NumColumns /// for ColumnLoop.Index = 0..NumColumns
/// for CurrentRow = 0..NumRows /// for RowLoop.Index = 0..NumRows
/// for CurrentInner = 0..NumInner /// for InnerLoop.Index = 0..NumInner
BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End, BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
IRBuilderBase &B, DomTreeUpdater &DTU, IRBuilderBase &B, DomTreeUpdater &DTU,
LoopInfo &LI); LoopInfo &LI);

View File

@ -1423,13 +1423,13 @@ public:
FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
MatrixTy TileResult; MatrixTy TileResult;
// Insert in the inner loop header. // Insert in the inner loop header.
Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
// Create PHI nodes for the result columns to accumulate across iterations. // Create PHI nodes for the result columns to accumulate across iterations.
SmallVector<PHINode *, 4> ColumnPhis; SmallVector<PHINode *, 4> ColumnPhis;
for (unsigned I = 0; I < TileSize; I++) { for (unsigned I = 0; I < TileSize; I++) {
auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
TI.RowLoopHeader->getSingleSuccessor()); TI.RowLoop.Header->getSingleSuccessor());
TileResult.addVector(Phi); TileResult.addVector(Phi);
ColumnPhis.push_back(Phi); ColumnPhis.push_back(Phi);
} }
@ -1438,27 +1438,29 @@ public:
// Res += Load(CurrentRow, K) * Load(K, CurrentColumn) // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
Builder.SetInsertPoint(InnerBody->getTerminator()); Builder.SetInsertPoint(InnerBody->getTerminator());
// Load tiles of the operands. // Load tiles of the operands.
MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK, MatrixTy A =
{TileSize, TileSize}, EltType, Builder); loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, {TileSize, TileSize}, EltType, Builder);
{TileSize, TileSize}, EltType, Builder); MatrixTy B =
loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
{TileSize, TileSize}, EltType, Builder);
emitMatrixMultiply(TileResult, A, B, Builder, true, false, emitMatrixMultiply(TileResult, A, B, Builder, true, false,
getFastMathFlags(MatMul)); getFastMathFlags(MatMul));
// Store result after the inner loop is done. // Store result after the inner loop is done.
Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
TI.CurrentRow, TI.CurrentCol, EltType, Builder); TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
for (unsigned I = 0; I < TileResult.getNumVectors(); I++) for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch); ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
// Force unrolling of a few iterations of the inner loop, to make sure there // Force unrolling of a few iterations of the inner loop, to make sure there
// is enough work per iteration. // is enough work per iteration.
// FIXME: The unroller should make this decision directly instead, but // FIXME: The unroller should make this decision directly instead, but
// currently the cost-model is not up to the task. // currently the cost-model is not up to the task.
unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader), addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
"llvm.loop.unroll.count", InnerLoopUnrollCount); "llvm.loop.unroll.count", InnerLoopUnrollCount);
} }

View File

@ -70,35 +70,35 @@ BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End, BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
IRBuilderBase &B, DomTreeUpdater &DTU, IRBuilderBase &B, DomTreeUpdater &DTU,
LoopInfo &LI) { LoopInfo &LI) {
Loop *ColLoop = LI.AllocateLoop(); Loop *ColumnLoopInfo = LI.AllocateLoop();
Loop *RowLoop = LI.AllocateLoop(); Loop *RowLoopInfo = LI.AllocateLoop();
Loop *InnerLoop = LI.AllocateLoop(); Loop *KLoopInfo = LI.AllocateLoop();
RowLoop->addChildLoop(InnerLoop); RowLoopInfo->addChildLoop(KLoopInfo);
ColLoop->addChildLoop(RowLoop); ColumnLoopInfo->addChildLoop(RowLoopInfo);
if (Loop *ParentL = LI.getLoopFor(Start)) if (Loop *ParentL = LI.getLoopFor(Start))
ParentL->addChildLoop(ColLoop); ParentL->addChildLoop(ColumnLoopInfo);
else else
LI.addTopLevelLoop(ColLoop); LI.addTopLevelLoop(ColumnLoopInfo);
BasicBlock *ColBody = BasicBlock *ColBody =
CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize), CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
"cols", B, DTU, ColLoop, LI); "cols", B, DTU, ColumnLoopInfo, LI);
BasicBlock *ColLatch = ColBody->getSingleSuccessor(); ColumnLoop.Latch = ColBody->getSingleSuccessor();
BasicBlock *RowBody = BasicBlock *RowBody =
CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize), CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
"rows", B, DTU, RowLoop, LI); B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
RowLoopLatch = RowBody->getSingleSuccessor(); RowLoop.Latch = RowBody->getSingleSuccessor();
BasicBlock *InnerBody = BasicBlock *InnerBody =
CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner), CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI); B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
InnerLoopLatch = InnerBody->getSingleSuccessor(); KLoop.Latch = InnerBody->getSingleSuccessor();
ColumnLoopHeader = ColBody->getSinglePredecessor(); ColumnLoop.Header = ColBody->getSinglePredecessor();
RowLoopHeader = RowBody->getSinglePredecessor(); RowLoop.Header = RowBody->getSinglePredecessor();
InnerLoopHeader = InnerBody->getSinglePredecessor(); KLoop.Header = InnerBody->getSinglePredecessor();
CurrentRow = &*RowLoopHeader->begin(); RowLoop.Index = &*RowLoop.Header->begin();
CurrentCol = &*ColumnLoopHeader->begin(); ColumnLoop.Index = &*ColumnLoop.Header->begin();
CurrentK = &*InnerLoopHeader->begin(); KLoop.Index = &*KLoop.Header->begin();
return InnerBody; return InnerBody;
} }