circt/lib/Dialect/FIRRTL/FIRRTLFolds.cpp

3291 lines
117 KiB
C++

//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implement the folding and canonicalizations for FIRRTL ops.
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/FIRRTL/FIRRTLAttributes.h"
#include "circt/Dialect/FIRRTL/FIRRTLOps.h"
#include "circt/Dialect/FIRRTL/FIRRTLTypes.h"
#include "circt/Dialect/FIRRTL/FIRRTLUtils.h"
#include "circt/Support/APInt.h"
#include "circt/Support/LLVM.h"
#include "circt/Support/Naming.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace circt;
using namespace firrtl;
// Drop writes to old and pass through passthrough to make patterns easier to
// write.
static Value dropWrite(PatternRewriter &rewriter, OpResult old,
Value passthrough) {
SmallPtrSet<Operation *, 8> users;
for (auto *user : old.getUsers())
users.insert(user);
for (Operation *user : users)
if (auto connect = dyn_cast<FConnectLike>(user))
if (connect.getDest() == old)
rewriter.eraseOp(user);
return passthrough;
}
// Move a name hint from a soon to be deleted operation to a new operation.
// Pass through the new operation to make patterns easier to write. This cannot
// move a name to a port (block argument), doing so would require rewriting all
// instance sites as well as the module.
static Value moveNameHint(OpResult old, Value passthrough) {
Operation *op = passthrough.getDefiningOp();
// This should handle ports, but it isn't clear we can change those in
// canonicalizers.
assert(op && "passthrough must be an operation");
Operation *oldOp = old.getOwner();
auto name = oldOp->getAttrOfType<StringAttr>("name");
if (name && !name.getValue().empty())
op->setAttr("name", name);
return passthrough;
}
// Declarative canonicalization patterns
namespace circt {
namespace firrtl {
namespace patterns {
#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
} // namespace patterns
} // namespace firrtl
} // namespace circt
/// Return true if this operation's operands and results all have a known width.
/// This only works for integer types.
static bool hasKnownWidthIntTypes(Operation *op) {
auto resultType = type_cast<IntType>(op->getResult(0).getType());
if (!resultType.hasWidth())
return false;
for (Value operand : op->getOperands())
if (!type_cast<IntType>(operand.getType()).hasWidth())
return false;
return true;
}
/// Return true if this value is 1 bit UInt.
static bool isUInt1(Type type) {
auto t = type_dyn_cast<UIntType>(type);
if (!t || !t.hasWidth() || t.getWidth() != 1)
return false;
return true;
}
/// Set the name of an op based on the best of two names: The current name, and
/// the name passed in.
static void updateName(PatternRewriter &rewriter, Operation *op,
StringAttr name) {
// Should never rename InstanceOp
assert(!isa<InstanceOp>(op));
if (!name || name.getValue().empty())
return;
auto newName = name.getValue(); // old name is interesting
auto newOpName = op->getAttrOfType<StringAttr>("name");
// new name might not be interesting
if (newOpName)
newName = chooseName(newOpName.getValue(), name.getValue());
// Only update if needed
if (!newOpName || newOpName.getValue() != newName)
rewriter.modifyOpInPlace(
op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
}
/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
/// If a replaced op has a "name" attribute, this function propagates the name
/// to the new value.
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
Value newValue) {
if (auto *newOp = newValue.getDefiningOp()) {
auto name = op->getAttrOfType<StringAttr>("name");
updateName(rewriter, newOp, name);
}
rewriter.replaceOp(op, newValue);
}
/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
/// attribute. If a replaced op has a "name" attribute, this function propagates
/// the name to the new value.
template <typename OpTy, typename... Args>
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
Operation *op, Args &&...args) {
auto name = op->getAttrOfType<StringAttr>("name");
auto newOp =
rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
updateName(rewriter, newOp, name);
return newOp;
}
/// Return true if the name is droppable. Note that this is different from
/// `isUselessName` because non-useless names may be also droppable.
bool circt::firrtl::hasDroppableName(Operation *op) {
if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
return namableOp.hasDroppableName();
return false;
}
/// Implicitly replace the operand to a constant folding operation with a const
/// 0 in case the operand is non-constant but has a bit width 0, or if the
/// operand is an invalid value.
///
/// This makes constant folding significantly easier, as we can simply pass the
/// operands to an operation through this function to appropriately replace any
/// zero-width dynamic values or invalid values with a constant of value 0.
static std::optional<APSInt>
getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
assert(type_cast<IntType>(operand.getType()) &&
"getExtendedConstant is limited to integer types");
// We never support constant folding to unknown width values.
if (destWidth < 0)
return {};
// Extension signedness follows the operand sign.
if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
// If the operand is zero bits, then we can return a zero of the result
// type.
if (type_cast<IntType>(operand.getType()).getWidth() == 0)
return APSInt(destWidth,
type_cast<IntType>(operand.getType()).isUnsigned());
return {};
}
/// Determine the value of a constant operand for the sake of constant folding.
static std::optional<APSInt> getConstant(Attribute operand) {
if (!operand)
return {};
if (auto attr = dyn_cast<BoolAttr>(operand))
return APSInt(APInt(1, attr.getValue()));
if (auto attr = dyn_cast<IntegerAttr>(operand))
return attr.getAPSInt();
return {};
}
/// Determine whether a constant operand is a zero value for the sake of
/// constant folding. This considers `invalidvalue` to be zero.
static bool isConstantZero(Attribute operand) {
if (auto cst = getConstant(operand))
return cst->isZero();
return false;
}
/// Determine whether a constant operand is a one value for the sake of constant
/// folding.
static bool isConstantOne(Attribute operand) {
if (auto cst = getConstant(operand))
return cst->isOne();
return false;
}
/// This is the policy for folding, which depends on the sort of operator we're
/// processing.
enum class BinOpKind {
Normal,
Compare,
DivideOrShift,
};
/// Applies the constant folding function `calculate` to the given operands.
///
/// Sign or zero extends the operands appropriately to the bitwidth of the
/// result type if \p useDstWidth is true, else to the larger of the two operand
/// bit widths and depending on whether the operation is to be performed on
/// signed or unsigned operands.
static Attribute constFoldFIRRTLBinaryOp(
Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
// We cannot fold something to an unknown width.
auto resultType = type_cast<IntType>(op->getResult(0).getType());
if (resultType.getWidthOrSentinel() < 0)
return {};
// Any binary op returning i0 is 0.
if (resultType.getWidthOrSentinel() == 0)
return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
// Determine the operand widths. This is either dictated by the operand type,
// or if that type is an unsized integer, by the actual bits necessary to
// represent the constant value.
auto lhsWidth =
type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
auto rhsWidth =
type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
// Compares extend the operands to the widest of the operand types, not to the
// result type.
int32_t operandWidth;
switch (opKind) {
case BinOpKind::Normal:
operandWidth = resultType.getWidthOrSentinel();
break;
case BinOpKind::Compare:
// Compares compute with the widest operand, not at the destination type
// (which is always i1).
operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
break;
case BinOpKind::DivideOrShift:
operandWidth =
std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
break;
}
auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
if (!lhs)
return {};
auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
if (!rhs)
return {};
APInt resultValue = calculate(*lhs, *rhs);
// If the result type is smaller than the computation then we need to
// narrow the constant after the calculation.
if (opKind == BinOpKind::DivideOrShift)
resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
assert((unsigned)resultType.getWidthOrSentinel() ==
resultValue.getBitWidth());
return getIntAttr(resultType, resultValue);
}
/// Applies the canonicalization function `canonicalize` to the given operation.
///
/// Determines which (if any) of the operation's operands are constants, and
/// provides them as arguments to the callback function. Any `invalidvalue` in
/// the input is mapped to a constant zero. The value returned from the callback
/// is used as the replacement for `op`, and an additional pad operation is
/// inserted if necessary. Does nothing if the result of `op` is of unknown
/// width, in which case the necessity of a pad cannot be determined.
static LogicalResult canonicalizePrimOp(
Operation *op, PatternRewriter &rewriter,
const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
// Can only operate on FIRRTL primitive operations.
if (op->getNumResults() != 1)
return failure();
auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
if (!type)
return failure();
// Can only operate on operations with a known result width.
auto width = type.getBitWidthOrSentinel();
if (width < 0)
return failure();
// Determine which of the operands are constants.
SmallVector<Attribute, 3> constOperands;
constOperands.reserve(op->getNumOperands());
for (auto operand : op->getOperands()) {
Attribute attr;
if (auto *defOp = operand.getDefiningOp())
TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
[&](auto op) { attr = op.getValueAttr(); });
constOperands.push_back(attr);
}
// Perform the canonicalization and materialize the result if it is a
// constant.
auto result = canonicalize(constOperands);
if (!result)
return failure();
Value resultValue;
if (auto cst = dyn_cast<Attribute>(result))
resultValue = op->getDialect()
->materializeConstant(rewriter, cst, type, op->getLoc())
->getResult(0);
else
resultValue = result.get<Value>();
// Insert a pad if the type widths disagree.
if (width !=
type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue, width);
// Insert a cast if this is a uint vs. sint or vice versa.
if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
else if (type_isa<UIntType>(type) &&
type_isa<SIntType>(resultValue.getType()))
resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
assert(type == resultValue.getType() && "canonicalization changed type");
replaceOpAndCopyName(rewriter, op, resultValue);
return success();
}
/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
/// value if `bitWidth` is 0.
static APInt getMaxUnsignedValue(unsigned bitWidth) {
return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
}
/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
/// value if `bitWidth` is 0.
static APInt getMinSignedValue(unsigned bitWidth) {
return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
}
/// Get the largest signed value of a given bit width. Returns a 1-bit zero
/// value if `bitWidth` is 0.
static APInt getMaxSignedValue(unsigned bitWidth) {
return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
}
//===----------------------------------------------------------------------===//
// Fold Hooks
//===----------------------------------------------------------------------===//
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
assert(adaptor.getOperands().empty() && "constant has no operands");
return getValueAttr();
}
OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
assert(adaptor.getOperands().empty() && "constant has no operands");
return getValueAttr();
}
OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
assert(adaptor.getOperands().empty() && "constant has no operands");
return getFieldsAttr();
}
OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
assert(adaptor.getOperands().empty() && "constant has no operands");
return getValueAttr();
}
OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
assert(adaptor.getOperands().empty() && "constant has no operands");
return getValueAttr();
}
OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
assert(adaptor.getOperands().empty() && "constant has no operands");
return getValueAttr();
}
OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
assert(adaptor.getOperands().empty() && "constant has no operands");
return getValueAttr();
}
//===----------------------------------------------------------------------===//
// Binary Operators
//===----------------------------------------------------------------------===//
OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Normal,
[=](const APSInt &a, const APSInt &b) { return a + b; });
}
void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::moveConstAdd, patterns::AddOfZero,
patterns::AddOfSelf, patterns::AddOfPad>(context);
}
OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Normal,
[=](const APSInt &a, const APSInt &b) { return a - b; });
}
void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
patterns::SubOfPadL, patterns::SubOfPadR>(context);
}
OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
// mul(x, 0) -> 0
//
// This is legal because it aligns with the Scala FIRRTL Compiler
// interpretation of lowering invalid to constant zero before constant
// propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
// multiplication this way and will emit "x * 0".
if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
return getIntZerosAttr(getType());
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Normal,
[=](const APSInt &a, const APSInt &b) { return a * b; });
}
OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
/// div(x, x) -> 1
///
/// Division by zero is undefined in the FIRRTL specification. This fold
/// exploits that fact to optimize self division to one. Note: this should
/// supersede any division with invalid or zero. Division of invalid by
/// invalid should be one.
if (getLhs() == getRhs()) {
auto width = getType().get().getWidthOrSentinel();
if (width == -1)
width = 2;
// Only fold if we have at least 1 bit of width to represent the `1` value.
if (width != 0)
return getIntAttr(getType(), APInt(width, 1));
}
// div(0, x) -> 0
//
// This is legal because it aligns with the Scala FIRRTL Compiler
// interpretation of lowering invalid to constant zero before constant
// propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
// division this way and will emit "0 / x".
if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
return getIntZerosAttr(getType());
/// div(x, 1) -> x : (uint, uint) -> uint
///
/// UInt division by one returns the numerator. SInt division can't
/// be folded here because it increases the return type bitwidth by
/// one and requires sign extension (a new op).
if (auto rhsCst = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
return getLhs();
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::DivideOrShift,
[=](const APSInt &a, const APSInt &b) -> APInt {
if (!!b)
return a / b;
return APInt(a.getBitWidth(), 0);
});
}
OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
// rem(x, x) -> 0
//
// Division by zero is undefined in the FIRRTL specification. This fold
// exploits that fact to optimize self division remainder to zero. Note:
// this should supersede any division with invalid or zero. Remainder of
// division of invalid by invalid should be zero.
if (getLhs() == getRhs())
return getIntZerosAttr(getType());
// rem(0, x) -> 0
//
// This is legal because it aligns with the Scala FIRRTL Compiler
// interpretation of lowering invalid to constant zero before constant
// propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
// division this way and will emit "0 % x".
if (isConstantZero(adaptor.getLhs()))
return getIntZerosAttr(getType());
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::DivideOrShift,
[=](const APSInt &a, const APSInt &b) -> APInt {
if (!!b)
return a % b;
return APInt(a.getBitWidth(), 0);
});
}
OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::DivideOrShift,
[=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
}
OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::DivideOrShift,
[=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
}
OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::DivideOrShift,
[=](const APSInt &a, const APSInt &b) -> APInt {
return getType().get().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
: a.ashr(b);
});
}
// TODO: Move to DRR.
OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
if (auto rhsCst = getConstant(adaptor.getRhs())) {
/// and(x, 0) -> 0, 0 is largest or is implicit zero extended
if (rhsCst->isZero())
return getIntZerosAttr(getType());
/// and(x, -1) -> x
if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
getRhs().getType() == getType())
return getLhs();
}
if (auto lhsCst = getConstant(adaptor.getLhs())) {
/// and(0, x) -> 0, 0 is largest or is implicit zero extended
if (lhsCst->isZero())
return getIntZerosAttr(getType());
/// and(-1, x) -> x
if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
getRhs().getType() == getType())
return getRhs();
}
/// and(x, x) -> x
if (getLhs() == getRhs() && getRhs().getType() == getType())
return getRhs();
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Normal,
[](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
}
void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
}
OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
if (auto rhsCst = getConstant(adaptor.getRhs())) {
/// or(x, 0) -> x
if (rhsCst->isZero() && getLhs().getType() == getType())
return getLhs();
/// or(x, -1) -> -1
if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
getLhs().getType() == getType())
return getRhs();
}
if (auto lhsCst = getConstant(adaptor.getLhs())) {
/// or(0, x) -> x
if (lhsCst->isZero() && getRhs().getType() == getType())
return getRhs();
/// or(-1, x) -> -1
if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
getRhs().getType() == getType())
return getLhs();
}
/// or(x, x) -> x
if (getLhs() == getRhs() && getRhs().getType() == getType())
return getRhs();
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Normal,
[](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
}
void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad>(
context);
}
OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
/// xor(x, 0) -> x
if (auto rhsCst = getConstant(adaptor.getRhs()))
if (rhsCst->isZero() &&
firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
return getLhs();
/// xor(x, 0) -> x
if (auto lhsCst = getConstant(adaptor.getLhs()))
if (lhsCst->isZero() &&
firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
return getRhs();
/// xor(x, x) -> 0
if (getLhs() == getRhs())
return getIntAttr(
getType(), APInt(std::max(getType().get().getWidthOrSentinel(), 0), 0));
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Normal,
[](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
}
void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::extendXor, patterns::moveConstXor,
patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
context);
}
void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::LEQWithConstLHS>(context);
}
OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
bool isUnsigned = getLhs().getType().get().isUnsigned();
// leq(x, x) -> 1
if (getLhs() == getRhs())
return getIntAttr(getType(), APInt(1, 1));
// Comparison against constant outside type bounds.
if (auto width = getLhs().getType().get().getWidth()) {
if (auto rhsCst = getConstant(adaptor.getRhs())) {
auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
commonWidth = std::max(commonWidth, 1);
// leq(x, const) -> 0 where const < minValue of the unsigned type of x
// This can never occur since const is unsigned and cannot be less than 0.
// leq(x, const) -> 0 where const < minValue of the signed type of x
if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
.slt(getMinSignedValue(*width).sext(commonWidth)))
return getIntAttr(getType(), APInt(1, 0));
// leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
if (isUnsigned && rhsCst->zext(commonWidth)
.uge(getMaxUnsignedValue(*width).zext(commonWidth)))
return getIntAttr(getType(), APInt(1, 1));
// leq(x, const) -> 1 where const >= maxValue of the signed type of x
if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
.sge(getMaxSignedValue(*width).sext(commonWidth)))
return getIntAttr(getType(), APInt(1, 1));
}
}
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Compare,
[=](const APSInt &a, const APSInt &b) -> APInt {
return APInt(1, a <= b);
});
}
void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::LTWithConstLHS>(context);
}
OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
IntType lhsType = getLhs().getType();
bool isUnsigned = lhsType.isUnsigned();
// lt(x, x) -> 0
if (getLhs() == getRhs())
return getIntAttr(getType(), APInt(1, 0));
// lt(x, 0) -> 0 when x is unsigned
if (auto rhsCst = getConstant(adaptor.getRhs())) {
if (rhsCst->isZero() && lhsType.isUnsigned())
return getIntAttr(getType(), APInt(1, 0));
}
// Comparison against constant outside type bounds.
if (auto width = lhsType.getWidth()) {
if (auto rhsCst = getConstant(adaptor.getRhs())) {
auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
commonWidth = std::max(commonWidth, 1);
// lt(x, const) -> 0 where const <= minValue of the unsigned type of x
// Handled explicitly above.
// lt(x, const) -> 0 where const <= minValue of the signed type of x
if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
.sle(getMinSignedValue(*width).sext(commonWidth)))
return getIntAttr(getType(), APInt(1, 0));
// lt(x, const) -> 1 where const > maxValue of the unsigned type of x
if (isUnsigned && rhsCst->zext(commonWidth)
.ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
return getIntAttr(getType(), APInt(1, 1));
// lt(x, const) -> 1 where const > maxValue of the signed type of x
if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
.sgt(getMaxSignedValue(*width).sext(commonWidth)))
return getIntAttr(getType(), APInt(1, 1));
}
}
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Compare,
[=](const APSInt &a, const APSInt &b) -> APInt {
return APInt(1, a < b);
});
}
void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::GEQWithConstLHS>(context);
}
OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
IntType lhsType = getLhs().getType();
bool isUnsigned = lhsType.isUnsigned();
// geq(x, x) -> 1
if (getLhs() == getRhs())
return getIntAttr(getType(), APInt(1, 1));
// geq(x, 0) -> 1 when x is unsigned
if (auto rhsCst = getConstant(adaptor.getRhs())) {
if (rhsCst->isZero() && isUnsigned)
return getIntAttr(getType(), APInt(1, 1));
}
// Comparison against constant outside type bounds.
if (auto width = lhsType.getWidth()) {
if (auto rhsCst = getConstant(adaptor.getRhs())) {
auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
commonWidth = std::max(commonWidth, 1);
// geq(x, const) -> 0 where const > maxValue of the unsigned type of x
if (isUnsigned && rhsCst->zext(commonWidth)
.ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
return getIntAttr(getType(), APInt(1, 0));
// geq(x, const) -> 0 where const > maxValue of the signed type of x
if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
.sgt(getMaxSignedValue(*width).sext(commonWidth)))
return getIntAttr(getType(), APInt(1, 0));
// geq(x, const) -> 1 where const <= minValue of the unsigned type of x
// Handled explicitly above.
// geq(x, const) -> 1 where const <= minValue of the signed type of x
if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
.sle(getMinSignedValue(*width).sext(commonWidth)))
return getIntAttr(getType(), APInt(1, 1));
}
}
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Compare,
[=](const APSInt &a, const APSInt &b) -> APInt {
return APInt(1, a >= b);
});
}
void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::GTWithConstLHS>(context);
}
OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
IntType lhsType = getLhs().getType();
bool isUnsigned = lhsType.isUnsigned();
// gt(x, x) -> 0
if (getLhs() == getRhs())
return getIntAttr(getType(), APInt(1, 0));
// Comparison against constant outside type bounds.
if (auto width = lhsType.getWidth()) {
if (auto rhsCst = getConstant(adaptor.getRhs())) {
auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
commonWidth = std::max(commonWidth, 1);
// gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
if (isUnsigned && rhsCst->zext(commonWidth)
.uge(getMaxUnsignedValue(*width).zext(commonWidth)))
return getIntAttr(getType(), APInt(1, 0));
// gt(x, const) -> 0 where const >= maxValue of the signed type of x
if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
.sge(getMaxSignedValue(*width).sext(commonWidth)))
return getIntAttr(getType(), APInt(1, 0));
// gt(x, const) -> 1 where const < minValue of the unsigned type of x
// This can never occur since const is unsigned and cannot be less than 0.
// gt(x, const) -> 1 where const < minValue of the signed type of x
if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
.slt(getMinSignedValue(*width).sext(commonWidth)))
return getIntAttr(getType(), APInt(1, 1));
}
}
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Compare,
[=](const APSInt &a, const APSInt &b) -> APInt {
return APInt(1, a > b);
});
}
OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
// eq(x, x) -> 1
if (getLhs() == getRhs())
return getIntAttr(getType(), APInt(1, 1));
if (auto rhsCst = getConstant(adaptor.getRhs())) {
/// eq(x, 1) -> x when x is 1 bit.
/// TODO: Support SInt<1> on the LHS etc.
if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
getRhs().getType() == getType())
return getLhs();
}
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Compare,
[=](const APSInt &a, const APSInt &b) -> APInt {
return APInt(1, a == b);
});
}
LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
return canonicalizePrimOp(
op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
if (auto rhsCst = getConstant(operands[1])) {
auto width = op.getLhs().getType().getBitWidthOrSentinel();
// eq(x, 0) -> not(x) when x is 1 bit.
if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
op.getRhs().getType() == op.getType()) {
return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
.getResult();
}
// eq(x, 0) -> not(orr(x)) when x is >1 bit
if (rhsCst->isZero() && width > 1) {
auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
}
// eq(x, ~0) -> andr(x) when x is >1 bit
if (rhsCst->isAllOnes() && width > 1 &&
op.getLhs().getType() == op.getRhs().getType()) {
return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
.getResult();
}
}
return {};
});
}
OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
// neq(x, x) -> 0
if (getLhs() == getRhs())
return getIntAttr(getType(), APInt(1, 0));
if (auto rhsCst = getConstant(adaptor.getRhs())) {
/// neq(x, 0) -> x when x is 1 bit.
/// TODO: Support SInt<1> on the LHS etc.
if (rhsCst->isZero() && getLhs().getType() == getType() &&
getRhs().getType() == getType())
return getLhs();
}
return constFoldFIRRTLBinaryOp(
*this, adaptor.getOperands(), BinOpKind::Compare,
[=](const APSInt &a, const APSInt &b) -> APInt {
return APInt(1, a != b);
});
}
LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
return canonicalizePrimOp(
op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
if (auto rhsCst = getConstant(operands[1])) {
auto width = op.getLhs().getType().getBitWidthOrSentinel();
// neq(x, 1) -> not(x) when x is 1 bit
if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
op.getRhs().getType() == op.getType()) {
return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
.getResult();
}
// neq(x, 0) -> orr(x) when x is >1 bit
if (rhsCst->isZero() && width > 1) {
return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
.getResult();
}
// neq(x, ~0) -> not(andr(x))) when x is >1 bit
if (rhsCst->isAllOnes() && width > 1 &&
op.getLhs().getType() == op.getRhs().getType()) {
auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
}
}
return {};
});
}
//===----------------------------------------------------------------------===//
// Unary Operators
//===----------------------------------------------------------------------===//
OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
auto base = getInput().getType();
auto w = base.getBitWidthOrSentinel();
if (w >= 0)
return getIntAttr(getType(), APInt(32, w));
return {};
}
OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
// No constant can be 'x' by definition.
if (auto cst = getConstant(adaptor.getArg()))
return getIntAttr(getType(), APInt(1, 0));
return {};
}
OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
// No effect.
if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
return getInput();
// Be careful to only fold the cast into the constant if the size is known.
// Otherwise width inference may produce differently-sized constants if the
// sign changes.
if (getType().get().hasWidth())
if (auto cst = getConstant(adaptor.getInput()))
return getIntAttr(getType(), *cst);
return {};
}
void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::StoUtoS>(context);
}
OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
// No effect.
if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
return getInput();
// Be careful to only fold the cast into the constant if the size is known.
// Otherwise width inference may produce differently-sized constants if the
// sign changes.
if (getType().get().hasWidth())
if (auto cst = getConstant(adaptor.getInput()))
return getIntAttr(getType(), *cst);
return {};
}
void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::UtoStoU>(context);
}
OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
// No effect.
if (getInput().getType() == getType())
return getInput();
// Constant fold.
if (auto cst = getConstant(adaptor.getInput()))
return BoolAttr::get(getContext(), cst->getBoolValue());
return {};
}
OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
// No effect.
if (getInput().getType() == getType())
return getInput();
// Constant fold.
if (auto cst = getConstant(adaptor.getInput()))
return BoolAttr::get(getContext(), cst->getBoolValue());
return {};
}
OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
if (!hasKnownWidthIntTypes(*this))
return {};
// Signed to signed is a noop, unsigned operands prepend a zero bit.
if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
getType().get().getWidthOrSentinel()))
return getIntAttr(getType(), *cst);
return {};
}
void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
}
OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
if (!hasKnownWidthIntTypes(*this))
return {};
// FIRRTL negate always adds a bit.
// -x ---> 0-sext(x) or 0-zext(x)
if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
getType().get().getWidthOrSentinel()))
return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
return {};
}
OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
if (!hasKnownWidthIntTypes(*this))
return {};
if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
getType().get().getWidthOrSentinel()))
return getIntAttr(getType(), ~*cst);
return {};
}
void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::NotNot>(context);
}
OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
if (!hasKnownWidthIntTypes(*this))
return {};
if (getInput().getType().getBitWidthOrSentinel() == 0)
return getIntAttr(getType(), APInt(1, 1));
// x == -1
if (auto cst = getConstant(adaptor.getInput()))
return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
// one bit is identity. Only applies to UInt since we can't make a cast
// here.
if (isUInt1(getInput().getType()))
return getInput();
return {};
}
void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
patterns::AndRCatZeroL, patterns::AndRCatZeroR>(context);
}
OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
if (!hasKnownWidthIntTypes(*this))
return {};
if (getInput().getType().getBitWidthOrSentinel() == 0)
return getIntAttr(getType(), APInt(1, 0));
// x != 0
if (auto cst = getConstant(adaptor.getInput()))
return getIntAttr(getType(), APInt(1, !cst->isZero()));
// one bit is identity. Only applies to UInt since we can't make a cast
// here.
if (isUInt1(getInput().getType()))
return getInput();
return {};
}
void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
patterns::OrRCatZeroH, patterns::OrRCatZeroL>(context);
}
OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
if (!hasKnownWidthIntTypes(*this))
return {};
if (getInput().getType().getBitWidthOrSentinel() == 0)
return getIntAttr(getType(), APInt(1, 0));
// popcount(x) & 1
if (auto cst = getConstant(adaptor.getInput()))
return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
// one bit is identity. Only applies to UInt since we can't make a cast here.
if (isUInt1(getInput().getType()))
return getInput();
return {};
}
void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
patterns::XorRCatZeroH, patterns::XorRCatZeroL>(context);
}
//===----------------------------------------------------------------------===//
// Other Operators
//===----------------------------------------------------------------------===//
OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
// cat(x, 0-width) -> x
// cat(0-width, x) -> x
// Limit to unsigned (result type), as cannot insert cast here.
IntType lhsType = getLhs().getType();
IntType rhsType = getRhs().getType();
if (lhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
return getRhs();
if (rhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
return getLhs();
if (!hasKnownWidthIntTypes(*this))
return {};
// Constant fold cat.
if (auto lhs = getConstant(adaptor.getLhs()))
if (auto rhs = getConstant(adaptor.getRhs()))
return getIntAttr(getType(), lhs->concat(*rhs));
return {};
}
void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::DShlOfConstant>(context);
}
void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::DShrOfConstant>(context);
}
namespace {
// cat(bits(x, ...), bits(x, ...)) -> bits(x ...) when the two ...'s are
// consequtive in the input.
struct CatBitsBits : public mlir::RewritePattern {
CatBitsBits(MLIRContext *context)
: RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto cat = cast<CatPrimOp>(op);
if (auto lhsBits =
dyn_cast_or_null<BitsPrimOp>(cat.getLhs().getDefiningOp())) {
if (auto rhsBits =
dyn_cast_or_null<BitsPrimOp>(cat.getRhs().getDefiningOp())) {
if (lhsBits.getInput() == rhsBits.getInput() &&
lhsBits.getLo() - 1 == rhsBits.getHi()) {
replaceOpWithNewOpAndCopyName<BitsPrimOp>(
rewriter, cat, cat.getType(), lhsBits.getInput(), lhsBits.getHi(),
rhsBits.getLo());
return success();
}
}
}
return failure();
}
};
} // namespace
void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<CatBitsBits, patterns::CatDoubleConst>(context);
}
OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
auto op = (*this);
// BitCast is redundant if input and result types are same.
if (op.getType() == op.getInput().getType())
return op.getInput();
// Two consecutive BitCasts are redundant if first bitcast type is same as the
// final result type.
if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
if (op.getType() == in.getInput().getType())
return in.getInput();
return {};
}
OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
IntType inputType = getInput().getType();
IntType resultType = getType();
// If we are extracting the entire input, then return it.
if (inputType == getType() && resultType.hasWidth())
return getInput();
// Constant fold.
if (hasKnownWidthIntTypes(*this))
if (auto cst = getConstant(adaptor.getInput()))
return getIntAttr(resultType,
cst->extractBits(getHi() - getLo() + 1, getLo()));
return {};
}
void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
patterns::BitsOfAnd, patterns::BitsOfPad>(context);
}
/// Replace the specified operation with a 'bits' op from the specified hi/lo
/// bits. Insert a cast to handle the case where the original operation
/// returned a signed integer.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
unsigned loBit, PatternRewriter &rewriter) {
auto resType = type_cast<IntType>(op->getResult(0).getType());
if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
} else if (resType.isUnsigned() &&
!type_cast<IntType>(value.getType()).isUnsigned()) {
value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
}
rewriter.replaceOp(op, value);
}
template <typename OpTy>
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
// mux : UInt<0> -> 0
if (op.getType().getBitWidthOrSentinel() == 0)
return getIntAttr(op.getType(),
APInt(0, 0, op.getType().isSignedInteger()));
// mux(cond, x, x) -> x
if (op.getHigh() == op.getLow())
return op.getHigh();
// The following folds require that the result has a known width. Otherwise
// the mux requires an additional padding operation to be inserted, which is
// not possible in a fold.
if (op.getType().getBitWidthOrSentinel() < 0)
return {};
// mux(0/1, x, y) -> x or y
if (auto cond = getConstant(adaptor.getSel())) {
if (cond->isZero() && op.getLow().getType() == op.getType())
return op.getLow();
if (!cond->isZero() && op.getHigh().getType() == op.getType())
return op.getHigh();
}
// mux(cond, x, cst)
if (auto lowCst = getConstant(adaptor.getLow())) {
// mux(cond, c1, c2)
if (auto highCst = getConstant(adaptor.getHigh())) {
// mux(cond, cst, cst) -> cst
if (highCst->getBitWidth() == lowCst->getBitWidth() &&
*highCst == *lowCst)
return getIntAttr(op.getType(), *highCst);
// mux(cond, 1, 0) -> cond
if (highCst->isOne() && lowCst->isZero() &&
op.getType() == op.getSel().getType())
return op.getSel();
// TODO: x ? ~0 : 0 -> sext(x)
// TODO: "x ? c1 : c2" -> many tricks
}
// TODO: "x ? a : 0" -> sext(x) & a
}
// TODO: "x ? c1 : y" -> "~x ? y : c1"
return {};
}
OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
return foldMux(*this, adaptor);
}
OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return foldMux(*this, adaptor);
}
OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
namespace {
// If the mux has a known output width, pad the operands up to this width.
// Most folds on mux require that folded operands are of the same width as
// the mux itself.
class MuxPad : public mlir::RewritePattern {
public:
MuxPad(MLIRContext *context)
: RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
LogicalResult
matchAndRewrite(Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto mux = cast<MuxPrimOp>(op);
auto width = mux.getType().getBitWidthOrSentinel();
if (width < 0)
return failure();
auto pad = [&](Value input) -> Value {
auto inputWidth =
type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
if (inputWidth < 0 || width == inputWidth)
return input;
return rewriter
.create<PadPrimOp>(mux.getLoc(), mux.getType(), input, width)
.getResult();
};
auto newHigh = pad(mux.getHigh());
auto newLow = pad(mux.getLow());
if (newHigh == mux.getHigh() && newLow == mux.getLow())
return failure();
replaceOpWithNewOpAndCopyName<MuxPrimOp>(
rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
mux->getAttrs());
return success();
}
};
// Find muxes which have conditions dominated by other muxes with the same
// condition.
class MuxSharedCond : public mlir::RewritePattern {
public:
MuxSharedCond(MLIRContext *context)
: RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
static const int depthLimit = 5;
Value updateOrClone(MuxPrimOp mux, Value high, Value low,
mlir::PatternRewriter &rewriter,
bool updateInPlace) const {
if (updateInPlace) {
rewriter.modifyOpInPlace(mux, [&] {
mux.setOperand(1, high);
mux.setOperand(2, low);
});
return {};
}
rewriter.setInsertionPointAfter(mux);
return rewriter
.create<MuxPrimOp>(mux.getLoc(), mux.getType(),
ValueRange{mux.getSel(), high, low})
.getResult();
}
// Walk a dependent mux tree assuming the condition cond is true.
Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
bool updateInPlace, int limit) const {
MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
if (!mux)
return {};
if (mux.getSel() == cond)
return mux.getHigh();
if (limit > depthLimit)
return {};
updateInPlace &= mux->hasOneUse();
if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
limit + 1))
return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
if (Value v =
tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
return {};
}
// Walk a dependent mux tree assuming the condition cond is false.
Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
bool updateInPlace, int limit) const {
MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
if (!mux)
return {};
if (mux.getSel() == cond)
return mux.getLow();
if (limit > depthLimit)
return {};
updateInPlace &= mux->hasOneUse();
if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
limit + 1))
return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
limit + 1))
return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
return {};
}
LogicalResult
matchAndRewrite(Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto mux = cast<MuxPrimOp>(op);
auto width = mux.getType().getBitWidthOrSentinel();
if (width < 0)
return failure();
if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
return success();
}
if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
return success();
}
return failure();
}
};
} // namespace
void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
patterns::MuxEQOperandsSwapped, patterns::MuxNEQ,
patterns::MuxNot, patterns::MuxSameTrue, patterns::MuxSameFalse,
patterns::NarrowMuxLHS, patterns::NarrowMuxRHS>(context);
}
OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
auto input = this->getInput();
// pad(x) -> x if the width doesn't change.
if (input.getType() == getType())
return input;
// Need to know the input width.
auto inputType = input.getType().get();
int32_t width = inputType.getWidthOrSentinel();
if (width == -1)
return {};
// Constant fold.
if (auto cst = getConstant(adaptor.getInput())) {
auto destWidth = getType().get().getWidthOrSentinel();
if (destWidth == -1)
return {};
if (inputType.isSigned() && cst->getBitWidth())
return getIntAttr(getType(), cst->sext(destWidth));
return getIntAttr(getType(), cst->zext(destWidth));
}
return {};
}
OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
auto input = this->getInput();
IntType inputType = input.getType();
int shiftAmount = getAmount();
// shl(x, 0) -> x
if (shiftAmount == 0)
return input;
// Constant fold.
if (auto cst = getConstant(adaptor.getInput())) {
auto inputWidth = inputType.getWidthOrSentinel();
if (inputWidth != -1) {
auto resultWidth = inputWidth + shiftAmount;
shiftAmount = std::min(shiftAmount, resultWidth);
return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
}
}
return {};
}
OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
auto input = this->getInput();
IntType inputType = input.getType();
int shiftAmount = getAmount();
// shr(x, 0) -> x
if (shiftAmount == 0)
return input;
auto inputWidth = inputType.getWidthOrSentinel();
if (inputWidth == -1)
return {};
if (inputWidth == 0)
return getIntZerosAttr(getType());
// shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
// If x is signed, it is the sign bit.
if (shiftAmount >= inputWidth && inputType.isUnsigned())
return getIntAttr(getType(), APInt(1, 0));
// Constant fold.
if (auto cst = getConstant(adaptor.getInput())) {
APInt value;
if (inputType.isSigned())
value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
else
value = cst->lshr(std::min(shiftAmount, inputWidth));
auto resultWidth = std::max(inputWidth - shiftAmount, 1);
return getIntAttr(getType(), value.trunc(resultWidth));
}
return {};
}
LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
auto inputWidth = op.getInput().getType().get().getWidthOrSentinel();
if (inputWidth <= 0)
return failure();
// If we know the input width, we can canonicalize this into a BitsPrimOp.
unsigned shiftAmount = op.getAmount();
if (int(shiftAmount) >= inputWidth) {
// shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
if (op.getType().get().isUnsigned())
return failure();
// Shifting a signed value by the full width is actually taking the
// sign bit. If the shift amount is greater than the input width, it
// is equivalent to shifting by the input width.
shiftAmount = inputWidth - 1;
}
replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
return success();
}
LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
PatternRewriter &rewriter) {
auto inputWidth = op.getInput().getType().get().getWidthOrSentinel();
if (inputWidth <= 0)
return failure();
// If we know the input width, we can canonicalize this into a BitsPrimOp.
unsigned keepAmount = op.getAmount();
if (keepAmount)
replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
rewriter);
return success();
}
OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
if (hasKnownWidthIntTypes(*this))
if (auto cst = getConstant(adaptor.getInput())) {
int shiftAmount =
getInput().getType().get().getWidthOrSentinel() - getAmount();
return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
}
return {};
}
OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
if (hasKnownWidthIntTypes(*this))
if (auto cst = getConstant(adaptor.getInput()))
return getIntAttr(getType(),
cst->trunc(getType().get().getWidthOrSentinel()));
return {};
}
LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
PatternRewriter &rewriter) {
auto inputWidth = op.getInput().getType().get().getWidthOrSentinel();
if (inputWidth <= 0)
return failure();
// If we know the input width, we can canonicalize this into a BitsPrimOp.
unsigned dropAmount = op.getAmount();
if (dropAmount != unsigned(inputWidth))
replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
rewriter);
return success();
}
void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<patterns::SubaccessOfConstant>(context);
}
OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
// If there is only one input, just return it.
if (adaptor.getInputs().size() == 1)
return getOperand(1);
if (auto constIndex = getConstant(adaptor.getIndex())) {
auto index = constIndex->getZExtValue();
if (index < getInputs().size())
return getInputs()[getInputs().size() - 1 - index];
}
return {};
}
LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
PatternRewriter &rewriter) {
// If all operands are equal, just canonicalize to it. We can add this
// canonicalization as a folder but it costly to look through all inputs so it
// is added here.
if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
return input == op.getInputs().front();
})) {
replaceOpAndCopyName(rewriter, op, op.getInputs().front());
return success();
}
// If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
// a[0]`), we can fold the op into subaccess op `a[idx]`.
if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
auto subindex = e.value().template getDefiningOp<SubindexOp>();
return subindex && lastSubindex.getInput() == subindex.getInput() &&
subindex.getIndex() + e.index() + 1 == op.getInputs().size();
})) {
replaceOpWithNewOpAndCopyName<SubaccessOp>(
rewriter, op, lastSubindex.getInput(), op.getIndex());
return success();
}
}
// If the size is 2, canonicalize into a normal mux to introduce more folds.
if (op.getInputs().size() != 2)
return failure();
// TODO: Handle even when `index` doesn't have uint<1>.
auto uintType = op.getIndex().getType();
if (uintType.getBitWidthOrSentinel() != 1)
return failure();
// multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
replaceOpWithNewOpAndCopyName<MuxPrimOp>(
rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
return success();
}
//===----------------------------------------------------------------------===//
// Declarations
//===----------------------------------------------------------------------===//
/// Scan all the uses of the specified value, checking to see if there is
/// exactly one connect that has the value as its destination. This returns the
/// operation if found and if all the other users are "reads" from the value.
/// Returns null if there are no connects, or multiple connects to the value, or
/// if the value is involved in an `AttachOp`, or if the connect isn't strict.
///
/// Note that this will simply return the connect, which is located *anywhere*
/// after the definition of the value. Users of this function are likely
/// interested in the source side of the returned connect, the definition of
/// which does likely not dominate the original value.
StrictConnectOp firrtl::getSingleConnectUserOf(Value value) {
StrictConnectOp connect;
for (Operation *user : value.getUsers()) {
// If we see an attach or aggregate sublements, just conservatively fail.
if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
return {};
if (auto aConnect = dyn_cast<FConnectLike>(user))
if (aConnect.getDest() == value) {
auto strictConnect = dyn_cast<StrictConnectOp>(*aConnect);
// If this is not a strict connect, a second strict connect or in a
// different block, fail.
if (!strictConnect || (connect && connect != strictConnect) ||
strictConnect->getBlock() != value.getParentBlock())
return {};
else
connect = strictConnect;
}
}
return connect;
}
// Forward simple values through wire's and reg's.
static LogicalResult canonicalizeSingleSetConnect(StrictConnectOp op,
PatternRewriter &rewriter) {
// While we can do this for nearly all wires, we currently limit it to simple
// things.
Operation *connectedDecl = op.getDest().getDefiningOp();
if (!connectedDecl)
return failure();
// Only support wire and reg for now.
if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
return failure();
if (hasDontTouch(connectedDecl) ||
!AnnotationSet(connectedDecl).canBeDeleted() ||
!hasDroppableName(connectedDecl) ||
cast<Forceable>(connectedDecl).isForceable())
return failure();
// Only forward if the types exactly match and there is one connect.
if (getSingleConnectUserOf(op.getDest()) != op)
return failure();
// Only forward if there is more than one use
if (connectedDecl->hasOneUse())
return failure();
// Only do this if the connectee and the declaration are in the same block.
auto *declBlock = connectedDecl->getBlock();
auto *srcValueOp = op.getSrc().getDefiningOp();
if (!srcValueOp) {
// Ports are ok for wires but not registers.
if (!isa<WireOp>(connectedDecl))
return failure();
} else {
// Constants/invalids in the same block are ok to forward, even through
// reg's since the clocking doesn't matter for constants.
if (!isa<ConstantOp>(srcValueOp))
return failure();
if (srcValueOp->getBlock() != declBlock)
return failure();
}
// Ok, we know we are doing the transformation.
auto replacement = op.getSrc();
// This will be replaced with the constant source. First, make sure the
// constant dominates all users.
if (srcValueOp && srcValueOp != &declBlock->front())
srcValueOp->moveBefore(&declBlock->front());
// Replace all things *using* the decl with the constant/port, and
// remove the declaration.
replaceOpAndCopyName(rewriter, connectedDecl, replacement);
// Remove the connect
rewriter.eraseOp(op);
return success();
}
void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
context);
}
LogicalResult StrictConnectOp::canonicalize(StrictConnectOp op,
PatternRewriter &rewriter) {
// TODO: Canonicalize towards explicit extensions and flips here.
// If there is a simple value connected to a foldable decl like a wire or reg,
// see if we can eliminate the decl.
if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
return success();
return failure();
}
//===----------------------------------------------------------------------===//
// Statements
//===----------------------------------------------------------------------===//
/// If the specified value has an AttachOp user strictly dominating by
/// "dominatingAttach" then return it.
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
for (auto *user : value.getUsers()) {
auto attach = dyn_cast<AttachOp>(user);
if (!attach || attach == dominatedAttach)
continue;
if (attach->isBeforeInBlock(dominatedAttach))
return attach;
}
return {};
}
LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
// Single operand attaches are a noop.
if (op.getNumOperands() <= 1) {
rewriter.eraseOp(op);
return success();
}
for (auto operand : op.getOperands()) {
// Check to see if any of our operands has other attaches to it:
// attach x, y
// ...
// attach x, z
// If so, we can merge these into "attach x, y, z".
if (auto attach = getDominatingAttachUser(operand, op)) {
SmallVector<Value> newOperands(op.getOperands());
for (auto newOperand : attach.getOperands())
if (newOperand != operand) // Don't add operand twice.
newOperands.push_back(newOperand);
rewriter.create<AttachOp>(op->getLoc(), newOperands);
rewriter.eraseOp(attach);
rewriter.eraseOp(op);
return success();
}
// If this wire is *only* used by an attach then we can just delete
// it.
// TODO: May need to be sensitive to "don't touch" or other
// annotations.
if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
!wire.isForceable()) {
SmallVector<Value> newOperands;
for (auto newOperand : op.getOperands())
if (newOperand != operand) // Don't the add wire.
newOperands.push_back(newOperand);
rewriter.create<AttachOp>(op->getLoc(), newOperands);
rewriter.eraseOp(op);
rewriter.eraseOp(wire);
return success();
}
}
}
return failure();
}
/// Replaces the given op with the contents of the given single-block region.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
Region &region) {
assert(llvm::hasSingleElement(region) && "expected single-region block");
rewriter.inlineBlockBefore(&region.front(), op, {});
}
LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
if (constant.getValue().isAllOnes())
replaceOpWithRegion(rewriter, op, op.getThenRegion());
else if (op.hasElseRegion() && !op.getElseRegion().empty())
replaceOpWithRegion(rewriter, op, op.getElseRegion());
rewriter.eraseOp(op);
return success();
}
// Erase empty if-else block.
if (!op.getThenBlock().empty() && op.hasElseRegion() &&
op.getElseBlock().empty()) {
rewriter.eraseBlock(&op.getElseBlock());
return success();
}
// Erase empty whens.
// If there is stuff in the then block, leave this operation alone.
if (!op.getThenBlock().empty())
return failure();
// If not and there is no else, then this operation is just useless.
if (!op.hasElseRegion() || op.getElseBlock().empty()) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
namespace {
// Remove private nodes. If they have an interesting names, move the name to
// the source expression.
struct FoldNodeName : public mlir::RewritePattern {
FoldNodeName(MLIRContext *context)
: RewritePattern(NodeOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto node = cast<NodeOp>(op);
auto name = node.getNameAttr();
if (!node.hasDroppableName() || node.getInnerSym() ||
!AnnotationSet(node).canBeDeleted() || node.isForceable())
return failure();
auto *newOp = node.getInput().getDefiningOp();
// Best effort, do not rename InstanceOp
if (newOp && !isa<InstanceOp>(newOp))
updateName(rewriter, newOp, name);
rewriter.replaceOp(node, node.getInput());
return success();
}
};
// Bypass nodes.
struct NodeBypass : public mlir::RewritePattern {
NodeBypass(MLIRContext *context)
: RewritePattern(NodeOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto node = cast<NodeOp>(op);
if (node.getInnerSym() || !AnnotationSet(node).canBeDeleted() ||
node.use_empty() || node.isForceable())
return failure();
rewriter.startOpModification(node);
node.getResult().replaceAllUsesWith(node.getInput());
rewriter.finalizeOpModification(node);
return success();
}
};
} // namespace
template <typename OpTy>
static LogicalResult demoteForceableIfUnused(OpTy op,
PatternRewriter &rewriter) {
if (!op.isForceable() || !op.getDataRef().use_empty())
return failure();
firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
return success();
}
// Interesting names and symbols and don't touch force nodes to stick around.
LogicalResult NodeOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
if (!hasDroppableName())
return failure();
if (hasDontTouch(getResult())) // handles inner symbols
return failure();
if (getAnnotationsAttr() &&
!AnnotationSet(getAnnotationsAttr()).canBeDeleted())
return failure();
if (isForceable())
return failure();
if (!adaptor.getInput())
return failure();
results.push_back(adaptor.getInput());
return success();
}
void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<FoldNodeName>(context);
results.add(demoteForceableIfUnused<NodeOp>);
}
namespace {
// For a lhs, find all the writers of fields of the aggregate type. If there
// is one writer for each field, merge the writes
struct AggOneShot : public mlir::RewritePattern {
AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
: RewritePattern(name, 0, context) {}
SmallVector<Value> getCompleteWrite(Operation *lhs) const {
auto lhsTy = lhs->getResult(0).getType();
if (!type_isa<BundleType, FVectorType>(lhsTy))
return {};
DenseMap<uint32_t, Value> fields;
for (Operation *user : lhs->getResult(0).getUsers()) {
if (user->getParentOp() != lhs->getParentOp())
return {};
if (auto aConnect = dyn_cast<StrictConnectOp>(user)) {
if (aConnect.getDest() == lhs->getResult(0))
return {};
} else if (auto subField = dyn_cast<SubfieldOp>(user)) {
for (Operation *subuser : subField.getResult().getUsers()) {
if (auto aConnect = dyn_cast<StrictConnectOp>(subuser)) {
if (aConnect.getDest() == subField) {
if (subuser->getParentOp() != lhs->getParentOp())
return {};
if (fields.count(subField.getFieldIndex())) // duplicate write
return {};
fields[subField.getFieldIndex()] = aConnect.getSrc();
}
continue;
}
return {};
}
} else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
for (Operation *subuser : subIndex.getResult().getUsers()) {
if (auto aConnect = dyn_cast<StrictConnectOp>(subuser)) {
if (aConnect.getDest() == subIndex) {
if (subuser->getParentOp() != lhs->getParentOp())
return {};
if (fields.count(subIndex.getIndex())) // duplicate write
return {};
fields[subIndex.getIndex()] = aConnect.getSrc();
}
continue;
}
return {};
}
} else {
return {};
}
}
SmallVector<Value> values;
uint32_t total = type_isa<BundleType>(lhsTy)
? type_cast<BundleType>(lhsTy).getNumElements()
: type_cast<FVectorType>(lhsTy).getNumElements();
for (uint32_t i = 0; i < total; ++i) {
if (!fields.count(i))
return {};
values.push_back(fields[i]);
}
return values;
}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto values = getCompleteWrite(op);
if (values.empty())
return failure();
rewriter.setInsertionPointToEnd(op->getBlock());
auto dest = op->getResult(0);
auto destType = dest.getType();
// If not passive, cannot strictconnect.
if (!type_cast<FIRRTLBaseType>(destType).isPassive())
return failure();
Value newVal = type_isa<BundleType>(destType)
? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
destType, values)
: rewriter.createOrFold<VectorCreateOp>(
op->getLoc(), destType, values);
rewriter.createOrFold<StrictConnectOp>(op->getLoc(), dest, newVal);
for (Operation *user : dest.getUsers()) {
if (auto subIndex = dyn_cast<SubindexOp>(user)) {
for (Operation *subuser :
llvm::make_early_inc_range(subIndex.getResult().getUsers()))
if (auto aConnect = dyn_cast<StrictConnectOp>(subuser))
if (aConnect.getDest() == subIndex)
rewriter.eraseOp(aConnect);
} else if (auto subField = dyn_cast<SubfieldOp>(user)) {
for (Operation *subuser :
llvm::make_early_inc_range(subField.getResult().getUsers()))
if (auto aConnect = dyn_cast<StrictConnectOp>(subuser))
if (aConnect.getDest() == subField)
rewriter.eraseOp(aConnect);
}
}
return success();
}
};
struct WireAggOneShot : public AggOneShot {
WireAggOneShot(MLIRContext *context)
: AggOneShot(WireOp::getOperationName(), 0, context) {}
};
struct SubindexAggOneShot : public AggOneShot {
SubindexAggOneShot(MLIRContext *context)
: AggOneShot(SubindexOp::getOperationName(), 0, context) {}
};
struct SubfieldAggOneShot : public AggOneShot {
SubfieldAggOneShot(MLIRContext *context)
: AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
};
} // namespace
void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<WireAggOneShot>(context);
results.add(demoteForceableIfUnused<WireOp>);
}
void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<SubindexAggOneShot>(context);
}
OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
if (!attr)
return {};
return attr[getIndex()];
}
OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
if (!attr)
return {};
auto index = getFieldIndex();
return attr[index];
}
void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<SubfieldAggOneShot>(context);
}
static Attribute collectFields(MLIRContext *context,
ArrayRef<Attribute> operands) {
for (auto operand : operands)
if (!operand)
return {};
return ArrayAttr::get(context, operands);
}
OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
// bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
// bundle<a:..., b:...>.
if (getNumOperands() > 0)
if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
if (first.getFieldIndex() == 0 &&
first.getInput().getType() == getType() &&
llvm::all_of(
llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
auto subindex =
elem.value().template getDefiningOp<SubfieldOp>();
return subindex && subindex.getInput() == first.getInput() &&
subindex.getFieldIndex() == elem.index();
}))
return first.getInput();
return collectFields(getContext(), adaptor.getOperands());
}
OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
// vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
// vector<..., 2>.
if (getNumOperands() > 0)
if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
llvm::all_of(
llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
auto subindex =
elem.value().template getDefiningOp<SubindexOp>();
return subindex && subindex.getInput() == first.getInput() &&
subindex.getIndex() == elem.index();
}))
return first.getInput();
return collectFields(getContext(), adaptor.getOperands());
}
OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
if (getOperand().getType() == getType())
return getOperand();
return {};
}
namespace {
// A register with constant reset and all connection to either itself or the
// same constant, must be replaced by the constant.
struct FoldResetMux : public mlir::RewritePattern {
FoldResetMux(MLIRContext *context)
: RewritePattern(RegResetOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto reg = cast<RegResetOp>(op);
auto reset =
dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
if (!reset || hasDontTouch(reg.getOperation()) ||
!AnnotationSet(reg).canBeDeleted() || reg.isForceable())
return failure();
// Find the one true connect, or bail
auto con = getSingleConnectUserOf(reg.getResult());
if (!con)
return failure();
auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
if (!mux)
return failure();
auto *high = mux.getHigh().getDefiningOp();
auto *low = mux.getLow().getDefiningOp();
auto constOp = dyn_cast_or_null<ConstantOp>(high);
if (constOp && low != reg)
return failure();
if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
constOp = dyn_cast<ConstantOp>(low);
if (!constOp || constOp.getType() != reset.getType() ||
constOp.getValue() != reset.getValue())
return failure();
// Check all types should be typed by now
auto regTy = reg.getResult().getType();
if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
regTy.getBitWidthOrSentinel() < 0)
return failure();
// Ok, we know we are doing the transformation.
// Make sure the constant dominates all users.
if (constOp != &con->getBlock()->front())
constOp->moveBefore(&con->getBlock()->front());
// Replace the register with the constant.
replaceOpAndCopyName(rewriter, reg, constOp.getResult());
// Remove the connect.
rewriter.eraseOp(con);
return success();
}
};
} // namespace
static bool isDefinedByOneConstantOp(Value v) {
if (auto c = v.getDefiningOp<ConstantOp>())
return c.getValue().isOne();
if (auto sc = v.getDefiningOp<SpecialConstantOp>())
return sc.getValue();
return false;
}
static LogicalResult
canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
if (!isDefinedByOneConstantOp(reg.getResetSignal()))
return failure();
// Ignore 'passthrough'.
(void)dropWrite(rewriter, reg->getResult(0), {});
replaceOpWithNewOpAndCopyName<NodeOp>(
rewriter, reg, reg.getResetValue(), reg.getNameAttr(), reg.getNameKind(),
reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
return success();
}
void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
results.add(canonicalizeRegResetWithOneReset);
results.add(demoteForceableIfUnused<RegResetOp>);
}
// Returns the value connected to a port, if there is only one.
static Value getPortFieldValue(Value port, StringRef name) {
auto portTy = type_cast<BundleType>(port.getType());
auto fieldIndex = portTy.getElementIndex(name);
assert(fieldIndex && "missing field on memory port");
Value value = {};
for (auto *op : port.getUsers()) {
auto portAccess = cast<SubfieldOp>(op);
if (fieldIndex != portAccess.getFieldIndex())
continue;
auto conn = getSingleConnectUserOf(portAccess);
if (!conn || value)
return {};
value = conn.getSrc();
}
return value;
}
// Returns true if the enable field of a port is set to false.
static bool isPortDisabled(Value port) {
auto value = getPortFieldValue(port, "en");
if (!value)
return false;
auto portConst = value.getDefiningOp<ConstantOp>();
if (!portConst)
return false;
return portConst.getValue().isZero();
}
// Returns true if the data output is unused.
static bool isPortUnused(Value port, StringRef data) {
auto portTy = type_cast<BundleType>(port.getType());
auto fieldIndex = portTy.getElementIndex(data);
assert(fieldIndex && "missing enable flag on memory port");
for (auto *op : port.getUsers()) {
auto portAccess = cast<SubfieldOp>(op);
if (fieldIndex != portAccess.getFieldIndex())
continue;
if (!portAccess.use_empty())
return false;
}
return true;
}
// Returns the value connected to a port, if there is only one.
static void replacePortField(PatternRewriter &rewriter, Value port,
StringRef name, Value value) {
auto portTy = type_cast<BundleType>(port.getType());
auto fieldIndex = portTy.getElementIndex(name);
assert(fieldIndex && "missing field on memory port");
for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
auto portAccess = cast<SubfieldOp>(op);
if (fieldIndex != portAccess.getFieldIndex())
continue;
rewriter.replaceAllUsesWith(portAccess, value);
rewriter.eraseOp(portAccess);
}
}
// Remove accesses to a port which is used.
static void erasePort(PatternRewriter &rewriter, Value port) {
// Helper to create a dummy 0 clock for the dummy registers.
Value clock;
auto getClock = [&] {
if (!clock)
clock = rewriter.create<SpecialConstantOp>(
port.getLoc(), ClockType::get(rewriter.getContext()), false);
return clock;
};
// Find the clock field of the port and determine whether the port is
// accessed only through its subfields or as a whole wire. If the port
// is used in its entirety, replace it with a wire. Otherwise,
// eliminate individual subfields and replace with reasonable defaults.
for (auto *op : port.getUsers()) {
auto subfield = dyn_cast<SubfieldOp>(op);
if (!subfield) {
auto ty = port.getType();
auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
port.replaceAllUsesWith(reg.getResult());
return;
}
}
// Remove all connects to field accesses as they are no longer relevant.
// If field values are used anywhere, which should happen solely for read
// ports, a dummy register is introduced which replicates the behaviour of
// memory that is never written, but might be read.
for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
auto access = cast<SubfieldOp>(accessOp);
for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
auto connect = dyn_cast<FConnectLike>(user);
if (connect && connect.getDest() == access) {
rewriter.eraseOp(user);
continue;
}
}
if (access.use_empty()) {
rewriter.eraseOp(access);
continue;
}
// Replace read values with a register that is never written, handing off
// the canonicalization of such a register to another canonicalizer.
auto ty = access.getType();
auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
rewriter.replaceOp(access, reg.getResult());
}
assert(port.use_empty() && "port should have no remaining uses");
}
namespace {
// If memory has known, but zero width, eliminate it.
struct FoldZeroWidthMemory : public mlir::RewritePattern {
FoldZeroWidthMemory(MLIRContext *context)
: RewritePattern(MemOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
MemOp mem = cast<MemOp>(op);
if (hasDontTouch(mem))
return failure();
if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
mem.getDataType().getBitWidthOrSentinel() != 0)
return failure();
// Make sure are users are safe to replace
for (auto port : mem.getResults())
for (auto *user : port.getUsers())
if (!isa<SubfieldOp>(user))
return failure();
// Annoyingly, there isn't a good replacement for the port as a whole,
// since they have an outer flip type.
for (auto port : op->getResults()) {
for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
SubfieldOp sfop = cast<SubfieldOp>(user);
StringRef fieldName = sfop.getFieldName();
auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
rewriter, sfop, sfop.getResult().getType())
.getResult();
if (fieldName.ends_with("data")) {
// Make sure to write data ports.
auto zero = rewriter.create<firrtl::ConstantOp>(
wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
APInt::getZero(0));
rewriter.create<StrictConnectOp>(wire.getLoc(), wire, zero);
}
}
}
rewriter.eraseOp(op);
return success();
}
};
// If memory has no write ports and no file initialization, eliminate it.
struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
FoldReadOrWriteOnlyMemory(MLIRContext *context)
: RewritePattern(MemOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
MemOp mem = cast<MemOp>(op);
if (hasDontTouch(mem))
return failure();
bool isRead = false, isWritten = false;
for (unsigned i = 0; i < mem.getNumResults(); ++i) {
switch (mem.getPortKind(i)) {
case MemOp::PortKind::Read:
isRead = true;
if (isWritten)
return failure();
continue;
case MemOp::PortKind::Write:
isWritten = true;
if (isRead)
return failure();
continue;
case MemOp::PortKind::Debug:
case MemOp::PortKind::ReadWrite:
return failure();
}
llvm_unreachable("unknown port kind");
}
assert((!isWritten || !isRead) && "memory is in use");
// If the memory is read only, but has a file initialization, then we can't
// remove it. A write only memory with file initialization is okay to
// remove.
if (isRead && mem.getInit())
return failure();
for (auto port : mem.getResults())
erasePort(rewriter, port);
rewriter.eraseOp(op);
return success();
}
};
// Eliminate the dead ports of memories.
struct FoldUnusedPorts : public mlir::RewritePattern {
FoldUnusedPorts(MLIRContext *context)
: RewritePattern(MemOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
MemOp mem = cast<MemOp>(op);
if (hasDontTouch(mem))
return failure();
// Identify the dead and changed ports.
llvm::SmallBitVector deadPorts(mem.getNumResults());
for (auto [i, port] : llvm::enumerate(mem.getResults())) {
// Do not simplify annotated ports.
if (!mem.getPortAnnotation(i).empty())
continue;
// Skip debug ports.
auto kind = mem.getPortKind(i);
if (kind == MemOp::PortKind::Debug)
continue;
// If a port is disabled, always eliminate it.
if (isPortDisabled(port)) {
deadPorts.set(i);
continue;
}
// Eliminate read ports whose outputs are not used.
if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
deadPorts.set(i);
continue;
}
}
if (deadPorts.none())
return failure();
// Rebuild the new memory with the altered ports.
SmallVector<Type> resultTypes;
SmallVector<StringRef> portNames;
SmallVector<Attribute> portAnnotations;
for (auto [i, port] : llvm::enumerate(mem.getResults())) {
if (deadPorts[i])
continue;
resultTypes.push_back(port.getType());
portNames.push_back(mem.getPortName(i));
portAnnotations.push_back(mem.getPortAnnotation(i));
}
MemOp newOp;
if (!resultTypes.empty())
newOp = rewriter.create<MemOp>(
mem.getLoc(), resultTypes, mem.getReadLatency(),
mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
// Replace the dead ports with dummy wires.
unsigned nextPort = 0;
for (auto [i, port] : llvm::enumerate(mem.getResults())) {
if (deadPorts[i])
erasePort(rewriter, port);
else
port.replaceAllUsesWith(newOp.getResult(nextPort++));
}
rewriter.eraseOp(op);
return success();
}
};
// Rewrite write-only read-write ports to write ports.
struct FoldReadWritePorts : public mlir::RewritePattern {
FoldReadWritePorts(MLIRContext *context)
: RewritePattern(MemOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
MemOp mem = cast<MemOp>(op);
if (hasDontTouch(mem))
return failure();
// Identify read-write ports whose read end is unused.
llvm::SmallBitVector deadReads(mem.getNumResults());
for (auto [i, port] : llvm::enumerate(mem.getResults())) {
if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
continue;
if (!mem.getPortAnnotation(i).empty())
continue;
if (isPortUnused(port, "rdata")) {
deadReads.set(i);
continue;
}
}
if (deadReads.none())
return failure();
SmallVector<Type> resultTypes;
SmallVector<StringRef> portNames;
SmallVector<Attribute> portAnnotations;
for (auto [i, port] : llvm::enumerate(mem.getResults())) {
if (deadReads[i])
resultTypes.push_back(
MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
MemOp::PortKind::Write, mem.getMaskBits()));
else
resultTypes.push_back(port.getType());
portNames.push_back(mem.getPortName(i));
portAnnotations.push_back(mem.getPortAnnotation(i));
}
auto newOp = rewriter.create<MemOp>(
mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
mem.getName(), mem.getNameKind(), mem.getAnnotations(),
rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
mem.getInitAttr(), mem.getPrefixAttr());
for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
auto result = mem.getResult(i);
auto newResult = newOp.getResult(i);
if (deadReads[i]) {
auto resultPortTy = type_cast<BundleType>(result.getType());
// Rewrite accesses to the old port field to accesses to a
// corresponding field of the new port.
auto replace = [&](StringRef toName, StringRef fromName) {
auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
assert(fromFieldIndex && "missing enable flag on memory port");
auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
newResult, toName);
for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
auto fromField = cast<SubfieldOp>(op);
if (fromFieldIndex != fromField.getFieldIndex())
continue;
rewriter.replaceOp(fromField, toField.getResult());
}
};
replace("addr", "addr");
replace("en", "en");
replace("clk", "clk");
replace("data", "wdata");
replace("mask", "wmask");
// Remove the wmode field, replacing it with dummy wires.
auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
auto wmodeField = cast<SubfieldOp>(op);
if (wmodeFieldIndex != wmodeField.getFieldIndex())
continue;
rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
}
} else {
result.replaceAllUsesWith(newResult);
}
}
rewriter.eraseOp(op);
return success();
}
};
// Eliminate the dead ports of memories.
struct FoldUnusedBits : public mlir::RewritePattern {
FoldUnusedBits(MLIRContext *context)
: RewritePattern(MemOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
MemOp mem = cast<MemOp>(op);
if (hasDontTouch(mem))
return failure();
// Only apply the transformation if the memory is not sequential.
const auto &summary = mem.getSummary();
if (summary.isMasked || summary.isSeqMem())
return failure();
auto type = type_dyn_cast<IntType>(mem.getDataType());
if (!type)
return failure();
auto width = type.getBitWidthOrSentinel();
if (width <= 0)
return failure();
llvm::SmallBitVector usedBits(width);
DenseMap<unsigned, unsigned> mapping;
// Find which bits are used out of the users of a read port. This detects
// ports whose data/rdata field is used only through bit select ops. The
// bit selects are then used to build a bit-mask. The ops are collected.
SmallVector<BitsPrimOp> readOps;
auto findReadUsers = [&](Value port, StringRef field) {
auto portTy = type_cast<BundleType>(port.getType());
auto fieldIndex = portTy.getElementIndex(field);
assert(fieldIndex && "missing data port");
for (auto *op : port.getUsers()) {
auto portAccess = cast<SubfieldOp>(op);
if (fieldIndex != portAccess.getFieldIndex())
continue;
for (auto *user : op->getUsers()) {
auto bits = dyn_cast<BitsPrimOp>(user);
if (!bits) {
usedBits.set();
continue;
}
usedBits.set(bits.getLo(), bits.getHi() + 1);
mapping[bits.getLo()] = 0;
readOps.push_back(bits);
}
}
};
// Finds the users of write ports. This expects all the data/wdata fields
// of the ports to be used solely as the destination of strict connects.
// If a memory has ports with other uses, it is excluded from optimisation.
SmallVector<StrictConnectOp> writeOps;
auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
auto portTy = type_cast<BundleType>(port.getType());
auto fieldIndex = portTy.getElementIndex(field);
assert(fieldIndex && "missing data port");
for (auto *op : port.getUsers()) {
auto portAccess = cast<SubfieldOp>(op);
if (fieldIndex != portAccess.getFieldIndex())
continue;
auto conn = getSingleConnectUserOf(portAccess);
if (!conn)
return failure();
writeOps.push_back(conn);
}
return success();
};
// Traverse all ports and find the read and used data fields.
for (auto [i, port] : llvm::enumerate(mem.getResults())) {
// Do not simplify annotated ports.
if (!mem.getPortAnnotation(i).empty())
return failure();
switch (mem.getPortKind(i)) {
case MemOp::PortKind::Debug:
// Skip debug ports.
return failure();
case MemOp::PortKind::Write:
if (failed(findWriteUsers(port, "data")))
return failure();
continue;
case MemOp::PortKind::Read:
findReadUsers(port, "data");
continue;
case MemOp::PortKind::ReadWrite:
if (failed(findWriteUsers(port, "wdata")))
return failure();
findReadUsers(port, "rdata");
continue;
}
llvm_unreachable("unknown port kind");
}
// Perform the transformation is there are some bits missing. Unused
// memories are handled in a different canonicalizer.
if (usedBits.all() || usedBits.none())
return failure();
// Build a mapping of existing indices to compacted ones.
SmallVector<std::pair<unsigned, unsigned>> ranges;
unsigned newWidth = 0;
for (int i = usedBits.find_first(); 0 <= i && i < width;) {
int e = usedBits.find_next_unset(i);
if (e < 0)
e = width;
for (int idx = i; idx < e; ++idx, ++newWidth) {
if (auto it = mapping.find(idx); it != mapping.end()) {
it->second = newWidth;
}
}
ranges.emplace_back(i, e - 1);
i = e != width ? usedBits.find_next(e) : e;
}
// Create the new op with the new port types.
auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
SmallVector<Type> portTypes;
for (auto [i, port] : llvm::enumerate(mem.getResults())) {
portTypes.push_back(
MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
}
auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
// Rewrite bundle users to the new data type.
auto rewriteSubfield = [&](Value port, StringRef field) {
auto portTy = type_cast<BundleType>(port.getType());
auto fieldIndex = portTy.getElementIndex(field);
assert(fieldIndex && "missing data port");
rewriter.setInsertionPointAfter(newMem);
auto newPortAccess =
rewriter.create<SubfieldOp>(port.getLoc(), port, field);
for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
auto portAccess = cast<SubfieldOp>(op);
if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
continue;
rewriter.replaceOp(portAccess, newPortAccess.getResult());
}
};
// Rewrite the field accesses.
for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
switch (newMem.getPortKind(i)) {
case MemOp::PortKind::Debug:
llvm_unreachable("cannot rewrite debug port");
case MemOp::PortKind::Write:
rewriteSubfield(port, "data");
continue;
case MemOp::PortKind::Read:
rewriteSubfield(port, "data");
continue;
case MemOp::PortKind::ReadWrite:
rewriteSubfield(port, "rdata");
rewriteSubfield(port, "wdata");
continue;
}
llvm_unreachable("unknown port kind");
}
// Rewrite the reads to the new ranges, compacting them.
for (auto readOp : readOps) {
rewriter.setInsertionPointAfter(readOp);
auto it = mapping.find(readOp.getLo());
assert(it != mapping.end() && "bit op mapping not found");
rewriter.replaceOpWithNewOp<BitsPrimOp>(
readOp, readOp.getInput(),
readOp.getHi() - readOp.getLo() + it->second, it->second);
}
// Rewrite the writes into a concatenation of slices.
for (auto writeOp : writeOps) {
Value source = writeOp.getSrc();
rewriter.setInsertionPoint(writeOp);
Value catOfSlices;
for (auto &[start, end] : ranges) {
Value slice =
rewriter.create<BitsPrimOp>(writeOp.getLoc(), source, end, start);
if (catOfSlices) {
catOfSlices =
rewriter.create<CatPrimOp>(writeOp.getLoc(), slice, catOfSlices);
} else {
catOfSlices = slice;
}
}
rewriter.replaceOpWithNewOp<StrictConnectOp>(writeOp, writeOp.getDest(),
catOfSlices);
}
return success();
}
};
// Rewrite single-address memories to a firrtl register.
struct FoldRegMems : public mlir::RewritePattern {
FoldRegMems(MLIRContext *context)
: RewritePattern(MemOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
MemOp mem = cast<MemOp>(op);
const FirMemory &info = mem.getSummary();
if (hasDontTouch(mem) || info.depth != 1)
return failure();
auto memModule = mem->getParentOfType<FModuleOp>();
// Find the clock of the register-to-be, all write ports should share it.
Value clock;
SmallPtrSet<Operation *, 8> connects;
SmallVector<SubfieldOp> portAccesses;
for (auto [i, port] : llvm::enumerate(mem.getResults())) {
if (!mem.getPortAnnotation(i).empty())
continue;
auto collect = [&, port = port](ArrayRef<StringRef> fields) {
auto portTy = type_cast<BundleType>(port.getType());
for (auto field : fields) {
auto fieldIndex = portTy.getElementIndex(field);
assert(fieldIndex && "missing field on memory port");
for (auto *op : port.getUsers()) {
auto portAccess = cast<SubfieldOp>(op);
if (fieldIndex != portAccess.getFieldIndex())
continue;
portAccesses.push_back(portAccess);
for (auto *user : portAccess->getUsers()) {
auto conn = dyn_cast<FConnectLike>(user);
if (!conn)
return failure();
connects.insert(conn);
}
}
}
return success();
};
switch (mem.getPortKind(i)) {
case MemOp::PortKind::Debug:
return failure();
case MemOp::PortKind::Read:
if (failed(collect({"clk", "en", "addr"})))
return failure();
continue;
case MemOp::PortKind::Write:
if (failed(collect({"clk", "en", "addr", "data", "mask"})))
return failure();
break;
case MemOp::PortKind::ReadWrite:
if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
return failure();
break;
}
Value portClock = getPortFieldValue(port, "clk");
if (!portClock || (clock && portClock != clock))
return failure();
clock = portClock;
}
// Create a new register to store the data.
auto ty = mem.getDataType();
rewriter.setInsertionPointAfterValue(clock);
auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
.getResult();
// Helper to insert a given number of pipeline stages through registers.
auto pipeline = [&](Value value, Value clock, const Twine &name,
unsigned latency) {
for (unsigned i = 0; i < latency; ++i) {
std::string regName;
{
llvm::raw_string_ostream os(regName);
os << mem.getName() << "_" << name << "_" << i;
}
auto reg = rewriter
.create<RegOp>(mem.getLoc(), value.getType(), clock,
rewriter.getStringAttr(regName))
.getResult();
rewriter.create<StrictConnectOp>(value.getLoc(), reg, value);
value = reg;
}
return value;
};
const unsigned writeStages = info.writeLatency - 1;
// Traverse each port. Replace reads with the pipelined register, discarding
// the enable flag and reading unconditionally. Pipeline the mask, enable
// and data bits of all write ports to be arbitrated and wired to the reg.
SmallVector<std::tuple<Value, Value, Value>> writes;
for (auto [i, port] : llvm::enumerate(mem.getResults())) {
Value portClock = getPortFieldValue(port, "clk");
StringRef name = mem.getPortName(i);
auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
Value value = getPortFieldValue(port, field);
assert(value);
rewriter.setInsertionPointAfterValue(value);
return pipeline(value, portClock, name + "_" + field, stages);
};
switch (mem.getPortKind(i)) {
case MemOp::PortKind::Debug:
llvm_unreachable("unknown port kind");
case MemOp::PortKind::Read: {
// Read ports pipeline the addr and enable signals. However, the
// address must be 0 for single-address memories and the enable signal
// is ignored, always reading out the register. Under these constraints,
// the read port can be replaced with the value from the register.
rewriter.setInsertionPointAfterValue(reg);
replacePortField(rewriter, port, "data", reg);
break;
}
case MemOp::PortKind::Write: {
auto data = portPipeline("data", writeStages);
auto en = portPipeline("en", writeStages);
auto mask = portPipeline("mask", writeStages);
writes.emplace_back(data, en, mask);
break;
}
case MemOp::PortKind::ReadWrite: {
// Always read the register into the read end.
rewriter.setInsertionPointAfterValue(reg);
replacePortField(rewriter, port, "rdata", reg);
// Create a write enable and pipeline stages.
auto wdata = portPipeline("wdata", writeStages);
auto wmask = portPipeline("wmask", writeStages);
Value en = getPortFieldValue(port, "en");
Value wmode = getPortFieldValue(port, "wmode");
rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
auto wen = rewriter.create<AndPrimOp>(port.getLoc(), en, wmode);
auto wenPipelined =
pipeline(wen, portClock, name + "_wen", writeStages);
writes.emplace_back(wdata, wenPipelined, wmask);
break;
}
}
}
// Regardless of `writeUnderWrite`, always implement PortOrder.
rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
Value next = reg;
for (auto &[data, en, mask] : writes) {
Value masked;
// If a mask bit is used, emit muxes to select the input from the
// register (no mask) or the input (mask bit set).
Location loc = mem.getLoc();
unsigned maskGran = info.dataWidth / info.maskBits;
for (unsigned i = 0; i < info.maskBits; ++i) {
unsigned hi = (i + 1) * maskGran - 1;
unsigned lo = i * maskGran;
auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
if (masked) {
masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
} else {
masked = chunk;
}
}
next = rewriter.create<MuxPrimOp>(next.getLoc(), en, masked, next);
}
rewriter.create<StrictConnectOp>(reg.getLoc(), reg, next);
// Delete the fields and their associated connects.
for (Operation *conn : connects)
rewriter.eraseOp(conn);
for (auto portAccess : portAccesses)
rewriter.eraseOp(portAccess);
rewriter.eraseOp(mem);
return success();
}
};
} // namespace
void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
context);
}
//===----------------------------------------------------------------------===//
// Declarations
//===----------------------------------------------------------------------===//
// Turn synchronous reset looking register updates to registers with resets.
// Also, const prop registers that are driven by a mux tree containing only
// instances of one constant or self-assigns.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
// reg ; connect(reg, mux(port, const, val)) ->
// reg.reset(port, const); connect(reg, val)
// Find the one true connect, or bail
auto con = getSingleConnectUserOf(reg.getResult());
if (!con)
return failure();
auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
if (!mux)
return failure();
auto *high = mux.getHigh().getDefiningOp();
auto *low = mux.getLow().getDefiningOp();
// Reset value must be constant
auto constOp = dyn_cast_or_null<ConstantOp>(high);
// Detect the case if a register only has two possible drivers:
// (1) itself/uninit and (2) constant.
// The mux can then be replaced with the constant.
// r = mux(cond, r, 3) --> r = 3
// r = mux(cond, 3, r) --> r = 3
bool constReg = false;
if (constOp && low == reg)
constReg = true;
else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
constReg = true;
constOp = dyn_cast<ConstantOp>(low);
}
if (!constOp)
return failure();
// For a non-constant register, reset should be a module port (heuristic to
// limit to intended reset lines). Replace the register anyway if constant.
if (!isa<BlockArgument>(mux.getSel()) && !constReg)
return failure();
// Check all types should be typed by now
auto regTy = reg.getResult().getType();
if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
regTy.getBitWidthOrSentinel() < 0)
return failure();
// Ok, we know we are doing the transformation.
// Make sure the constant dominates all users.
if (constOp != &con->getBlock()->front())
constOp->moveBefore(&con->getBlock()->front());
if (!constReg) {
SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
reg.getForceableAttr());
newReg->setDialectAttrs(attrs);
}
auto pt = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(con);
auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
rewriter.restoreInsertionPoint(pt);
return success();
}
LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
succeeded(foldHiddenReset(op, rewriter)))
return success();
if (succeeded(demoteForceableIfUnused(op, rewriter)))
return success();
return failure();
}
//===----------------------------------------------------------------------===//
// Verification Ops.
//===----------------------------------------------------------------------===//
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
Value enable,
PatternRewriter &rewriter,
bool eraseIfZero) {
// If the verification op is never enabled, delete it.
if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
if (constant.getValue().isZero()) {
rewriter.eraseOp(op);
return success();
}
}
// If the verification op is never triggered, delete it.
if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
if (constant.getValue().isZero() == eraseIfZero) {
rewriter.eraseOp(op);
return success();
}
}
return failure();
}
template <class Op, bool EraseIfZero = false>
static LogicalResult canonicalizeImmediateVerifOp(Op op,
PatternRewriter &rewriter) {
return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
EraseIfZero);
}
void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(canonicalizeImmediateVerifOp<AssertOp>);
}
void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(canonicalizeImmediateVerifOp<AssumeOp>);
}
void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
}
//===----------------------------------------------------------------------===//
// InvalidValueOp
//===----------------------------------------------------------------------===//
LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
PatternRewriter &rewriter) {
// Remove `InvalidValueOp`s with no uses.
if (op.use_empty()) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// ClockGateIntrinsicOp
//===----------------------------------------------------------------------===//
OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
// Forward the clock if one of the enables is always true.
if (isConstantOne(adaptor.getEnable()) ||
isConstantOne(adaptor.getTestEnable()))
return getInput();
// Fold to a constant zero clock if the enables are always false.
if (isConstantZero(adaptor.getEnable()) &&
(!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
return BoolAttr::get(getContext(), false);
// Forward constant zero clocks.
if (isConstantZero(adaptor.getInput()))
return BoolAttr::get(getContext(), false);
return {};
}
LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
PatternRewriter &rewriter) {
// Remove constant false test enable.
if (auto testEnable = op.getTestEnable()) {
if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
if (constOp.getValue().isZero()) {
rewriter.modifyOpInPlace(op,
[&] { op.getTestEnableMutable().clear(); });
return success();
}
}
}
return failure();
}
//===----------------------------------------------------------------------===//
// Reference Ops.
//===----------------------------------------------------------------------===//
// refresolve(forceable.ref) -> forceable.data
static LogicalResult
canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
auto forceable = op.getRef().getDefiningOp<Forceable>();
if (!forceable || !forceable.isForceable() ||
op.getRef() != forceable.getDataRef() ||
op.getType() != forceable.getDataType())
return failure();
rewriter.replaceAllUsesWith(op, forceable.getData());
return success();
}
void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::RefResolveOfRefSend>(context);
results.insert(canonicalizeRefResolveOfForceable);
}
OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
// RefCast is unnecessary if types match.
if (getInput().getType() == getType())
return getInput();
return {};
}
static bool isConstantZero(Value operand) {
auto constOp = operand.getDefiningOp<ConstantOp>();
return constOp && constOp.getValue().isZero();
}
template <typename Op>
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
if (isConstantZero(op.getPredicate())) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(eraseIfPredFalse<RefForceOp>);
}
void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(eraseIfPredFalse<RefForceInitialOp>);
}
void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(eraseIfPredFalse<RefReleaseOp>);
}
void RefReleaseInitialOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add(eraseIfPredFalse<RefReleaseInitialOp>);
}
//===----------------------------------------------------------------------===//
// HasBeenResetIntrinsicOp
//===----------------------------------------------------------------------===//
OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
// The folds in here should reflect the ones for `verif::HasBeenResetOp`.
// Fold to zero if the reset is a constant. In this case the op is either
// permanently in reset or never resets. Both mean that the reset never
// finishes, so this op never returns true.
if (adaptor.getReset())
return getIntZerosAttr(UIntType::get(getContext(), 1));
// Fold to zero if the clock is a constant and the reset is synchronous. In
// that case the reset will never be started.
if (isUInt1(getReset().getType()) && adaptor.getClock())
return getIntZerosAttr(UIntType::get(getContext(), 1));
return {};
}