445 lines
18 KiB
C++
445 lines
18 KiB
C++
//===- ConvertToLLVMDialect.cpp - conversion from Linalg to LLVM dialect --===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
#include "mlir/EDSC/Builders.h"
|
|
#include "mlir/EDSC/Intrinsics.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/StandardTypes.h"
|
|
#include "mlir/IR/Types.h"
|
|
#include "mlir/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/Support/Allocator.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
#include "linalg1/Common.h"
|
|
#include "linalg1/ConvertToLLVMDialect.h"
|
|
#include "linalg1/LLVMIntrinsics.h"
|
|
#include "linalg1/Ops.h"
|
|
#include "linalg1/Passes.h"
|
|
|
|
using namespace mlir;
|
|
|
|
// Convert the given type to the LLVM IR Dialect type. The following
|
|
// conversions are supported:
|
|
// - an Index type is converted into an LLVM integer type with pointer
|
|
// bitwidth (analogous to intptr_t in C);
|
|
// - an Integer type is converted into an LLVM integer type of the same width;
|
|
// - an F32 type is converted into an LLVM float type
|
|
// - a Range or View is converted into an LLVM structure type containing the
|
|
// respective dynamic values.
|
|
Type linalg::convertLinalgType(Type t) {
|
|
auto *context = t.getContext();
|
|
auto *dialect = context->getRegisteredDialect<LLVM::LLVMDialect>();
|
|
|
|
// Simple conversions.
|
|
if (t.isa<IndexType>()) {
|
|
int width = dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
|
|
return LLVM::LLVMType::getIntNTy(dialect, width);
|
|
}
|
|
if (auto intTy = t.dyn_cast<IntegerType>())
|
|
return LLVM::LLVMType::getIntNTy(dialect, intTy.getWidth());
|
|
if (t.isF32())
|
|
return LLVM::LLVMType::getFloatTy(dialect);
|
|
if (t.isF64())
|
|
return LLVM::LLVMType::getDoubleTy(dialect);
|
|
|
|
// Range descriptor contains the range bounds and the step as 64-bit integers.
|
|
//
|
|
// struct {
|
|
// int64_t min;
|
|
// int64_t max;
|
|
// int64_t step;
|
|
// };
|
|
if (auto rangeTy = t.dyn_cast<linalg::RangeType>()) {
|
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
|
|
return LLVM::LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
|
}
|
|
|
|
// View descriptor contains the pointer to the data buffer, followed by a
|
|
// 64-bit integer containing the distance between the beginning of the buffer
|
|
// and the first element to be accessed through the view, followed by two
|
|
// arrays, each containing as many 64-bit integers as the rank of the View.
|
|
// The first array represents the size, in number of original elements, of the
|
|
// view along the given dimension. When taking the view, the size is the
|
|
// difference between the upper and the lower bound of the range. The second
|
|
// array represents the "stride" (in tensor abstraction sense), i.e. the
|
|
// number of consecutive elements of the underlying buffer that separate two
|
|
// consecutive elements addressable through the view along the given
|
|
// dimension. When taking the view, the strides are constructed as products
|
|
// of the original sizes along the trailing dimensions, multiplied by the view
|
|
// step. For example, a view of a MxN memref with ranges {0:M:1}, {0:N:1},
|
|
// i.e. the view of a complete memref, will have strides N and 1. A view with
|
|
// ranges {0:M:2}, {0:N:3} will have strides 2*N and 3.
|
|
//
|
|
// template <typename Elem, size_t Rank>
|
|
// struct {
|
|
// Elem *ptr;
|
|
// int64_t offset;
|
|
// int64_t sizes[Rank];
|
|
// int64_t strides[Rank];
|
|
// };
|
|
if (auto viewTy = t.dyn_cast<linalg::ViewType>()) {
|
|
auto elemTy = linalg::convertLinalgType(viewTy.getElementType())
|
|
.cast<LLVM::LLVMType>()
|
|
.getPointerTo();
|
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(dialect);
|
|
auto arrayTy = LLVM::LLVMType::getArrayTy(int64Ty, viewTy.getRank());
|
|
return LLVM::LLVMType::getStructTy(elemTy, int64Ty, arrayTy, arrayTy);
|
|
}
|
|
|
|
// All other types are kept as is.
|
|
return t;
|
|
}
|
|
|
|
// Create an array attribute containing integer attributes with values provided
|
|
// in `position`.
|
|
static ArrayAttr makePositionAttr(OpBuilder &builder, ArrayRef<int> position) {
|
|
SmallVector<Attribute, 4> attrs;
|
|
attrs.reserve(position.size());
|
|
for (auto p : position)
|
|
attrs.push_back(builder.getI64IntegerAttr(p));
|
|
return builder.getArrayAttr(attrs);
|
|
}
|
|
|
|
// RangeOp creates a new range descriptor.
|
|
class RangeOpConversion : public ConversionPattern {
|
|
public:
|
|
explicit RangeOpConversion(MLIRContext *context)
|
|
: ConversionPattern(linalg::RangeOp::getOperationName(), 1, context) {}
|
|
|
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto rangeOp = cast<linalg::RangeOp>(op);
|
|
auto rangeDescriptorType =
|
|
linalg::convertLinalgType(rangeOp.getResult()->getType());
|
|
|
|
using namespace intrinsics;
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
|
|
// Fill in an aggregate value of the descriptor.
|
|
Value *rangeDescriptor = undef(rangeDescriptorType);
|
|
rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
|
|
operands[0], makePositionAttr(rewriter, 0));
|
|
rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
|
|
operands[1], makePositionAttr(rewriter, 1));
|
|
rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
|
|
operands[2], makePositionAttr(rewriter, 2));
|
|
rewriter.replaceOp(op, rangeDescriptor);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
class ViewOpConversion : public ConversionPattern {
|
|
public:
|
|
explicit ViewOpConversion(MLIRContext *context)
|
|
: ConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {}
|
|
|
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto viewOp = cast<linalg::ViewOp>(op);
|
|
auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType());
|
|
auto memrefType =
|
|
viewOp.getSupportingMemRef()->getType().cast<MemRefType>();
|
|
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
|
|
|
|
// Helper function to create an integer array attribute out of a list of
|
|
// values.
|
|
auto pos = [&rewriter](ArrayRef<int> values) {
|
|
return makePositionAttr(rewriter, values);
|
|
};
|
|
|
|
// Helper function to emit an LLVMIR Dialect 64-bit integer constant given
|
|
// its value.
|
|
auto i64cst = [&rewriter, int64Ty](int64_t value) {
|
|
return intrinsics::constant(
|
|
int64Ty, IntegerAttr::get(rewriter.getIndexType(), value));
|
|
};
|
|
|
|
// Helper function to obtain the size of the given `memref` along the
|
|
// dimension `dim`. For static dimensions, emits a constant; for dynamic
|
|
// dimensions, extracts the size from the memref descriptor.
|
|
auto memrefSize = [int64Ty, pos, i64cst](MemRefType type, Value *memref,
|
|
int dim) -> Value * {
|
|
assert(dim < type.getRank());
|
|
if (type.getShape()[dim] != -1) {
|
|
return i64cst(type.getShape()[dim]);
|
|
}
|
|
int dynamicDimPos = 0;
|
|
for (int i = 0; i < dim; ++i)
|
|
if (type.getShape()[i] == -1)
|
|
++dynamicDimPos;
|
|
return intrinsics::extractvalue(int64Ty, memref, pos(1 + dynamicDimPos));
|
|
};
|
|
|
|
// Helper function to obtain the data pointer of the given `memref`.
|
|
auto memrefPtr = [pos](MemRefType type, Value *memref) -> Value * {
|
|
if (type.hasStaticShape())
|
|
return memref;
|
|
|
|
auto elementTy = linalg::convertLinalgType(type.getElementType())
|
|
.cast<LLVM::LLVMType>()
|
|
.getPointerTo();
|
|
return intrinsics::extractvalue(elementTy, memref, pos(0));
|
|
};
|
|
|
|
using namespace intrinsics;
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
|
|
// Declare the view descriptor.
|
|
Value *viewDescriptor = undef(viewDescriptorType);
|
|
// Insert the data pointer.
|
|
Value *bufferPtr = memrefPtr(memrefType, operands[0]);
|
|
viewDescriptor =
|
|
insertvalue(viewDescriptorType, viewDescriptor, bufferPtr, pos(0));
|
|
|
|
// Collect all memref sizes but the first, which are needed for further
|
|
// computation.
|
|
SmallVector<Value *, 4> trueSizes(memrefType.getRank());
|
|
for (int i = 1, e = memrefType.getRank(); i < e; ++i) {
|
|
trueSizes[i] = memrefSize(memrefType, operands[0], i);
|
|
}
|
|
|
|
// Compute all strides of the memref.
|
|
SmallVector<Value *, 4> trueStrides(memrefType.getRank());
|
|
if (viewOp.getRank() != 0)
|
|
trueStrides[memrefType.getRank() - 1] = i64cst(1);
|
|
for (int i = memrefType.getRank() - 2; i >= 0; --i)
|
|
trueStrides[i] = mul(trueStrides[i + 1], trueSizes[i + 1]);
|
|
|
|
// Compute and insert the base offset.
|
|
Value *baseOffset = i64cst(0);
|
|
for (int j = 0, e = memrefType.getRank(); j < e; ++j) {
|
|
Value *indexing = operands[1 + j];
|
|
Value *min = viewOp.getIndexing(j)->getType().isa<linalg::RangeType>()
|
|
? (Value *)extractvalue(int64Ty, indexing, pos(0))
|
|
: indexing;
|
|
Value *product = mul(min, trueStrides[j]);
|
|
baseOffset = add(baseOffset, product);
|
|
}
|
|
viewDescriptor =
|
|
insertvalue(viewDescriptorType, viewDescriptor, baseOffset, pos(1));
|
|
|
|
// Compute and insert view sizes (max - min along the range). Skip the
|
|
// non-range operands as they will be projected away from the view.
|
|
int i = 0;
|
|
for (Value *index : viewOp.getIndexings()) {
|
|
if (!index->getType().isa<linalg::RangeType>())
|
|
continue;
|
|
|
|
Value *rangeDescriptor = operands[1 + i];
|
|
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
|
|
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
|
|
Value *size = sub(max, min);
|
|
|
|
viewDescriptor =
|
|
insertvalue(viewDescriptorType, viewDescriptor, size, pos({2, i}));
|
|
++i;
|
|
}
|
|
|
|
// Compute and insert view strides. Step over the strides that correspond
|
|
// to non-range operands as they are projected away from the view.
|
|
i = 0;
|
|
for (int j = 0, e = trueStrides.size(); j < e; ++j) {
|
|
if (!viewOp.getIndexing(j)->getType().isa<linalg::RangeType>())
|
|
continue;
|
|
Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
|
|
Value *stride = mul(trueStrides[j], step);
|
|
viewDescriptor =
|
|
insertvalue(viewDescriptorType, viewDescriptor, stride, pos({3, i}));
|
|
++i;
|
|
}
|
|
|
|
rewriter.replaceOp(op, viewDescriptor);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
class SliceOpConversion : public ConversionPattern {
|
|
public:
|
|
explicit SliceOpConversion(MLIRContext *context)
|
|
: ConversionPattern(linalg::SliceOp::getOperationName(), 1, context) {}
|
|
|
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto sliceOp = cast<linalg::SliceOp>(op);
|
|
auto newViewDescriptorType =
|
|
linalg::convertLinalgType(sliceOp.getViewType());
|
|
auto elementType = linalg::convertLinalgType(sliceOp.getElementType())
|
|
.cast<LLVM::LLVMType>()
|
|
.getPointerTo();
|
|
auto int64Ty = linalg::convertLinalgType(rewriter.getIntegerType(64));
|
|
|
|
auto pos = [&rewriter](ArrayRef<int> values) {
|
|
return makePositionAttr(rewriter, values);
|
|
};
|
|
|
|
// First operand to `slice` is the old view descriptor.
|
|
Value *oldViewDescriptor = operands[0];
|
|
|
|
// Properties of the slice.
|
|
bool isRankDecreasing = sliceOp.isRankDecreasing();
|
|
int dim = sliceOp.getSlicingDim();
|
|
assert(isRankDecreasing ^
|
|
sliceOp.getIndexing()->getType().isa<linalg::RangeType>());
|
|
|
|
// Declare the descriptor of the new view.
|
|
using namespace intrinsics;
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
Value *newViewDescriptor = undef(newViewDescriptorType);
|
|
|
|
// Copy the buffer pointer from the old descriptor to the new one.
|
|
Value *buffer = extractvalue(elementType, oldViewDescriptor, pos(0));
|
|
newViewDescriptor =
|
|
insertvalue(newViewDescriptorType, newViewDescriptor, buffer, pos(0));
|
|
|
|
// Update the base offset:
|
|
// base_offset' = base_offset + min_d * stride_d
|
|
// where d is the dimension being sliced, min_d is the minimum value of the
|
|
// range (in case of a single-value slice, that value), stride_d is the
|
|
// stride along this dimension.
|
|
Value *baseOffset = extractvalue(int64Ty, oldViewDescriptor, pos(1));
|
|
Value *slicingValue = operands[1];
|
|
// If `slice` is not rank-decreasing, we need to extract the "min" value
|
|
// from the range descriptor. Otherwise, we take the value directly.
|
|
Value *min = !isRankDecreasing
|
|
? (Value *)extractvalue(int64Ty, slicingValue, pos(0))
|
|
: slicingValue;
|
|
Value *stride = extractvalue(int64Ty, oldViewDescriptor, pos({3, dim}));
|
|
baseOffset = add(baseOffset, mul(min, stride));
|
|
newViewDescriptor = insertvalue(newViewDescriptorType, newViewDescriptor,
|
|
baseOffset, pos(1));
|
|
|
|
// Copy the sizes and strides into the new descriptor, updating or dropping
|
|
// the affected dimension. If the `slice` is rank-decreasing, the resulting
|
|
// view will no longer one of the dimensions, its size and stride become
|
|
// unnecessary and can be dropped. Otherwise, the size of the affected
|
|
// updated to the size of the range and its stride is multiplied with the
|
|
// step of the range.
|
|
for (int i = 0, e = sliceOp.getRank(); i < e; ++i) {
|
|
int originalPos = (isRankDecreasing && i >= dim) ? i + 1 : i;
|
|
Value *size;
|
|
Value *stride;
|
|
if (!isRankDecreasing && i == dim) {
|
|
Value *upper = extractvalue(int64Ty, slicingValue, pos(1));
|
|
Value *lower = extractvalue(int64Ty, slicingValue, pos(0));
|
|
size = sub(upper, lower);
|
|
|
|
Value *previousStride =
|
|
extractvalue(int64Ty, oldViewDescriptor, pos({3, originalPos}));
|
|
Value *step = extractvalue(int64Ty, slicingValue, pos(2));
|
|
stride = mul(previousStride, step);
|
|
} else {
|
|
size = extractvalue(int64Ty, oldViewDescriptor, pos({2, originalPos}));
|
|
stride =
|
|
extractvalue(int64Ty, oldViewDescriptor, pos({3, originalPos}));
|
|
}
|
|
newViewDescriptor = insertvalue(newViewDescriptorType, newViewDescriptor,
|
|
size, pos({2, i}));
|
|
newViewDescriptor = insertvalue(newViewDescriptorType, newViewDescriptor,
|
|
stride, pos({3, i}));
|
|
}
|
|
|
|
rewriter.replaceOp(op, newViewDescriptor);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
// When converting the "some_consumer" operation, don't emit anything and
|
|
// effectively drop it.
|
|
class DropConsumer : public ConversionPattern {
|
|
public:
|
|
explicit DropConsumer(MLIRContext *context)
|
|
: ConversionPattern("some_consumer", 1, context) {}
|
|
|
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
rewriter.replaceOp(op, llvm::None);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
void linalg::populateLinalg1ToLLVMConversionPatterns(
|
|
mlir::OwningRewritePatternList &patterns, mlir::MLIRContext *context) {
|
|
RewriteListBuilder<DropConsumer, RangeOpConversion, SliceOpConversion,
|
|
ViewOpConversion>::build(patterns, context);
|
|
}
|
|
|
|
namespace {
|
|
/// A type conversion class that converts Linalg and Std types to LLVM.
|
|
struct LinalgTypeConverter : public LLVMTypeConverter {
|
|
using LLVMTypeConverter::LLVMTypeConverter;
|
|
|
|
// This gets called for block and region arguments, and attributes.
|
|
Type convertType(Type t) override {
|
|
if (auto result = LLVMTypeConverter::convertType(t))
|
|
return result;
|
|
return linalg::convertLinalgType(t);
|
|
}
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
void linalg::convertToLLVM(mlir::Module module) {
|
|
// Remove affine constructs if any by using an existing pass.
|
|
PassManager pm;
|
|
pm.addPass(createLowerAffinePass());
|
|
auto rr = pm.run(module);
|
|
(void)rr;
|
|
assert(succeeded(rr) && "affine loop lowering failed");
|
|
|
|
// Convert Linalg ops to the LLVM IR dialect using the converter defined
|
|
// above.
|
|
LinalgTypeConverter converter(module.getContext());
|
|
OwningRewritePatternList patterns;
|
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
|
populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
|
|
|
|
ConversionTarget target(*module.getContext());
|
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
|
auto r =
|
|
applyConversionPatterns(module, target, converter, std::move(patterns));
|
|
(void)r;
|
|
assert(succeeded(r) && "conversion failed");
|
|
}
|
|
|
|
namespace {
|
|
struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
|
|
void runOnModule() { linalg::convertToLLVM(getModule()); }
|
|
};
|
|
} // namespace
|
|
|
|
ModulePassBase *linalg::createLowerLinalgToLLVMPass() {
|
|
return new LowerLinalgToLLVMPass();
|
|
}
|
|
|
|
static PassRegistration<LowerLinalgToLLVMPass>
|
|
pass("lower-linalg-to-llvm",
|
|
"Lower the operations from the linalg dialect into the LLVM dialect");
|