[Matrix] Refactor tiled loops in a struct. NFC
The three loops have the same structure: index, header, latch.
This commit is contained in:
parent
39d431d811
commit
2c6e8b4636
|
|
@ -25,9 +25,9 @@ class IRBuilderBase;
|
|||
|
||||
/// A helper struct to create IR loop nests for tiling in IR of the following
|
||||
/// form:
|
||||
/// for CurrentColumn = 0..NumColumns
|
||||
/// for CurrentRow = 0..NumRows
|
||||
/// for CurrentInner = 0..NumInner
|
||||
/// for ColumnLoop.Index = 0..NumColumns
|
||||
/// for RowLoop.Index = 0..NumRows
|
||||
/// for KLoop.Index = 0..NumInner
|
||||
struct TileInfo {
|
||||
/// Number of rows of the matrix.
|
||||
unsigned NumRows;
|
||||
|
|
@ -42,26 +42,21 @@ struct TileInfo {
|
|||
/// Number of rows/columns in a tile.
|
||||
unsigned TileSize = -1;
|
||||
|
||||
/// Start row of the current tile to compute.
|
||||
Value *CurrentRow;
|
||||
/// Properties of a single loop used when generating the tiled loop nest.
|
||||
struct MatrixLoop {
|
||||
/// The index updated on every iteration.
|
||||
Value *Index = nullptr;
|
||||
/// The header and latch of the loop.
|
||||
BasicBlock *Header = nullptr;
|
||||
BasicBlock *Latch = nullptr;
|
||||
};
|
||||
|
||||
/// Start column of the current tile to compute.
|
||||
Value *CurrentCol;
|
||||
|
||||
/// Current tile offset during the tile computation.
|
||||
Value *CurrentK;
|
||||
|
||||
/// Header of the outermost loop iterating from 0..NumColumns.
|
||||
BasicBlock *ColumnLoopHeader = nullptr;
|
||||
|
||||
/// Header of the second loop iterating from 0..NumRows.
|
||||
BasicBlock *RowLoopHeader = nullptr;
|
||||
/// Latch of the second loop iterating from 0..NumRows.
|
||||
BasicBlock *RowLoopLatch = nullptr;
|
||||
/// Header of the innermost loop iterating from 0..NumInner.
|
||||
BasicBlock *InnerLoopHeader = nullptr;
|
||||
/// Latch of the innermost loop iterating from 0..NumInner.
|
||||
BasicBlock *InnerLoopLatch = nullptr;
|
||||
/// The loop iterating on the rows.
|
||||
MatrixLoop RowLoop;
|
||||
/// The loop iterating on the columns.
|
||||
MatrixLoop ColumnLoop;
|
||||
/// The loop iterating on k (inner dimension).
|
||||
MatrixLoop KLoop;
|
||||
|
||||
TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
|
||||
unsigned TileSize)
|
||||
|
|
@ -72,9 +67,9 @@ struct TileInfo {
|
|||
/// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
|
||||
/// fields.
|
||||
///
|
||||
/// for CurrentColumn = 0..NumColumns
|
||||
/// for CurrentRow = 0..NumRows
|
||||
/// for CurrentInner = 0..NumInner
|
||||
/// for ColumnLoop.Index = 0..NumColumns
|
||||
/// for RowLoop.Index = 0..NumRows
|
||||
/// for InnerLoop.Index = 0..NumInner
|
||||
BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
|
||||
IRBuilderBase &B, DomTreeUpdater &DTU,
|
||||
LoopInfo &LI);
|
||||
|
|
|
|||
|
|
@ -1423,13 +1423,13 @@ public:
|
|||
FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize);
|
||||
MatrixTy TileResult;
|
||||
// Insert in the inner loop header.
|
||||
Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator());
|
||||
Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
|
||||
// Create PHI nodes for the result columns to accumulate across iterations.
|
||||
SmallVector<PHINode *, 4> ColumnPhis;
|
||||
for (unsigned I = 0; I < TileSize; I++) {
|
||||
auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
|
||||
Phi->addIncoming(ConstantAggregateZero::get(TileVecTy),
|
||||
TI.RowLoopHeader->getSingleSuccessor());
|
||||
TI.RowLoop.Header->getSingleSuccessor());
|
||||
TileResult.addVector(Phi);
|
||||
ColumnPhis.push_back(Phi);
|
||||
}
|
||||
|
|
@ -1438,27 +1438,29 @@ public:
|
|||
// Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
|
||||
Builder.SetInsertPoint(InnerBody->getTerminator());
|
||||
// Load tiles of the operands.
|
||||
MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK,
|
||||
{TileSize, TileSize}, EltType, Builder);
|
||||
MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol,
|
||||
{TileSize, TileSize}, EltType, Builder);
|
||||
MatrixTy A =
|
||||
loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
|
||||
{TileSize, TileSize}, EltType, Builder);
|
||||
MatrixTy B =
|
||||
loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
|
||||
{TileSize, TileSize}, EltType, Builder);
|
||||
emitMatrixMultiply(TileResult, A, B, Builder, true, false,
|
||||
getFastMathFlags(MatMul));
|
||||
// Store result after the inner loop is done.
|
||||
Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator());
|
||||
Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
|
||||
storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
|
||||
Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
|
||||
TI.CurrentRow, TI.CurrentCol, EltType, Builder);
|
||||
TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
|
||||
|
||||
for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
|
||||
ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch);
|
||||
ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
|
||||
|
||||
// Force unrolling of a few iterations of the inner loop, to make sure there
|
||||
// is enough work per iteration.
|
||||
// FIXME: The unroller should make this decision directly instead, but
|
||||
// currently the cost-model is not up to the task.
|
||||
unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
|
||||
addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader),
|
||||
addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header),
|
||||
"llvm.loop.unroll.count", InnerLoopUnrollCount);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -70,35 +70,35 @@ BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
|
|||
BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
|
||||
IRBuilderBase &B, DomTreeUpdater &DTU,
|
||||
LoopInfo &LI) {
|
||||
Loop *ColLoop = LI.AllocateLoop();
|
||||
Loop *RowLoop = LI.AllocateLoop();
|
||||
Loop *InnerLoop = LI.AllocateLoop();
|
||||
RowLoop->addChildLoop(InnerLoop);
|
||||
ColLoop->addChildLoop(RowLoop);
|
||||
Loop *ColumnLoopInfo = LI.AllocateLoop();
|
||||
Loop *RowLoopInfo = LI.AllocateLoop();
|
||||
Loop *KLoopInfo = LI.AllocateLoop();
|
||||
RowLoopInfo->addChildLoop(KLoopInfo);
|
||||
ColumnLoopInfo->addChildLoop(RowLoopInfo);
|
||||
if (Loop *ParentL = LI.getLoopFor(Start))
|
||||
ParentL->addChildLoop(ColLoop);
|
||||
ParentL->addChildLoop(ColumnLoopInfo);
|
||||
else
|
||||
LI.addTopLevelLoop(ColLoop);
|
||||
LI.addTopLevelLoop(ColumnLoopInfo);
|
||||
|
||||
BasicBlock *ColBody =
|
||||
CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
|
||||
"cols", B, DTU, ColLoop, LI);
|
||||
BasicBlock *ColLatch = ColBody->getSingleSuccessor();
|
||||
"cols", B, DTU, ColumnLoopInfo, LI);
|
||||
ColumnLoop.Latch = ColBody->getSingleSuccessor();
|
||||
BasicBlock *RowBody =
|
||||
CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize),
|
||||
"rows", B, DTU, RowLoop, LI);
|
||||
RowLoopLatch = RowBody->getSingleSuccessor();
|
||||
CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
|
||||
B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
|
||||
RowLoop.Latch = RowBody->getSingleSuccessor();
|
||||
|
||||
BasicBlock *InnerBody =
|
||||
CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner),
|
||||
B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI);
|
||||
InnerLoopLatch = InnerBody->getSingleSuccessor();
|
||||
ColumnLoopHeader = ColBody->getSinglePredecessor();
|
||||
RowLoopHeader = RowBody->getSinglePredecessor();
|
||||
InnerLoopHeader = InnerBody->getSinglePredecessor();
|
||||
CurrentRow = &*RowLoopHeader->begin();
|
||||
CurrentCol = &*ColumnLoopHeader->begin();
|
||||
CurrentK = &*InnerLoopHeader->begin();
|
||||
CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
|
||||
B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
|
||||
KLoop.Latch = InnerBody->getSingleSuccessor();
|
||||
ColumnLoop.Header = ColBody->getSinglePredecessor();
|
||||
RowLoop.Header = RowBody->getSinglePredecessor();
|
||||
KLoop.Header = InnerBody->getSinglePredecessor();
|
||||
RowLoop.Index = &*RowLoop.Header->begin();
|
||||
ColumnLoop.Index = &*ColumnLoop.Header->begin();
|
||||
KLoop.Index = &*KLoop.Header->begin();
|
||||
|
||||
return InnerBody;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue