llvm-project/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

250 lines
9.7 KiB
C++

//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the tiling using TilingInterface.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "tile-using-interface"
using namespace mlir;
scf::SCFTilingOptions &
scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
assert(!tileSizeComputationFunction && "tile sizes already set");
SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(
&op->getParentOfType<func::FuncOp>().getBody().front());
return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
return v;
}));
};
return *this;
}
/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
/// the
/// tile processed within the inner most loop.
static SmallVector<scf::ForOp>
generateTileLoopNest(OpBuilder &builder, Location loc,
ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes) {
assert(!loopRanges.empty() && "expected at least one loop range");
assert(loopRanges.size() == tileSizeVals.size() &&
"expected as many tile sizes as loop ranges");
OpBuilder::InsertionGuard guard(builder);
SmallVector<scf::ForOp> loops;
offsets.resize(loopRanges.size());
sizes.resize(loopRanges.size());
// The tile size to use (to avoid out of bounds access) is minimum of
// `tileSize` and `ub - iv`, where `iv` is the induction variable
// of the tiled loop.
AffineExpr s0, s1, d0;
bindDims(builder.getContext(), d0);
bindSymbols(builder.getContext(), s0, s1);
AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
for (auto loopRange : llvm::enumerate(loopRanges)) {
// No loops if tile size is zero. Set offset and size to the loop
// offset and size.
if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
offsets[loopRange.index()] = loopRange.value().offset;
sizes[loopRange.index()] = loopRange.value().size;
continue;
}
auto loop = builder.create<scf::ForOp>(
loc, loopRange.value().offset, loopRange.value().size,
tileSizeVals[loopRange.index()], ValueRange{},
[&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
ValueRange /*iterArgs*/) {
Value boundedTileSize = builder.create<AffineMinOp>(
bodyLoc, minMap,
ValueRange{iv, tileSizeVals[loopRange.index()],
loopRange.value().size});
sizes[loopRange.index()] = boundedTileSize;
builder.create<scf::YieldOp>(loc);
});
offsets[loopRange.index()] = loop.getInductionVar();
loops.push_back(loop);
builder.setInsertionPoint(loop.getBody()->getTerminator());
}
return loops;
}
scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
scf::SCFTilingOptions options,
PatternBenefit benefit)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)) {}
scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
MLIRContext *context,
scf::SCFTilingOptions options,
PatternBenefit benefit)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
options(std::move(options)) {}
FailureOr<scf::SCFTilingResult>
scf::TileUsingSCFForOp::returningMatchAndRewrite(
TilingInterface op, PatternRewriter &rewriter) const {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);
if (!options.tileSizeComputationFunction) {
return rewriter.notifyMatchFailure(
op, "missing tile size computation function");
}
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
size_t numLoops = iterationDomain.size();
if (numLoops == 0) {
return rewriter.notifyMatchFailure(
op, "unable to tile op with no iteration domain");
}
// 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
// dimensions.
SmallVector<Value, 4> tileSizeVector =
options.tileSizeComputationFunction(rewriter, op);
if (tileSizeVector.size() < iterationDomain.size()) {
auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
}
scf::SCFTilingResult tilingResult;
SmallVector<OpFoldResult> offsets, sizes;
{
// 3. Materialize an empty loop nest that iterates over the tiles. These
// loops for now do not return any values even if the original operation has
// results.
tilingResult.loops = generateTileLoopNest(
rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
llvm::errs() << "LoopNest shell :\n";
tilingResult.loops.front().dump();
llvm::errs() << "\n";
}
});
// 4. Generate the tiled implementation within the inner most loop.
if (!tilingResult.loops.empty())
rewriter.setInsertionPoint(
tilingResult.loops.back().getBody()->getTerminator());
SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
if (tiledImplementation.size() != 1) {
return rewriter.notifyMatchFailure(
op, "expected tiled implementation to return a single op");
}
tilingResult.tiledOp = tiledImplementation[0];
LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
llvm::errs() << "After tiled implementation :\n";
tilingResult.loops.front().dump();
llvm::errs() << "\n";
}
});
}
if (op->getNumResults() == 0) {
rewriter.eraseOp(op);
return tilingResult;
}
// 5. If the original operations has results, modify the loop nest to yield
// the replacement values.
SmallVector<Value> replacements;
if (tilingResult.loops.empty()) {
// 5a. If there were no loops, the tiled implementation results are the
// replacements.
rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
return tilingResult;
}
// 5b. `scf.for` with tensor semantics requires the loop nest to yield the
// replacement values using destructive updates. Use the `TilingInterface`
// to get the position of the result tiles and use that to generate the
// destructive update pattern, i.e.,
//
// ```mlir
// scf.for %iv0 = ... {
// %0 = tiled_op
// }
// ```
//
// is transformed to
//
// ```mlir
// %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
// %0 = tiled_op
// %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
// scf.yield %1
// }
// ```
NewYieldValueFn yieldValueFn =
[&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
SmallVector<Value> yieldedValues;
Attribute one = b.getIndexAttr(1);
for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
resultTileOffsets,
resultTileSizes))) {
op.emitOpError("unable to get position of result ")
<< resultNum << " of the tiled implementation";
return {};
}
SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
one);
Value yieldedValue = b.create<tensor::InsertSliceOp>(
op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
resultTileStrides);
yieldedValues.push_back(yieldedValue);
}
return yieldedValues;
};
SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
yieldValueFn);
for (auto loop : llvm::enumerate(tilingResult.loops)) {
rewriter.eraseOp(loop.value());
tilingResult.loops[loop.index()] = newLoops[loop.index()];
}
rewriter.replaceOp(op, tilingResult.loops.front().getResults());
return tilingResult;
}