llvm-project/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

157 lines
6.3 KiB
C++

//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/TilingInterface.h"
using namespace mlir;
using namespace mlir::linalg;
namespace {
/// External model implementation of TilingInterface for LinalgOps. An external
/// model implementation is used for now till the use of `TilingInterface` is
/// on-par with the current Linalg tiling + fusion patterns. Once it is
/// maybe possible to move this into the op-definition (though there are
/// advantages to leaving it as an external model)
template <typename LinalgOpTy>
struct LinalgOpTilingInterface
: public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
LinalgOpTy> {
/// Return the destination operands.
SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
return llvm::cast<LinalgOp>(op).getOutputOperands();
}
/// Return the loop iterator type.
SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
return llvm::to_vector(
llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
return strAttr.cast<StringAttr>().getValue();
}));
}
/// Return the iteration domain range.
SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
AffineMap map = linalgOp.getShapesToLoopsMap();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
return llvm::to_vector(llvm::map_range(
applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
return Range{zero, v, one};
}));
}
// Instantiate the tiled implementation of the operation.
SmallVector<Operation *>
getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool tileDestOperands) const {
// Leave the `sizeBounds` value empty. That is only needed when the `sizes`
// specified could lead to out of bounds accesses.
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
b, loc, linalgOp, valuesToTile,
getValueOrCreateConstantIndexOp(b, loc, offsets),
getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
return tiledOperands[opOperand->getOperandNumber()].getType();
}));
Operation *tiledOp =
linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
return {tiledOp};
}
// Return the details of the output tile generated by the tiled
// implementation.
LogicalResult
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) const {
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
AffineExpr d0;
bindDims(b.getContext(), d0);
auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
AffineExpr expr,
ValueRange operands) -> Value {
AffineMap map = AffineMap::inferFromExprList({expr}).front();
SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
canonicalizeMapAndOperands(&map, &normalizedOperands);
return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
};
SmallVector<Value> sizeVals =
getValueOrCreateConstantIndexOp(b, loc, sizes);
SmallVector<Value> subShapeSizes =
llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
}));
OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
Value sliceOpResult =
makeTiledShape(b, loc, outOperand->get(), sizeVals,
linalgOp.getTiedIndexingMap(outOperand),
getValueOrCreateConstantIndexOp(b, loc, offsets),
/*ubs*/ {}, subShapeSizes, true);
auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp)
return failure();
resultOffsets = sliceOp.getMixedOffsets();
resultSizes = sliceOp.getMixedSizes();
return success();
}
};
} // namespace
template <typename OpType> static void registerOne(MLIRContext *ctx) {
OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
}
/// Variadic helper function.
template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
}
#define GET_OP_LIST
void mlir::linalg::registerTilingInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
registerOne<linalg::GenericOp>(ctx);
registerAll<
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(ctx);
});
}