157 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			157 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			C++
		
	
	
	
//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
 | 
						|
//
 | 
						|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
						|
// See https://llvm.org/LICENSE.txt for license information.
 | 
						|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
						|
//
 | 
						|
//===----------------------------------------------------------------------===//
 | 
						|
 | 
						|
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
 | 
						|
 | 
						|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
 | 
						|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 | 
						|
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 | 
						|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
 | 
						|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
 | 
						|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
						|
#include "mlir/Interfaces/TilingInterface.h"
 | 
						|
 | 
						|
using namespace mlir;
 | 
						|
using namespace mlir::linalg;
 | 
						|
 | 
						|
namespace {
 | 
						|
 | 
						|
/// External model implementation of TilingInterface for LinalgOps. An external
 | 
						|
/// model implementation is used for now till the use of `TilingInterface` is
 | 
						|
/// on-par with the current Linalg tiling + fusion patterns. Once it is
 | 
						|
/// maybe possible to move this into the op-definition (though there are
 | 
						|
/// advantages to leaving it as an external model)
 | 
						|
template <typename LinalgOpTy>
 | 
						|
struct LinalgOpTilingInterface
 | 
						|
    : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
 | 
						|
                                            LinalgOpTy> {
 | 
						|
 | 
						|
  /// Return the destination operands.
 | 
						|
  SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
 | 
						|
    return llvm::cast<LinalgOp>(op).getOutputOperands();
 | 
						|
  }
 | 
						|
 | 
						|
  /// Return the loop iterator type.
 | 
						|
  SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
 | 
						|
    LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
 | 
						|
    return llvm::to_vector(
 | 
						|
        llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
 | 
						|
          return strAttr.cast<StringAttr>().getValue();
 | 
						|
        }));
 | 
						|
  }
 | 
						|
 | 
						|
  /// Return the iteration domain range.
 | 
						|
  SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
 | 
						|
    Location loc = op->getLoc();
 | 
						|
    LinalgOp linalgOp = cast<LinalgOp>(op);
 | 
						|
    auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
 | 
						|
    AffineMap map = linalgOp.getShapesToLoopsMap();
 | 
						|
    Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
 | 
						|
    Value one = b.create<arith::ConstantIndexOp>(loc, 1);
 | 
						|
    return llvm::to_vector(llvm::map_range(
 | 
						|
        applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
 | 
						|
          return Range{zero, v, one};
 | 
						|
        }));
 | 
						|
  }
 | 
						|
 | 
						|
  // Instantiate the tiled implementation of the operation.
 | 
						|
  SmallVector<Operation *>
 | 
						|
  getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
 | 
						|
                         ArrayRef<OpFoldResult> offsets,
 | 
						|
                         ArrayRef<OpFoldResult> sizes,
 | 
						|
                         bool tileDestOperands) const {
 | 
						|
    // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
 | 
						|
    // specified could lead to out of bounds accesses.
 | 
						|
    Location loc = op->getLoc();
 | 
						|
    LinalgOp linalgOp = cast<LinalgOp>(op);
 | 
						|
    SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
 | 
						|
    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
 | 
						|
        b, loc, linalgOp, valuesToTile,
 | 
						|
        getValueOrCreateConstantIndexOp(b, loc, offsets),
 | 
						|
        getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
 | 
						|
 | 
						|
    SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
 | 
						|
        linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
 | 
						|
          return tiledOperands[opOperand->getOperandNumber()].getType();
 | 
						|
        }));
 | 
						|
 | 
						|
    Operation *tiledOp =
 | 
						|
        linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
 | 
						|
 | 
						|
    return {tiledOp};
 | 
						|
  }
 | 
						|
 | 
						|
  // Return the details of the output tile generated by the tiled
 | 
						|
  // implementation.
 | 
						|
  LogicalResult
 | 
						|
  getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
 | 
						|
                        ArrayRef<OpFoldResult> offsets,
 | 
						|
                        ArrayRef<OpFoldResult> sizes,
 | 
						|
                        SmallVector<OpFoldResult> &resultOffsets,
 | 
						|
                        SmallVector<OpFoldResult> &resultSizes) const {
 | 
						|
    Location loc = op->getLoc();
 | 
						|
    LinalgOp linalgOp = cast<LinalgOp>(op);
 | 
						|
 | 
						|
    AffineExpr d0;
 | 
						|
    bindDims(b.getContext(), d0);
 | 
						|
 | 
						|
    auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
 | 
						|
                                               AffineExpr expr,
 | 
						|
                                               ValueRange operands) -> Value {
 | 
						|
      AffineMap map = AffineMap::inferFromExprList({expr}).front();
 | 
						|
      SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
 | 
						|
      mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
 | 
						|
      canonicalizeMapAndOperands(&map, &normalizedOperands);
 | 
						|
      return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
 | 
						|
    };
 | 
						|
 | 
						|
    SmallVector<Value> sizeVals =
 | 
						|
        getValueOrCreateConstantIndexOp(b, loc, sizes);
 | 
						|
    SmallVector<Value> subShapeSizes =
 | 
						|
        llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
 | 
						|
          return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
 | 
						|
        }));
 | 
						|
    OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
 | 
						|
    Value sliceOpResult =
 | 
						|
        makeTiledShape(b, loc, outOperand->get(), sizeVals,
 | 
						|
                       linalgOp.getTiedIndexingMap(outOperand),
 | 
						|
                       getValueOrCreateConstantIndexOp(b, loc, offsets),
 | 
						|
                       /*ubs*/ {}, subShapeSizes, true);
 | 
						|
    auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
 | 
						|
    if (!sliceOp)
 | 
						|
      return failure();
 | 
						|
    resultOffsets = sliceOp.getMixedOffsets();
 | 
						|
    resultSizes = sliceOp.getMixedSizes();
 | 
						|
    return success();
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
} // namespace
 | 
						|
 | 
						|
template <typename OpType> static void registerOne(MLIRContext *ctx) {
 | 
						|
  OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
 | 
						|
}
 | 
						|
 | 
						|
/// Variadic helper function.
 | 
						|
template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
 | 
						|
  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
 | 
						|
  (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
 | 
						|
}
 | 
						|
 | 
						|
#define GET_OP_LIST
 | 
						|
 | 
						|
void mlir::linalg::registerTilingInterfaceExternalModels(
 | 
						|
    DialectRegistry ®istry) {
 | 
						|
  registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
 | 
						|
    registerOne<linalg::GenericOp>(ctx);
 | 
						|
    registerAll<
 | 
						|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
 | 
						|
        >(ctx);
 | 
						|
  });
 | 
						|
}
 |