1413 lines
56 KiB
C++
1413 lines
56 KiB
C++
//===- StandardToSPIRV.cpp - Standard to SPIR-V Patterns ------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements patterns to convert standard dialect to SPIR-V dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
|
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
|
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "std-to-spirv-pattern"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utility functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns true if the given `type` is a boolean scalar or vector type.
|
|
static bool isBoolScalarOrVector(Type type) {
|
|
if (type.isInteger(1))
|
|
return true;
|
|
if (auto vecType = type.dyn_cast<VectorType>())
|
|
return vecType.getElementType().isInteger(1);
|
|
return false;
|
|
}
|
|
|
|
/// Converts the given `srcAttr` into a boolean attribute if it holds an
|
|
/// integral value. Returns null attribute if conversion fails.
|
|
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
|
|
if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
|
|
return boolAttr;
|
|
if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
|
|
return builder.getBoolAttr(intAttr.getValue().getBoolValue());
|
|
return BoolAttr();
|
|
}
|
|
|
|
/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
|
|
/// Returns null attribute if conversion fails.
|
|
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
|
|
Builder builder) {
|
|
// If the source number uses less active bits than the target bitwidth, then
|
|
// it should be safe to convert.
|
|
if (srcAttr.getValue().isIntN(dstType.getWidth()))
|
|
return builder.getIntegerAttr(dstType, srcAttr.getInt());
|
|
|
|
// XXX: Try again by interpreting the source number as a signed value.
|
|
// Although integers in the standard dialect are signless, they can represent
|
|
// a signed number. It's the operation decides how to interpret. This is
|
|
// dangerous, but it seems there is no good way of handling this if we still
|
|
// want to change the bitwidth. Emit a message at least.
|
|
if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
|
|
auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
|
|
LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
|
|
<< dstAttr << "' for type '" << dstType << "'\n");
|
|
return dstAttr;
|
|
}
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
|
|
<< "' illegal: cannot fit into target type '"
|
|
<< dstType << "'\n");
|
|
return IntegerAttr();
|
|
}
|
|
|
|
/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
|
|
/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
|
|
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
|
|
Builder builder) {
|
|
// Only support converting to float for now.
|
|
if (!dstType.isF32())
|
|
return FloatAttr();
|
|
|
|
// Try to convert the source floating-point number to single precision.
|
|
APFloat dstVal = srcAttr.getValue();
|
|
bool losesInfo = false;
|
|
APFloat::opStatus status =
|
|
dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
|
|
if (status != APFloat::opOK || losesInfo) {
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< srcAttr << " illegal: cannot fit into converted type '"
|
|
<< dstType << "'\n");
|
|
return FloatAttr();
|
|
}
|
|
|
|
return builder.getF32FloatAttr(dstVal.convertToFloat());
|
|
}
|
|
|
|
/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
|
|
/// the sign of `signOperand`.
|
|
///
|
|
/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
|
|
/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
|
|
/// the result is undefined." So we cannot directly use spv.SRem/spv.SMod
|
|
/// if either operand can be negative. Emulate it via spv.UMod.
|
|
static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
|
|
Value signOperand, OpBuilder &builder) {
|
|
assert(lhs.getType() == rhs.getType());
|
|
assert(lhs == signOperand || rhs == signOperand);
|
|
|
|
Type type = lhs.getType();
|
|
|
|
// Calculate the remainder with spv.UMod.
|
|
Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
|
|
Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
|
|
Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
|
|
|
|
// Fix the sign.
|
|
Value isPositive;
|
|
if (lhs == signOperand)
|
|
isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
|
|
else
|
|
isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
|
|
Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
|
|
return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
|
|
}
|
|
|
|
/// Returns the offset of the value in `targetBits` representation.
|
|
///
|
|
/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
|
|
/// It's assumed to be non-negative.
|
|
///
|
|
/// When accessing an element in the array treating as having elements of
|
|
/// `targetBits`, multiple values are loaded in the same time. The method
|
|
/// returns the offset where the `srcIdx` locates in the value. For example, if
|
|
/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
|
|
/// located at (x % 4) * 8. Because there are four elements in one i32, and one
|
|
/// element has 8 bits.
|
|
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
|
|
int targetBits, OpBuilder &builder) {
|
|
assert(targetBits % sourceBits == 0);
|
|
IntegerType targetType = builder.getIntegerType(targetBits);
|
|
IntegerAttr idxAttr =
|
|
builder.getIntegerAttr(targetType, targetBits / sourceBits);
|
|
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
|
|
IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
|
|
auto srcBitsValue =
|
|
builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
|
|
auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
|
|
return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
|
|
}
|
|
|
|
/// Returns an adjusted spirv::AccessChainOp. Based on the
|
|
/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
|
|
/// supported. During conversion if a memref of an unsupported type is used,
|
|
/// load/stores to this memref need to be modified to use a supported higher
|
|
/// bitwidth `targetBits` and extracting the required bits. For an accessing a
|
|
/// 1D array (spv.array or spv.rt_array), the last index is modified to load the
|
|
/// bits needed. The extraction of the actual bits needed are handled
|
|
/// separately. Note that this only works for a 1-D tensor.
|
|
static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
|
|
spirv::AccessChainOp op,
|
|
int sourceBits, int targetBits,
|
|
OpBuilder &builder) {
|
|
assert(targetBits % sourceBits == 0);
|
|
const auto loc = op.getLoc();
|
|
IntegerType targetType = builder.getIntegerType(targetBits);
|
|
IntegerAttr attr =
|
|
builder.getIntegerAttr(targetType, targetBits / sourceBits);
|
|
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
|
|
auto lastDim = op->getOperand(op.getNumOperands() - 1);
|
|
auto indices = llvm::to_vector<4>(op.indices());
|
|
// There are two elements if this is a 1-D tensor.
|
|
assert(indices.size() == 2);
|
|
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
|
|
Type t = typeConverter.convertType(op.component_ptr().getType());
|
|
return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
|
|
}
|
|
|
|
/// Returns the shifted `targetBits`-bit value with the given offset.
|
|
static Value shiftValue(Location loc, Value value, Value offset, Value mask,
|
|
int targetBits, OpBuilder &builder) {
|
|
Type targetType = builder.getIntegerType(targetBits);
|
|
Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
|
|
return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
|
|
offset);
|
|
}
|
|
|
|
/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
|
|
static bool isAllocationSupported(MemRefType t) {
|
|
// Currently only support workgroup local memory allocations with static
|
|
// shape and int or float or vector of int or float element type.
|
|
if (!(t.hasStaticShape() &&
|
|
SPIRVTypeConverter::getMemorySpaceForStorageClass(
|
|
spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt()))
|
|
return false;
|
|
Type elementType = t.getElementType();
|
|
if (auto vecType = elementType.dyn_cast<VectorType>())
|
|
elementType = vecType.getElementType();
|
|
return elementType.isIntOrFloat();
|
|
}
|
|
|
|
/// Returns the scope to use for atomic operations use for emulating store
|
|
/// operations of unsupported integer bitwidths, based on the memref
|
|
/// type. Returns None on failure.
|
|
static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
|
|
Optional<spirv::StorageClass> storageClass =
|
|
SPIRVTypeConverter::getStorageClassForMemorySpace(
|
|
t.getMemorySpaceAsInt());
|
|
if (!storageClass)
|
|
return {};
|
|
switch (*storageClass) {
|
|
case spirv::StorageClass::StorageBuffer:
|
|
return spirv::Scope::Device;
|
|
case spirv::StorageClass::Workgroup:
|
|
return spirv::Scope::Workgroup;
|
|
default: {
|
|
}
|
|
}
|
|
return {};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operation conversion
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Note that DRR cannot be used for the patterns in this file: we may need to
|
|
// convert type along the way, which requires ConversionPattern. DRR generates
|
|
// normal RewritePattern.
|
|
|
|
namespace {
|
|
|
|
/// Converts an allocation operation to SPIR-V. Currently only supports lowering
|
|
/// to Workgroup memory when the size is constant. Note that this pattern needs
|
|
/// to be applied in a pass that runs at least at spv.module scope since it wil
|
|
/// ladd global variables into the spv.module.
|
|
class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
|
|
public:
|
|
using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::AllocOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
MemRefType allocType = operation.getType();
|
|
if (!isAllocationSupported(allocType))
|
|
return operation.emitError("unhandled allocation type");
|
|
|
|
// Get the SPIR-V type for the allocation.
|
|
Type spirvType = getTypeConverter()->convertType(allocType);
|
|
|
|
// Insert spv.GlobalVariable for this allocation.
|
|
Operation *parent =
|
|
SymbolTable::getNearestSymbolTable(operation->getParentOp());
|
|
if (!parent)
|
|
return failure();
|
|
Location loc = operation.getLoc();
|
|
spirv::GlobalVariableOp varOp;
|
|
{
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
Block &entryBlock = *parent->getRegion(0).begin();
|
|
rewriter.setInsertionPointToStart(&entryBlock);
|
|
auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
|
|
std::string varName =
|
|
std::string("__workgroup_mem__") +
|
|
std::to_string(std::distance(varOps.begin(), varOps.end()));
|
|
varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
|
|
/*initializer=*/nullptr);
|
|
}
|
|
|
|
// Get pointer to global variable at the current scope.
|
|
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Removed a deallocation if it is a supported allocation. Currently only
|
|
/// removes deallocation if the memory space is workgroup memory.
|
|
class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
|
|
public:
|
|
using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::DeallocOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
|
|
if (!isAllocationSupported(deallocType))
|
|
return operation.emitError("unhandled deallocation type");
|
|
rewriter.eraseOp(operation);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts unary and binary standard operations to SPIR-V operations.
|
|
template <typename StdOp, typename SPIRVOp>
|
|
class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
|
|
public:
|
|
using OpConversionPattern<StdOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
assert(operands.size() <= 2);
|
|
auto dstType = this->getTypeConverter()->convertType(operation.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
|
|
dstType != operation.getType()) {
|
|
return operation.emitError(
|
|
"bitwidth emulation is not implemented yet on unsigned op");
|
|
}
|
|
rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts math.log1p to SPIR-V ops.
|
|
///
|
|
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
|
|
/// these operations.
|
|
class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
|
|
public:
|
|
using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(math::Log1pOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
assert(operands.size() == 1);
|
|
Location loc = operation.getLoc();
|
|
auto type =
|
|
this->getTypeConverter()->convertType(operation.operand().getType());
|
|
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
|
|
auto onePlus = rewriter.create<spirv::FAddOp>(loc, one, operands[0]);
|
|
rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts std.remi_signed to SPIR-V ops.
|
|
///
|
|
/// This cannot be merged into the template unary/binary pattern due to
|
|
/// Vulkan restrictions over spv.SRem and spv.SMod.
|
|
class SignedRemIOpPattern final : public OpConversionPattern<SignedRemIOp> {
|
|
public:
|
|
using OpConversionPattern<SignedRemIOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SignedRemIOp remOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts bitwise standard operations to SPIR-V operations. This is a special
|
|
/// pattern other than the BinaryOpPatternPattern because if the operands are
|
|
/// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
|
|
/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
|
|
template <typename StdOp, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
|
|
class BitwiseOpPattern final : public OpConversionPattern<StdOp> {
|
|
public:
|
|
using OpConversionPattern<StdOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
assert(operands.size() == 2);
|
|
auto dstType =
|
|
this->getTypeConverter()->convertType(operation.getResult().getType());
|
|
if (!dstType)
|
|
return failure();
|
|
if (isBoolScalarOrVector(operands.front().getType())) {
|
|
rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(operation, dstType,
|
|
operands);
|
|
} else {
|
|
rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(operation, dstType,
|
|
operands);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts composite std.constant operation to spv.Constant.
|
|
class ConstantCompositeOpPattern final
|
|
: public OpConversionPattern<ConstantOp> {
|
|
public:
|
|
using OpConversionPattern<ConstantOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts scalar std.constant operation to spv.Constant.
|
|
class ConstantScalarOpPattern final : public OpConversionPattern<ConstantOp> {
|
|
public:
|
|
using OpConversionPattern<ConstantOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts floating-point comparison operations to SPIR-V ops.
|
|
class CmpFOpPattern final : public OpConversionPattern<CmpFOp> {
|
|
public:
|
|
using OpConversionPattern<CmpFOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts floating point NaN check to SPIR-V ops. This pattern requires
|
|
/// Kernel capability.
|
|
class CmpFOpNanKernelPattern final : public OpConversionPattern<CmpFOp> {
|
|
public:
|
|
using OpConversionPattern<CmpFOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts floating point NaN check to SPIR-V ops. This pattern does not
|
|
/// require additional capability.
|
|
class CmpFOpNanNonePattern final : public OpConversionPattern<CmpFOp> {
|
|
public:
|
|
using OpConversionPattern<CmpFOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts integer compare operation on i1 type operands to SPIR-V ops.
|
|
class BoolCmpIOpPattern final : public OpConversionPattern<CmpIOp> {
|
|
public:
|
|
using OpConversionPattern<CmpIOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts integer compare operation to SPIR-V ops.
|
|
class CmpIOpPattern final : public OpConversionPattern<CmpIOp> {
|
|
public:
|
|
using OpConversionPattern<CmpIOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts memref.load to spv.Load.
|
|
class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
|
|
public:
|
|
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts memref.load to spv.Load.
|
|
class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
|
|
public:
|
|
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts std.return to spv.Return.
|
|
class ReturnOpPattern final : public OpConversionPattern<ReturnOp> {
|
|
public:
|
|
using OpConversionPattern<ReturnOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts std.select to spv.Select.
|
|
class SelectOpPattern final : public OpConversionPattern<SelectOp> {
|
|
public:
|
|
using OpConversionPattern<SelectOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts std.splat to spv.CompositeConstruct.
|
|
class SplatPattern final : public OpConversionPattern<SplatOp> {
|
|
public:
|
|
using OpConversionPattern<SplatOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SplatOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts memref.store to spv.Store on integers.
|
|
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
|
|
public:
|
|
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts memref.store to spv.Store.
|
|
class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
|
|
public:
|
|
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts std.zexti to spv.Select if the type of source is i1 or vector of
|
|
/// i1.
|
|
class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
|
|
public:
|
|
using OpConversionPattern<ZeroExtendIOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto srcType = operands.front().getType();
|
|
if (!isBoolScalarOrVector(srcType))
|
|
return failure();
|
|
|
|
auto dstType =
|
|
this->getTypeConverter()->convertType(op.getResult().getType());
|
|
Location loc = op.getLoc();
|
|
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
|
|
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
|
|
rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
|
|
op, dstType, operands.front(), one, zero);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts tensor.extract into loading using access chains from SPIR-V local
|
|
/// variables.
|
|
class TensorExtractPattern final
|
|
: public OpConversionPattern<tensor::ExtractOp> {
|
|
public:
|
|
TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context,
|
|
int64_t threshold, PatternBenefit benefit = 1)
|
|
: OpConversionPattern(typeConverter, context, benefit),
|
|
byteCountThreshold(threshold) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(tensor::ExtractOp extractOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
TensorType tensorType = extractOp.tensor().getType().cast<TensorType>();
|
|
|
|
if (!tensorType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
|
|
|
|
if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
|
|
byteCountThreshold * 8)
|
|
return rewriter.notifyMatchFailure(extractOp,
|
|
"exceeding byte count threshold");
|
|
|
|
Location loc = extractOp.getLoc();
|
|
tensor::ExtractOp::Adaptor adaptor(operands);
|
|
|
|
int64_t rank = tensorType.getRank();
|
|
SmallVector<int64_t, 4> strides(rank, 1);
|
|
for (int i = rank - 2; i >= 0; --i) {
|
|
strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1);
|
|
}
|
|
|
|
Type varType = spirv::PointerType::get(adaptor.tensor().getType(),
|
|
spirv::StorageClass::Function);
|
|
|
|
spirv::VariableOp varOp;
|
|
if (adaptor.tensor().getDefiningOp<spirv::ConstantOp>()) {
|
|
varOp = rewriter.create<spirv::VariableOp>(
|
|
loc, varType, spirv::StorageClass::Function,
|
|
/*initializer=*/adaptor.tensor());
|
|
} else {
|
|
// Need to store the value to the local variable. It's questionable
|
|
// whether we want to support such case though.
|
|
return failure();
|
|
}
|
|
|
|
Value index = spirv::linearizeIndex(adaptor.indices(), strides,
|
|
/*offset=*/0, loc, rewriter);
|
|
auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
int64_t byteCountThreshold;
|
|
};
|
|
|
|
/// Converts std.trunci to spv.Select if the type of result is i1 or vector of
|
|
/// i1.
|
|
class TruncI1Pattern final : public OpConversionPattern<TruncateIOp> {
|
|
public:
|
|
using OpConversionPattern<TruncateIOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(TruncateIOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto dstType =
|
|
this->getTypeConverter()->convertType(op.getResult().getType());
|
|
if (!isBoolScalarOrVector(dstType))
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
auto srcType = operands.front().getType();
|
|
// Check if (x & 1) == 1.
|
|
Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
|
|
Value maskedSrc =
|
|
rewriter.create<spirv::BitwiseAndOp>(loc, srcType, operands[0], mask);
|
|
Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
|
|
|
|
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
|
|
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
|
|
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of
|
|
/// i1.
|
|
class UIToFPI1Pattern final : public OpConversionPattern<UIToFPOp> {
|
|
public:
|
|
using OpConversionPattern<UIToFPOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(UIToFPOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto srcType = operands.front().getType();
|
|
if (!isBoolScalarOrVector(srcType))
|
|
return failure();
|
|
|
|
auto dstType =
|
|
this->getTypeConverter()->convertType(op.getResult().getType());
|
|
Location loc = op.getLoc();
|
|
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
|
|
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
|
|
rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
|
|
op, dstType, operands.front(), one, zero);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts type-casting standard operations to SPIR-V operations.
|
|
template <typename StdOp, typename SPIRVOp>
|
|
class TypeCastingOpPattern final : public OpConversionPattern<StdOp> {
|
|
public:
|
|
using OpConversionPattern<StdOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
assert(operands.size() == 1);
|
|
auto srcType = operands.front().getType();
|
|
auto dstType =
|
|
this->getTypeConverter()->convertType(operation.getResult().getType());
|
|
if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
|
|
return failure();
|
|
if (dstType == srcType) {
|
|
// Due to type conversion, we are seeing the same source and target type.
|
|
// Then we can just erase this operation by forwarding its operand.
|
|
rewriter.replaceOp(operation, operands.front());
|
|
} else {
|
|
rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
|
|
operands);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts std.xor to SPIR-V operations.
|
|
class XOrOpPattern final : public OpConversionPattern<XOrOp> {
|
|
public:
|
|
using OpConversionPattern<XOrOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector
|
|
/// of i1.
|
|
class BoolXOrOpPattern final : public OpConversionPattern<XOrOp> {
|
|
public:
|
|
using OpConversionPattern<XOrOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SignedRemIOpPattern
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SignedRemIOpPattern::matchAndRewrite(
|
|
SignedRemIOp remOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Value result = emulateSignedRemainder(remOp.getLoc(), operands[0],
|
|
operands[1], operands[0], rewriter);
|
|
rewriter.replaceOp(remOp, result);
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConstantOp with composite type.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TODO: This probably should be split into the vector case and tensor case,
|
|
// so that the tensor case can be moved to TensorToSPIRV conversion. But,
|
|
// std.constant is for the standard dialect though.
|
|
LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
|
|
ConstantOp constOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto srcType = constOp.getType().dyn_cast<ShapedType>();
|
|
if (!srcType)
|
|
return failure();
|
|
|
|
// std.constant should only have vector or tenor types.
|
|
assert((srcType.isa<VectorType, RankedTensorType>()));
|
|
|
|
auto dstType = getTypeConverter()->convertType(srcType);
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
auto dstElementsAttr = constOp.value().dyn_cast<DenseElementsAttr>();
|
|
ShapedType dstAttrType = dstElementsAttr.getType();
|
|
if (!dstElementsAttr)
|
|
return failure();
|
|
|
|
// If the composite type has more than one dimensions, perform linearization.
|
|
if (srcType.getRank() > 1) {
|
|
if (srcType.isa<RankedTensorType>()) {
|
|
dstAttrType = RankedTensorType::get(srcType.getNumElements(),
|
|
srcType.getElementType());
|
|
dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
|
|
} else {
|
|
// TODO: add support for large vectors.
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
Type srcElemType = srcType.getElementType();
|
|
Type dstElemType;
|
|
// Tensor types are converted to SPIR-V array types; vector types are
|
|
// converted to SPIR-V vector/array types.
|
|
if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
|
|
dstElemType = arrayType.getElementType();
|
|
else
|
|
dstElemType = dstType.cast<VectorType>().getElementType();
|
|
|
|
// If the source and destination element types are different, perform
|
|
// attribute conversion.
|
|
if (srcElemType != dstElemType) {
|
|
SmallVector<Attribute, 8> elements;
|
|
if (srcElemType.isa<FloatType>()) {
|
|
for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
|
|
FloatAttr dstAttr = convertFloatAttr(
|
|
srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
|
|
if (!dstAttr)
|
|
return failure();
|
|
elements.push_back(dstAttr);
|
|
}
|
|
} else if (srcElemType.isInteger(1)) {
|
|
return failure();
|
|
} else {
|
|
for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
|
|
IntegerAttr dstAttr =
|
|
convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
|
|
dstElemType.cast<IntegerType>(), rewriter);
|
|
if (!dstAttr)
|
|
return failure();
|
|
elements.push_back(dstAttr);
|
|
}
|
|
}
|
|
|
|
// Unfortunately, we cannot use dialect-specific types for element
|
|
// attributes; element attributes only works with builtin types. So we need
|
|
// to prepare another converted builtin types for the destination elements
|
|
// attribute.
|
|
if (dstAttrType.isa<RankedTensorType>())
|
|
dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
|
|
else
|
|
dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
|
|
|
|
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
|
|
dstElementsAttr);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConstantOp with scalar type.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult ConstantScalarOpPattern::matchAndRewrite(
|
|
ConstantOp constOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Type srcType = constOp.getType();
|
|
if (!srcType.isIntOrIndexOrFloat())
|
|
return failure();
|
|
|
|
Type dstType = getTypeConverter()->convertType(srcType);
|
|
if (!dstType)
|
|
return failure();
|
|
|
|
// Floating-point types.
|
|
if (srcType.isa<FloatType>()) {
|
|
auto srcAttr = constOp.value().cast<FloatAttr>();
|
|
auto dstAttr = srcAttr;
|
|
|
|
// Floating-point types not supported in the target environment are all
|
|
// converted to float type.
|
|
if (srcType != dstType) {
|
|
dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
|
|
if (!dstAttr)
|
|
return failure();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
|
|
return success();
|
|
}
|
|
|
|
// Bool type.
|
|
if (srcType.isInteger(1)) {
|
|
// std.constant can use 0/1 instead of true/false for i1 values. We need to
|
|
// handle that here.
|
|
auto dstAttr = convertBoolAttr(constOp.value(), rewriter);
|
|
if (!dstAttr)
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
|
|
return success();
|
|
}
|
|
|
|
// IndexType or IntegerType. Index values are converted to 32-bit integer
|
|
// values when converting to SPIR-V.
|
|
auto srcAttr = constOp.value().cast<IntegerAttr>();
|
|
auto dstAttr =
|
|
convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
|
|
if (!dstAttr)
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
CmpFOpAdaptor cmpFOpOperands(operands);
|
|
|
|
switch (cmpFOp.getPredicate()) {
|
|
#define DISPATCH(cmpPredicate, spirvOp) \
|
|
case cmpPredicate: \
|
|
rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
|
|
cmpFOpOperands.lhs(), \
|
|
cmpFOpOperands.rhs()); \
|
|
return success();
|
|
|
|
// Ordered.
|
|
DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
|
|
DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
|
|
DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
|
|
DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
|
|
DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
|
|
DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
|
|
// Unordered.
|
|
DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
|
|
DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
|
|
DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
|
|
DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
|
|
DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
|
|
DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
|
|
|
|
#undef DISPATCH
|
|
|
|
default:
|
|
break;
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
|
|
CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
CmpFOpAdaptor cmpFOpOperands(operands);
|
|
|
|
if (cmpFOp.getPredicate() == CmpFPredicate::ORD) {
|
|
rewriter.replaceOpWithNewOp<spirv::OrderedOp>(cmpFOp, cmpFOpOperands.lhs(),
|
|
cmpFOpOperands.rhs());
|
|
return success();
|
|
}
|
|
|
|
if (cmpFOp.getPredicate() == CmpFPredicate::UNO) {
|
|
rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(
|
|
cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs());
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
|
|
CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
if (cmpFOp.getPredicate() != CmpFPredicate::ORD &&
|
|
cmpFOp.getPredicate() != CmpFPredicate::UNO)
|
|
return failure();
|
|
|
|
CmpFOpAdaptor cmpFOpOperands(operands);
|
|
Location loc = cmpFOp.getLoc();
|
|
|
|
Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.lhs());
|
|
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.rhs());
|
|
|
|
Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
|
|
if (cmpFOp.getPredicate() == CmpFPredicate::ORD)
|
|
replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
|
|
|
|
rewriter.replaceOp(cmpFOp, replace);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
CmpIOpAdaptor cmpIOpOperands(operands);
|
|
|
|
Type operandType = cmpIOp.lhs().getType();
|
|
if (!isBoolScalarOrVector(operandType))
|
|
return failure();
|
|
|
|
switch (cmpIOp.getPredicate()) {
|
|
#define DISPATCH(cmpPredicate, spirvOp) \
|
|
case cmpPredicate: \
|
|
rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
|
|
cmpIOpOperands.lhs(), \
|
|
cmpIOpOperands.rhs()); \
|
|
return success();
|
|
|
|
DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp);
|
|
DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp);
|
|
|
|
#undef DISPATCH
|
|
default:;
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
LogicalResult
|
|
CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
CmpIOpAdaptor cmpIOpOperands(operands);
|
|
|
|
Type operandType = cmpIOp.lhs().getType();
|
|
if (isBoolScalarOrVector(operandType))
|
|
return failure();
|
|
|
|
switch (cmpIOp.getPredicate()) {
|
|
#define DISPATCH(cmpPredicate, spirvOp) \
|
|
case cmpPredicate: \
|
|
if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
|
|
operandType != this->getTypeConverter()->convertType(operandType)) { \
|
|
return cmpIOp.emitError( \
|
|
"bitwidth emulation is not implemented yet on unsigned op"); \
|
|
} \
|
|
rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
|
|
cmpIOpOperands.lhs(), \
|
|
cmpIOpOperands.rhs()); \
|
|
return success();
|
|
|
|
DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
|
|
DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
|
|
DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
|
|
DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
|
|
DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
|
|
DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
|
|
DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp);
|
|
DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp);
|
|
DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp);
|
|
DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
|
|
|
|
#undef DISPATCH
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LoadOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
memref::LoadOpAdaptor loadOperands(operands);
|
|
auto loc = loadOp.getLoc();
|
|
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
|
|
if (!memrefType.getElementType().isSignlessInteger())
|
|
return failure();
|
|
|
|
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
|
|
spirv::AccessChainOp accessChainOp =
|
|
spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
|
|
loadOperands.indices(), loc, rewriter);
|
|
|
|
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
|
|
bool isBool = srcBits == 1;
|
|
if (isBool)
|
|
srcBits = typeConverter.getOptions().boolNumBits;
|
|
Type pointeeType = typeConverter.convertType(memrefType)
|
|
.cast<spirv::PointerType>()
|
|
.getPointeeType();
|
|
Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
|
|
Type dstType;
|
|
if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
|
|
dstType = arrayType.getElementType();
|
|
else
|
|
dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
|
|
|
|
int dstBits = dstType.getIntOrFloatBitWidth();
|
|
assert(dstBits % srcBits == 0);
|
|
|
|
// If the rewrited load op has the same bit width, use the loading value
|
|
// directly.
|
|
if (srcBits == dstBits) {
|
|
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
|
|
accessChainOp.getResult());
|
|
return success();
|
|
}
|
|
|
|
// Assume that getElementPtr() works linearizely. If it's a scalar, the method
|
|
// still returns a linearized accessing. If the accessing is not linearized,
|
|
// there will be offset issues.
|
|
assert(accessChainOp.indices().size() == 2);
|
|
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
|
|
srcBits, dstBits, rewriter);
|
|
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
|
|
loc, dstType, adjustedPtr,
|
|
loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
|
|
spirv::attributeName<spirv::MemoryAccess>()),
|
|
loadOp->getAttrOfType<IntegerAttr>("alignment"));
|
|
|
|
// Shift the bits to the rightmost.
|
|
// ____XXXX________ -> ____________XXXX
|
|
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
|
|
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
|
|
Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
|
|
loc, spvLoadOp.getType(), spvLoadOp, offset);
|
|
|
|
// Apply the mask to extract corresponding bits.
|
|
Value mask = rewriter.create<spirv::ConstantOp>(
|
|
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
|
|
result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
|
|
|
|
// Apply sign extension on the loading value unconditionally. The signedness
|
|
// semantic is carried in the operator itself, we relies other pattern to
|
|
// handle the casting.
|
|
IntegerAttr shiftValueAttr =
|
|
rewriter.getIntegerAttr(dstType, dstBits - srcBits);
|
|
Value shiftValue =
|
|
rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
|
|
result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
|
|
shiftValue);
|
|
result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
|
|
shiftValue);
|
|
|
|
if (isBool) {
|
|
dstType = typeConverter.convertType(loadOp.getType());
|
|
mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
|
|
Value isOne = rewriter.create<spirv::IEqualOp>(loc, result, mask);
|
|
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
|
|
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
|
|
result = rewriter.create<spirv::SelectOp>(loc, dstType, isOne, one, zero);
|
|
} else if (result.getType().getIntOrFloatBitWidth() !=
|
|
static_cast<unsigned>(dstBits)) {
|
|
result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
|
|
}
|
|
rewriter.replaceOp(loadOp, result);
|
|
|
|
assert(accessChainOp.use_empty());
|
|
rewriter.eraseOp(accessChainOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
memref::LoadOpAdaptor loadOperands(operands);
|
|
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
|
|
if (memrefType.getElementType().isSignlessInteger())
|
|
return failure();
|
|
auto loadPtr = spirv::getElementPtr(
|
|
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
|
|
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
|
|
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReturnOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
if (returnOp.getNumOperands() > 1)
|
|
return failure();
|
|
|
|
if (returnOp.getNumOperands() == 1) {
|
|
rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(returnOp, operands[0]);
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SelectOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
SelectOpAdaptor selectOperands(operands);
|
|
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
|
|
selectOperands.true_value(),
|
|
selectOperands.false_value());
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SplatOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
SplatPattern::matchAndRewrite(SplatOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto dstVecType = op.getType().dyn_cast<VectorType>();
|
|
if (!dstVecType || !spirv::CompositeType::isValid(dstVecType))
|
|
return failure();
|
|
SplatOp::Adaptor adaptor(operands);
|
|
SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input());
|
|
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
|
|
source);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// StoreOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
memref::StoreOpAdaptor storeOperands(operands);
|
|
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
|
|
if (!memrefType.getElementType().isSignlessInteger())
|
|
return failure();
|
|
|
|
auto loc = storeOp.getLoc();
|
|
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
|
|
spirv::AccessChainOp accessChainOp =
|
|
spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
|
|
storeOperands.indices(), loc, rewriter);
|
|
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
|
|
|
|
bool isBool = srcBits == 1;
|
|
if (isBool)
|
|
srcBits = typeConverter.getOptions().boolNumBits;
|
|
Type pointeeType = typeConverter.convertType(memrefType)
|
|
.cast<spirv::PointerType>()
|
|
.getPointeeType();
|
|
Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
|
|
Type dstType;
|
|
if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
|
|
dstType = arrayType.getElementType();
|
|
else
|
|
dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
|
|
|
|
int dstBits = dstType.getIntOrFloatBitWidth();
|
|
assert(dstBits % srcBits == 0);
|
|
|
|
if (srcBits == dstBits) {
|
|
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
|
|
storeOp, accessChainOp.getResult(), storeOperands.value());
|
|
return success();
|
|
}
|
|
|
|
// Since there are multi threads in the processing, the emulation will be done
|
|
// with atomic operations. E.g., if the storing value is i8, rewrite the
|
|
// StoreOp to
|
|
// 1) load a 32-bit integer
|
|
// 2) clear 8 bits in the loading value
|
|
// 3) store 32-bit value back
|
|
// 4) load a 32-bit integer
|
|
// 5) modify 8 bits in the loading value
|
|
// 6) store 32-bit value back
|
|
// The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
|
|
// 4 to step 6 are done by AtomicOr as another atomic step.
|
|
assert(accessChainOp.indices().size() == 2);
|
|
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
|
|
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
|
|
|
|
// Create a mask to clear the destination. E.g., if it is the second i8 in
|
|
// i32, 0xFFFF00FF is created.
|
|
Value mask = rewriter.create<spirv::ConstantOp>(
|
|
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
|
|
Value clearBitsMask =
|
|
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
|
|
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
|
|
|
|
Value storeVal = storeOperands.value();
|
|
if (isBool) {
|
|
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
|
|
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
|
|
storeVal =
|
|
rewriter.create<spirv::SelectOp>(loc, dstType, storeVal, one, zero);
|
|
}
|
|
storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
|
|
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
|
|
srcBits, dstBits, rewriter);
|
|
Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
|
|
if (!scope)
|
|
return failure();
|
|
Value result = rewriter.create<spirv::AtomicAndOp>(
|
|
loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
|
|
clearBitsMask);
|
|
result = rewriter.create<spirv::AtomicOrOp>(
|
|
loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
|
|
storeVal);
|
|
|
|
// The AtomicOrOp has no side effect. Since it is already inserted, we can
|
|
// just remove the original StoreOp. Note that rewriter.replaceOp()
|
|
// doesn't work because it only accepts that the numbers of result are the
|
|
// same.
|
|
rewriter.eraseOp(storeOp);
|
|
|
|
assert(accessChainOp.use_empty());
|
|
rewriter.eraseOp(accessChainOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
memref::StoreOpAdaptor storeOperands(operands);
|
|
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
|
|
if (memrefType.getElementType().isSignlessInteger())
|
|
return failure();
|
|
auto storePtr =
|
|
spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
|
|
storeOperands.memref(), storeOperands.indices(),
|
|
storeOp.getLoc(), rewriter);
|
|
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
|
|
storeOperands.value());
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// XorOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult
|
|
XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
assert(operands.size() == 2);
|
|
|
|
if (isBoolScalarOrVector(operands.front().getType()))
|
|
return failure();
|
|
|
|
auto dstType = getTypeConverter()->convertType(xorOp.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
assert(operands.size() == 2);
|
|
|
|
if (!isBoolScalarOrVector(operands.front().getType()))
|
|
return failure();
|
|
|
|
auto dstType = getTypeConverter()->convertType(xorOp.getType());
|
|
if (!dstType)
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(xorOp, dstType,
|
|
operands);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pattern population
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace mlir {
|
|
void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
patterns.add<
|
|
// Math dialect operations.
|
|
// TODO: Move to separate pass.
|
|
UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
|
|
UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
|
|
UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
|
|
UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
|
|
UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
|
|
UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
|
|
UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
|
|
UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
|
|
|
|
// Unary and binary patterns
|
|
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
|
|
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
|
|
UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
|
|
UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
|
|
UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
|
|
UnaryAndBinaryOpPattern<CeilFOp, spirv::GLSLCeilOp>,
|
|
UnaryAndBinaryOpPattern<DivFOp, spirv::FDivOp>,
|
|
UnaryAndBinaryOpPattern<FloorFOp, spirv::GLSLFloorOp>,
|
|
UnaryAndBinaryOpPattern<MulFOp, spirv::FMulOp>,
|
|
UnaryAndBinaryOpPattern<MulIOp, spirv::IMulOp>,
|
|
UnaryAndBinaryOpPattern<NegFOp, spirv::FNegateOp>,
|
|
UnaryAndBinaryOpPattern<RemFOp, spirv::FRemOp>,
|
|
UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
|
|
UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
|
|
UnaryAndBinaryOpPattern<SignedShiftRightOp,
|
|
spirv::ShiftRightArithmeticOp>,
|
|
UnaryAndBinaryOpPattern<SubIOp, spirv::ISubOp>,
|
|
UnaryAndBinaryOpPattern<SubFOp, spirv::FSubOp>,
|
|
UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
|
|
UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
|
|
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
|
|
Log1pOpPattern, SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern,
|
|
|
|
// Comparison patterns
|
|
BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,
|
|
|
|
// Constant patterns
|
|
ConstantCompositeOpPattern, ConstantScalarOpPattern,
|
|
|
|
// Memory patterns
|
|
AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
|
|
LoadOpPattern, StoreOpPattern,
|
|
|
|
ReturnOpPattern, SelectOpPattern, SplatPattern,
|
|
|
|
// Type cast patterns
|
|
UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern,
|
|
TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
|
|
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
|
|
TypeCastingOpPattern<UIToFPOp, spirv::ConvertUToFOp>,
|
|
TypeCastingOpPattern<SignExtendIOp, spirv::SConvertOp>,
|
|
TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
|
|
TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
|
|
TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,
|
|
TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
|
|
TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(typeConverter,
|
|
context);
|
|
|
|
// Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
|
|
// capability is available.
|
|
patterns.add<CmpFOpNanKernelPattern>(typeConverter, context,
|
|
/*benefit=*/2);
|
|
}
|
|
|
|
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|
int64_t byteCountThreshold,
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<TensorExtractPattern>(typeConverter, patterns.getContext(),
|
|
byteCountThreshold);
|
|
}
|
|
|
|
} // namespace mlir
|