328 lines
13 KiB
C++
328 lines
13 KiB
C++
//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
|
#include "mlir/Dialect/ArmSVE/Transforms.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::arm_sve;
|
|
|
|
// Extract an LLVM IR type from the LLVM IR dialect type.
|
|
static Type unwrap(Type type) {
|
|
if (!type)
|
|
return nullptr;
|
|
auto *mlirContext = type.getContext();
|
|
if (!LLVM::isCompatibleType(type))
|
|
emitError(UnknownLoc::get(mlirContext),
|
|
"conversion resulted in a non-LLVM type");
|
|
return type;
|
|
}
|
|
|
|
static Optional<Type>
|
|
convertScalableVectorTypeToLLVM(ScalableVectorType svType,
|
|
LLVMTypeConverter &converter) {
|
|
auto elementType = unwrap(converter.convertType(svType.getElementType()));
|
|
if (!elementType)
|
|
return {};
|
|
|
|
auto sVectorType =
|
|
LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
|
|
return sVectorType;
|
|
}
|
|
|
|
template <typename OpTy>
|
|
class ForwardOperands : public OpConversionPattern<OpTy> {
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(OpTy op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
if (ValueRange(operands).getTypes() == op->getOperands().getTypes())
|
|
return rewriter.notifyMatchFailure(op, "operand types already match");
|
|
|
|
rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
|
|
public:
|
|
using OpConversionPattern<ReturnOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
|
|
return success();
|
|
}
|
|
};
|
|
|
|
static Optional<Value> addUnrealizedCast(OpBuilder &builder,
|
|
ScalableVectorType svType,
|
|
ValueRange inputs, Location loc) {
|
|
if (inputs.size() != 1 ||
|
|
!inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
|
|
return Value();
|
|
return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
|
|
.getResult(0);
|
|
}
|
|
|
|
using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
|
|
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
|
|
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
|
|
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
|
|
using VectorScaleOpLowering =
|
|
OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
|
|
using ScalableMaskedAddIOpLowering =
|
|
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
|
|
ScalableMaskedAddIIntrOp>;
|
|
using ScalableMaskedAddFOpLowering =
|
|
OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
|
|
ScalableMaskedAddFIntrOp>;
|
|
using ScalableMaskedSubIOpLowering =
|
|
OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
|
|
ScalableMaskedSubIIntrOp>;
|
|
using ScalableMaskedSubFOpLowering =
|
|
OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
|
|
ScalableMaskedSubFIntrOp>;
|
|
using ScalableMaskedMulIOpLowering =
|
|
OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
|
|
ScalableMaskedMulIIntrOp>;
|
|
using ScalableMaskedMulFOpLowering =
|
|
OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
|
|
ScalableMaskedMulFIntrOp>;
|
|
using ScalableMaskedSDivIOpLowering =
|
|
OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
|
|
ScalableMaskedSDivIIntrOp>;
|
|
using ScalableMaskedUDivIOpLowering =
|
|
OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
|
|
ScalableMaskedUDivIIntrOp>;
|
|
using ScalableMaskedDivFOpLowering =
|
|
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
|
|
ScalableMaskedDivFIntrOp>;
|
|
|
|
// Load operation is lowered to code that obtains a pointer to the indexed
|
|
// element and loads from it.
|
|
struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
|
|
using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ScalableLoadOp loadOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = loadOp.getMemRefType();
|
|
if (!isConvertibleAndHasIdentityMaps(type))
|
|
return failure();
|
|
|
|
ScalableLoadOp::Adaptor transformed(operands);
|
|
LLVMTypeConverter converter(loadOp.getContext());
|
|
|
|
auto resultType = loadOp.result().getType();
|
|
LLVM::LLVMPointerType llvmDataTypePtr;
|
|
if (resultType.isa<VectorType>()) {
|
|
llvmDataTypePtr =
|
|
LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
|
|
} else if (resultType.isa<ScalableVectorType>()) {
|
|
llvmDataTypePtr = LLVM::LLVMPointerType::get(
|
|
convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
|
|
converter)
|
|
.getValue());
|
|
}
|
|
Value dataPtr =
|
|
getStridedElementPtr(loadOp.getLoc(), type, transformed.base(),
|
|
transformed.index(), rewriter);
|
|
Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
|
|
loadOp.getLoc(), llvmDataTypePtr, dataPtr);
|
|
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Store operation is lowered to code that obtains a pointer to the indexed
|
|
// element, and stores the given value to it.
|
|
struct ScalableStoreOpLowering
|
|
: public ConvertOpToLLVMPattern<ScalableStoreOp> {
|
|
using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ScalableStoreOp storeOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = storeOp.getMemRefType();
|
|
if (!isConvertibleAndHasIdentityMaps(type))
|
|
return failure();
|
|
|
|
ScalableStoreOp::Adaptor transformed(operands);
|
|
LLVMTypeConverter converter(storeOp.getContext());
|
|
|
|
auto resultType = storeOp.value().getType();
|
|
LLVM::LLVMPointerType llvmDataTypePtr;
|
|
if (resultType.isa<VectorType>()) {
|
|
llvmDataTypePtr =
|
|
LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
|
|
} else if (resultType.isa<ScalableVectorType>()) {
|
|
llvmDataTypePtr = LLVM::LLVMPointerType::get(
|
|
convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
|
|
converter)
|
|
.getValue());
|
|
}
|
|
Value dataPtr =
|
|
getStridedElementPtr(storeOp.getLoc(), type, transformed.base(),
|
|
transformed.index(), rewriter);
|
|
Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
|
|
storeOp.getLoc(), llvmDataTypePtr, dataPtr);
|
|
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, transformed.value(),
|
|
bitCastedPtr);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
static void
|
|
populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
|
|
OwningRewritePatternList &patterns) {
|
|
// clang-format off
|
|
patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
|
|
OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
|
|
OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
|
|
OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
|
|
OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
|
|
OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
|
|
OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
|
|
OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
|
|
OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
|
|
>(converter);
|
|
// clang-format on
|
|
}
|
|
|
|
static void
|
|
configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
|
|
// clang-format off
|
|
target.addIllegalOp<ScalableAddIOp,
|
|
ScalableAddFOp,
|
|
ScalableSubIOp,
|
|
ScalableSubFOp,
|
|
ScalableMulIOp,
|
|
ScalableMulFOp,
|
|
ScalableSDivIOp,
|
|
ScalableUDivIOp,
|
|
ScalableDivFOp>();
|
|
// clang-format on
|
|
}
|
|
|
|
static void
|
|
populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
|
|
OwningRewritePatternList &patterns) {
|
|
// clang-format off
|
|
patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
|
|
OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
|
|
>(converter);
|
|
// clang-format on
|
|
}
|
|
|
|
static void
|
|
configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
|
|
// clang-format off
|
|
target.addIllegalOp<ScalableCmpFOp,
|
|
ScalableCmpIOp>();
|
|
// clang-format on
|
|
}
|
|
|
|
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
|
|
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
|
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
|
// Populate conversion patterns
|
|
// Remove any ArmSVE-specific types from function signatures and results.
|
|
populateFuncOpTypeConversionPattern(patterns, converter);
|
|
converter.addConversion([&converter](ScalableVectorType svType) {
|
|
return convertScalableVectorTypeToLLVM(svType, converter);
|
|
});
|
|
converter.addSourceMaterialization(addUnrealizedCast);
|
|
|
|
// clang-format off
|
|
patterns.add<ForwardOperands<CallOp>,
|
|
ForwardOperands<CallIndirectOp>,
|
|
ForwardOperands<ReturnOp>>(converter,
|
|
&converter.getContext());
|
|
patterns.add<SdotOpLowering,
|
|
SmmlaOpLowering,
|
|
UdotOpLowering,
|
|
UmmlaOpLowering,
|
|
VectorScaleOpLowering,
|
|
ScalableMaskedAddIOpLowering,
|
|
ScalableMaskedAddFOpLowering,
|
|
ScalableMaskedSubIOpLowering,
|
|
ScalableMaskedSubFOpLowering,
|
|
ScalableMaskedMulIOpLowering,
|
|
ScalableMaskedMulFOpLowering,
|
|
ScalableMaskedSDivIOpLowering,
|
|
ScalableMaskedUDivIOpLowering,
|
|
ScalableMaskedDivFOpLowering>(converter);
|
|
patterns.add<ScalableLoadOpLowering,
|
|
ScalableStoreOpLowering>(converter);
|
|
// clang-format on
|
|
populateBasicSVEArithmeticExportPatterns(converter, patterns);
|
|
populateSVEMaskGenerationExportPatterns(converter, patterns);
|
|
}
|
|
|
|
void mlir::configureArmSVELegalizeForExportTarget(
|
|
LLVMConversionTarget &target) {
|
|
// clang-format off
|
|
target.addLegalOp<SdotIntrOp,
|
|
SmmlaIntrOp,
|
|
UdotIntrOp,
|
|
UmmlaIntrOp,
|
|
VectorScaleIntrOp,
|
|
ScalableMaskedAddIIntrOp,
|
|
ScalableMaskedAddFIntrOp,
|
|
ScalableMaskedSubIIntrOp,
|
|
ScalableMaskedSubFIntrOp,
|
|
ScalableMaskedMulIIntrOp,
|
|
ScalableMaskedMulFIntrOp,
|
|
ScalableMaskedSDivIIntrOp,
|
|
ScalableMaskedUDivIIntrOp,
|
|
ScalableMaskedDivFIntrOp>();
|
|
target.addIllegalOp<SdotOp,
|
|
SmmlaOp,
|
|
UdotOp,
|
|
UmmlaOp,
|
|
VectorScaleOp,
|
|
ScalableMaskedAddIOp,
|
|
ScalableMaskedAddFOp,
|
|
ScalableMaskedSubIOp,
|
|
ScalableMaskedSubFOp,
|
|
ScalableMaskedMulIOp,
|
|
ScalableMaskedMulFOp,
|
|
ScalableMaskedSDivIOp,
|
|
ScalableMaskedUDivIOp,
|
|
ScalableMaskedDivFOp,
|
|
ScalableLoadOp,
|
|
ScalableStoreOp>();
|
|
// clang-format on
|
|
auto hasScalableVectorType = [](TypeRange types) {
|
|
for (Type type : types)
|
|
if (type.isa<arm_sve::ScalableVectorType>())
|
|
return true;
|
|
return false;
|
|
};
|
|
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
|
|
return !hasScalableVectorType(op.getType().getInputs()) &&
|
|
!hasScalableVectorType(op.getType().getResults());
|
|
});
|
|
target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
|
|
[hasScalableVectorType](Operation *op) {
|
|
return !hasScalableVectorType(op->getOperandTypes()) &&
|
|
!hasScalableVectorType(op->getResultTypes());
|
|
});
|
|
configureBasicSVEArithmeticLegalizations(target);
|
|
configureSVEMaskGenerationLegalizations(target);
|
|
}
|