mirror of https://github.com/llvm/circt.git
3210 lines
120 KiB
C++
3210 lines
120 KiB
C++
//===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Dialect/Comb/CombOps.h"
|
|
#include "circt/Dialect/HW/HWAttributes.h"
|
|
#include "circt/Dialect/HW/HWOps.h"
|
|
#include "circt/Support/Naming.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/KnownBits.h"
|
|
|
|
using namespace mlir;
|
|
using namespace circt;
|
|
using namespace comb;
|
|
using namespace matchers;
|
|
|
|
/// In comb, we assume no knowledge of the semantics of cross-block dataflow. As
|
|
/// such, cross-block dataflow is interpreted as a canonicalization barrier.
|
|
/// This is a conservative approach which:
|
|
/// 1. still allows for efficient canonicalization for the common CIRCT usecase
|
|
/// of comb (comb logic nested inside single-block hw.module's)
|
|
/// 2. allows comb operations to be used in non-HW container ops - that may use
|
|
/// MLIR blocks and regions to represent various forms of hierarchical
|
|
/// abstractions, thus allowing comb to compose with other dialects.
|
|
static bool hasOperandsOutsideOfBlock(Operation *op) {
|
|
Block *thisBlock = op->getBlock();
|
|
return llvm::any_of(op->getOperands(), [&](Value operand) {
|
|
return operand.getParentBlock() != thisBlock;
|
|
});
|
|
}
|
|
|
|
/// Create a new instance of a generic operation that only has value operands,
|
|
/// and has a single result value whose type matches the first operand.
|
|
///
|
|
/// This should not be used to create instances of ops with attributes or with
|
|
/// more complicated type signatures.
|
|
static Value createGenericOp(Location loc, OperationName name,
|
|
ArrayRef<Value> operands, OpBuilder &builder) {
|
|
OperationState state(loc, name);
|
|
state.addOperands(operands);
|
|
state.addTypes(operands[0].getType());
|
|
return builder.create(state)->getResult(0);
|
|
}
|
|
|
|
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
|
|
return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
|
|
value);
|
|
}
|
|
|
|
/// Flatten concat and mux operands into a vector.
|
|
static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
|
|
if (auto concat = v.getDefiningOp<ConcatOp>()) {
|
|
for (auto op : concat.getOperands())
|
|
getConcatOperands(op, result);
|
|
} else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
|
|
for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
|
|
getConcatOperands(repl.getOperand(), result);
|
|
} else {
|
|
result.push_back(v);
|
|
}
|
|
}
|
|
|
|
// Return true if the op has SV attributes. Note that we cannot use a helper
|
|
// function `hasSVAttributes` defined under SV dialect because of a cyclic
|
|
// dependency.
|
|
static bool hasSVAttributes(Operation *op) {
|
|
return op->hasAttr("sv.attributes");
|
|
}
|
|
|
|
namespace {
|
|
template <typename SubType>
|
|
struct ComplementMatcher {
|
|
SubType lhs;
|
|
ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
|
|
bool match(Operation *op) {
|
|
auto xorOp = dyn_cast<XorOp>(op);
|
|
return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
|
|
}
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
template <typename SubType>
|
|
static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
|
|
return ComplementMatcher<SubType>(subExpr);
|
|
}
|
|
|
|
/// Return true if the op will be flattened afterwards. Op will be flattend if
|
|
/// it has a single user which has a same op type. User must be in same block.
|
|
static bool shouldBeFlattened(Operation *op) {
|
|
assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
|
|
"must be commutative operations"));
|
|
if (op->hasOneUse()) {
|
|
auto *user = *op->getUsers().begin();
|
|
return user->getName() == op->getName() &&
|
|
op->getAttrOfType<UnitAttr>("twoState") ==
|
|
user->getAttrOfType<UnitAttr>("twoState") &&
|
|
op->getBlock() == user->getBlock();
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
|
|
/// as an Op. Returns true if successful, and false otherwise.
|
|
///
|
|
/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
|
|
///
|
|
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
|
|
// Skip if the operation should be flattened by another operation.
|
|
if (shouldBeFlattened(op))
|
|
return false;
|
|
|
|
auto inputs = op->getOperands();
|
|
|
|
SmallVector<Value, 4> newOperands;
|
|
SmallVector<Location, 4> newLocations{op->getLoc()};
|
|
newOperands.reserve(inputs.size());
|
|
struct Element {
|
|
decltype(inputs.begin()) current, end;
|
|
};
|
|
|
|
SmallVector<Element> worklist;
|
|
worklist.push_back({inputs.begin(), inputs.end()});
|
|
bool binFlag = op->hasAttrOfType<UnitAttr>("twoState");
|
|
bool changed = false;
|
|
while (!worklist.empty()) {
|
|
auto &element = worklist.back(); // Do not pop. Take ref.
|
|
|
|
// Pop when we finished traversing the current operand range.
|
|
if (element.current == element.end) {
|
|
worklist.pop_back();
|
|
continue;
|
|
}
|
|
|
|
Value value = *element.current++;
|
|
auto *flattenOp = value.getDefiningOp();
|
|
// If not defined by a compatible operation of the same kind and
|
|
// from the same block, keep this as-is.
|
|
if (!flattenOp || flattenOp->getName() != op->getName() ||
|
|
flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState") ||
|
|
flattenOp->getBlock() != op->getBlock()) {
|
|
newOperands.push_back(value);
|
|
continue;
|
|
}
|
|
|
|
// Don't duplicate logic when it has multiple uses.
|
|
if (!value.hasOneUse()) {
|
|
// We can fold a multi-use binary operation into this one if this allows a
|
|
// constant to fold though. For example, fold
|
|
// (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
|
|
// since the constants will both fold and we end up with the equiv cost.
|
|
//
|
|
// We don't do this for add/mul because the hardware won't be shared
|
|
// between the two ops if duplicated.
|
|
if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
|
|
!flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
|
|
!inputs.back().getDefiningOp<hw::ConstantOp>()) {
|
|
newOperands.push_back(value);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
changed = true;
|
|
|
|
// Otherwise, push operands into worklist.
|
|
auto flattenOpInputs = flattenOp->getOperands();
|
|
worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
|
|
newLocations.push_back(flattenOp->getLoc());
|
|
}
|
|
|
|
if (!changed)
|
|
return false;
|
|
|
|
Value result = createGenericOp(FusedLoc::get(op->getContext(), newLocations),
|
|
op->getName(), newOperands, rewriter);
|
|
if (binFlag)
|
|
result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
|
|
|
|
replaceOpAndCopyNamehint(rewriter, op, result);
|
|
return true;
|
|
}
|
|
|
|
// Given a range of uses of an operation, find the lowest and highest bits
|
|
// inclusive that are ever referenced. The range of uses must not be empty.
|
|
static std::pair<size_t, size_t>
|
|
getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
|
|
size_t originalOpWidth) {
|
|
auto users = op->getUsers();
|
|
assert(!users.empty() &&
|
|
"getLowestBitAndHighestBitRequired cannot operate on "
|
|
"a empty list of uses.");
|
|
|
|
// when we don't want to narrowTrailingBits (namely in arithmetic
|
|
// operations), forcing lowestBitRequired = 0
|
|
size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
|
|
size_t highestBitRequired = 0;
|
|
|
|
for (auto *user : users) {
|
|
if (auto extractOp = dyn_cast<ExtractOp>(user)) {
|
|
size_t lowBit = extractOp.getLowBit();
|
|
size_t highBit =
|
|
cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
|
|
highestBitRequired = std::max(highestBitRequired, highBit);
|
|
lowestBitRequired = std::min(lowestBitRequired, lowBit);
|
|
continue;
|
|
}
|
|
|
|
highestBitRequired = originalOpWidth - 1;
|
|
lowestBitRequired = 0;
|
|
break;
|
|
}
|
|
|
|
return {lowestBitRequired, highestBitRequired};
|
|
}
|
|
|
|
template <class OpTy>
|
|
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
|
|
PatternRewriter &rewriter) {
|
|
IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
|
|
if (!opType)
|
|
return false;
|
|
|
|
auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
|
|
opType.getWidth());
|
|
if (range.second + 1 == opType.getWidth() && range.first == 0)
|
|
return false;
|
|
|
|
SmallVector<Value> args;
|
|
auto newType = rewriter.getIntegerType(range.second - range.first + 1);
|
|
for (auto inop : op.getOperands()) {
|
|
// deal with muxes here
|
|
if (inop.getType() != op.getType())
|
|
args.push_back(inop);
|
|
else
|
|
args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
|
|
inop, range.first));
|
|
}
|
|
auto newop = rewriter.create<OpTy>(op.getLoc(), newType, args);
|
|
newop->setDialectAttrs(op->getDialectAttrs());
|
|
if (op.getTwoState())
|
|
newop.setTwoState(true);
|
|
|
|
Value newResult = newop.getResult();
|
|
if (range.first)
|
|
newResult = rewriter.createOrFold<ConcatOp>(
|
|
op.getLoc(), newResult,
|
|
rewriter.create<hw::ConstantOp>(op.getLoc(),
|
|
APInt::getZero(range.first)));
|
|
if (range.second + 1 < opType.getWidth())
|
|
newResult = rewriter.createOrFold<ConcatOp>(
|
|
op.getLoc(),
|
|
rewriter.create<hw::ConstantOp>(
|
|
op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
|
|
newResult);
|
|
rewriter.replaceOp(op, newResult);
|
|
return true;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Unary Operations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
// Replicate one time -> noop.
|
|
if (cast<IntegerType>(getType()).getWidth() ==
|
|
getInput().getType().getIntOrFloatBitWidth())
|
|
return getInput();
|
|
|
|
// Constant fold.
|
|
if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
|
|
if (input.getValue().getBitWidth() == 1) {
|
|
if (input.getValue().isZero())
|
|
return getIntAttr(
|
|
APInt::getZero(cast<IntegerType>(getType()).getWidth()),
|
|
getContext());
|
|
return getIntAttr(
|
|
APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
|
|
getContext());
|
|
}
|
|
|
|
APInt result = APInt::getZeroWidth();
|
|
for (auto i = getMultiple(); i != 0; --i)
|
|
result = result.concat(input.getValue());
|
|
return getIntAttr(result, getContext());
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
// Constant fold.
|
|
if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
|
|
return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
|
|
|
|
return {};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Binary Operations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Performs constant folding `calculate` with element-wise behavior on the two
|
|
/// attributes in `operands` and returns the result if possible.
|
|
static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
|
|
hw::PEO paramOpcode) {
|
|
assert(operands.size() == 2 && "binary op takes two operands");
|
|
if (!operands[0] || !operands[1])
|
|
return {};
|
|
|
|
// Fold constants with ParamExprAttr::get which handles simple constants as
|
|
// well as parameter expressions.
|
|
return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
|
|
cast<TypedAttr>(operands[1]));
|
|
}
|
|
|
|
OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
|
|
unsigned shift = rhs.getValue().getZExtValue();
|
|
unsigned width = getType().getIntOrFloatBitWidth();
|
|
if (shift == 0)
|
|
return getOperand(0);
|
|
if (width <= shift)
|
|
return getIntAttr(APInt::getZero(width), getContext());
|
|
}
|
|
|
|
return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
|
|
}
|
|
|
|
LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
// ShlOp(x, cst) -> Concat(Extract(x), zeros)
|
|
APInt value;
|
|
if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
|
|
return failure();
|
|
|
|
unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
|
|
unsigned shift = value.getZExtValue();
|
|
|
|
// This case is handled by fold.
|
|
if (width <= shift || shift == 0)
|
|
return failure();
|
|
|
|
auto zeros =
|
|
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
|
|
|
|
// Remove the high bits which would be removed by the Shl.
|
|
auto extract =
|
|
rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
|
|
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
|
|
return success();
|
|
}
|
|
|
|
OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
|
|
unsigned shift = rhs.getValue().getZExtValue();
|
|
if (shift == 0)
|
|
return getOperand(0);
|
|
|
|
unsigned width = getType().getIntOrFloatBitWidth();
|
|
if (width <= shift)
|
|
return getIntAttr(APInt::getZero(width), getContext());
|
|
}
|
|
return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
|
|
}
|
|
|
|
LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
// ShrUOp(x, cst) -> Concat(zeros, Extract(x))
|
|
APInt value;
|
|
if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
|
|
return failure();
|
|
|
|
unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
|
|
unsigned shift = value.getZExtValue();
|
|
|
|
// This case is handled by fold.
|
|
if (width <= shift || shift == 0)
|
|
return failure();
|
|
|
|
auto zeros =
|
|
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
|
|
|
|
// Remove the low bits which would be removed by the Shr.
|
|
auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
|
|
width - shift);
|
|
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
|
|
return success();
|
|
}
|
|
|
|
OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
|
|
if (rhs.getValue().getZExtValue() == 0)
|
|
return getOperand(0);
|
|
}
|
|
return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
|
|
}
|
|
|
|
LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
// ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
|
|
APInt value;
|
|
if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
|
|
return failure();
|
|
|
|
unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
|
|
unsigned shift = value.getZExtValue();
|
|
|
|
auto topbit =
|
|
rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
|
|
auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
|
|
|
|
if (width <= shift) {
|
|
replaceOpAndCopyNamehint(rewriter, op, {sext});
|
|
return success();
|
|
}
|
|
|
|
auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
|
|
width - shift);
|
|
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Other Operations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
// If we are extracting the entire input, then return it.
|
|
if (getInput().getType() == getType())
|
|
return getInput();
|
|
|
|
// Constant fold.
|
|
if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
|
|
unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
|
|
return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
|
|
getContext());
|
|
}
|
|
return {};
|
|
}
|
|
|
|
// Transforms extract(lo, cat(a, b, c, d, e)) into
|
|
// cat(extract(lo1, b), c, extract(lo2, d)).
|
|
// innerCat must be the argument of the provided ExtractOp.
|
|
static LogicalResult extractConcatToConcatExtract(ExtractOp op,
|
|
ConcatOp innerCat,
|
|
PatternRewriter &rewriter) {
|
|
auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
|
|
size_t beginOfFirstRelevantElement = 0;
|
|
auto it = reversedConcatArgs.begin();
|
|
size_t lowBit = op.getLowBit();
|
|
|
|
// This loop finds the first concatArg that is covered by the ExtractOp
|
|
for (; it != reversedConcatArgs.end(); it++) {
|
|
assert(beginOfFirstRelevantElement <= lowBit &&
|
|
"incorrectly moved past an element that lowBit has coverage over");
|
|
auto operand = *it;
|
|
|
|
size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
|
|
if (lowBit < beginOfFirstRelevantElement + operandWidth) {
|
|
// A bit other than the first bit will be used in this element.
|
|
// ...... ........ ...
|
|
// ^---lowBit
|
|
// ^---beginOfFirstRelevantElement
|
|
//
|
|
// Edge-case close to the end of the range.
|
|
// ...... ........ ...
|
|
// ^---(position + operandWidth)
|
|
// ^---lowBit
|
|
// ^---beginOfFirstRelevantElement
|
|
//
|
|
// Edge-case close to the beginning of the rang
|
|
// ...... ........ ...
|
|
// ^---lowBit
|
|
// ^---beginOfFirstRelevantElement
|
|
//
|
|
break;
|
|
}
|
|
|
|
// extraction discards this element.
|
|
// ...... ........ ...
|
|
// | ^---lowBit
|
|
// ^---beginOfFirstRelevantElement
|
|
beginOfFirstRelevantElement += operandWidth;
|
|
}
|
|
assert(it != reversedConcatArgs.end() &&
|
|
"incorrectly failed to find an element which contains coverage of "
|
|
"lowBit");
|
|
|
|
SmallVector<Value> reverseConcatArgs;
|
|
size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
|
|
size_t extractLo = lowBit - beginOfFirstRelevantElement;
|
|
|
|
// Transform individual arguments of innerCat(..., a, b, c,) into
|
|
// [ extract(a), b, extract(c) ], skipping an extract operation where
|
|
// possible.
|
|
for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
|
|
auto concatArg = *it;
|
|
size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
|
|
size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
|
|
|
|
if (widthToConsume == operandWidth && extractLo == 0) {
|
|
reverseConcatArgs.push_back(concatArg);
|
|
} else {
|
|
auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
|
|
reverseConcatArgs.push_back(
|
|
rewriter.create<ExtractOp>(op.getLoc(), resultType, *it, extractLo));
|
|
}
|
|
|
|
widthRemaining -= widthToConsume;
|
|
|
|
// Beyond the first element, all elements are extracted from position 0.
|
|
extractLo = 0;
|
|
}
|
|
|
|
if (reverseConcatArgs.size() == 1) {
|
|
replaceOpAndCopyNamehint(rewriter, op, reverseConcatArgs[0]);
|
|
} else {
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
|
|
rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
|
|
}
|
|
return success();
|
|
}
|
|
|
|
// Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
|
|
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
|
|
PatternRewriter &rewriter) {
|
|
auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
|
|
auto replicateEltWidth =
|
|
replicate.getOperand().getType().getIntOrFloatBitWidth();
|
|
|
|
// If the extract starts at the base of an element and is an even multiple,
|
|
// we can replace the extract with a smaller replicate.
|
|
if (op.getLowBit() % replicateEltWidth == 0 &&
|
|
extractResultWidth % replicateEltWidth == 0) {
|
|
replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
|
|
replicate.getOperand());
|
|
return true;
|
|
}
|
|
|
|
// If the extract is completely contained in one element, extract from the
|
|
// element.
|
|
if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
|
|
replicateEltWidth) {
|
|
replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
|
|
rewriter, op, op.getType(), replicate.getOperand(),
|
|
op.getLowBit() % replicateEltWidth);
|
|
return true;
|
|
}
|
|
|
|
// We don't currently handle the case of extracting from non-whole elements,
|
|
// e.g. `extract (replicate 2-bit-thing, N), 1`.
|
|
return false;
|
|
}
|
|
|
|
LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
auto *inputOp = op.getInput().getDefiningOp();
|
|
|
|
// This turns out to be incredibly expensive. Disable until performance is
|
|
// addressed.
|
|
#if 0
|
|
// If the extracted bits are all known, then return the result.
|
|
auto knownBits = computeKnownBits(op.getInput())
|
|
.extractBits(cast<IntegerType>(op.getType()).getWidth(),
|
|
op.getLowBit());
|
|
if (knownBits.isConstant()) {
|
|
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
|
|
knownBits.getConstant());
|
|
return success();
|
|
}
|
|
#endif
|
|
|
|
// extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
|
|
if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
|
|
replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
|
|
rewriter, op, op.getType(), innerExtract.getInput(),
|
|
innerExtract.getLowBit() + op.getLowBit());
|
|
return success();
|
|
}
|
|
|
|
// extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
|
|
if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
|
|
return extractConcatToConcatExtract(op, innerCat, rewriter);
|
|
|
|
// extract(lo, replicate(a))
|
|
if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
|
|
if (extractFromReplicate(op, replicate, rewriter))
|
|
return success();
|
|
|
|
// `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
|
|
// and/or/xor are not modifying the extracted bits.
|
|
if (inputOp && inputOp->getNumOperands() == 2 &&
|
|
isa<AndOp, OrOp, XorOp>(inputOp)) {
|
|
if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
|
|
auto extractedCst = cstRHS.getValue().extractBits(
|
|
cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
|
|
if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
|
|
replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
|
|
rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
|
|
return success();
|
|
}
|
|
|
|
// `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
|
|
// extract to represent the result. Turning it into a pile of extracts is
|
|
// always fine by our cost model, but we don't want to explode things into
|
|
// a ton of bits because it will bloat the IR and generated Verilog.
|
|
if (isa<AndOp>(inputOp)) {
|
|
// For our cost model, we only do this if the bit pattern is a
|
|
// contiguous series of ones.
|
|
unsigned lz = extractedCst.countLeadingZeros();
|
|
unsigned tz = extractedCst.countTrailingZeros();
|
|
unsigned pop = extractedCst.popcount();
|
|
if (extractedCst.getBitWidth() - lz - tz == pop) {
|
|
auto resultTy = rewriter.getIntegerType(pop);
|
|
SmallVector<Value> resultElts;
|
|
if (lz)
|
|
resultElts.push_back(rewriter.create<hw::ConstantOp>(
|
|
op.getLoc(), APInt::getZero(lz)));
|
|
resultElts.push_back(rewriter.createOrFold<ExtractOp>(
|
|
op.getLoc(), resultTy, inputOp->getOperand(0),
|
|
op.getLowBit() + tz));
|
|
if (tz)
|
|
resultElts.push_back(rewriter.create<hw::ConstantOp>(
|
|
op.getLoc(), APInt::getZero(tz)));
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
|
|
return success();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
|
|
// extracted.
|
|
if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
|
|
if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
|
|
// Don't canonicalize if the shift is multiply used.
|
|
if (shlOp->hasOneUse())
|
|
if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
|
|
if (lhsCst.getValue().isOne()) {
|
|
auto newCst = rewriter.create<hw::ConstantOp>(
|
|
shlOp.getLoc(),
|
|
APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
|
|
replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
|
|
rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
|
|
false);
|
|
return success();
|
|
}
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Associative Variadic operations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Reduce all operands to a single value (either integer constant or parameter
|
|
// expression) if all the operands are constants.
|
|
static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
|
|
hw::PEO paramOpcode) {
|
|
assert(operands.size() > 1 && "caller should handle one-operand case");
|
|
// We can only fold anything in the case where all operands are known to be
|
|
// constants. Check the least common one first for an early out.
|
|
if (!operands[1] || !operands[0])
|
|
return {};
|
|
|
|
// This will fold to a simple constant if all operands are constant.
|
|
if (llvm::all_of(operands.drop_front(2),
|
|
[&](Attribute in) { return !!in; })) {
|
|
SmallVector<mlir::TypedAttr> typedOperands;
|
|
typedOperands.reserve(operands.size());
|
|
for (auto operand : operands) {
|
|
if (auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
|
|
typedOperands.push_back(typedOperand);
|
|
else
|
|
break;
|
|
}
|
|
if (typedOperands.size() == operands.size())
|
|
return hw::ParamExprAttr::get(paramOpcode, typedOperands);
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
/// When we find a logical operation (and, or, xor) with a constant e.g.
|
|
/// `X & 42`, we want to push the constant into the computation of X if it leads
|
|
/// to simplification.
|
|
///
|
|
/// This function handles the case where the logical operation has a concat
|
|
/// operand. We check to see if we can simplify the concat, e.g. when it has
|
|
/// constant operands.
|
|
///
|
|
/// This returns true when a simplification happens.
|
|
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
|
|
size_t concatIdx, const APInt &cst,
|
|
PatternRewriter &rewriter) {
|
|
auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
|
|
assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
|
|
|
|
// Check to see if any operands can be simplified by pushing the logical op
|
|
// into all parts of the concat.
|
|
bool canSimplify =
|
|
llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
|
|
auto *operandOp = operand.getDefiningOp();
|
|
if (!operandOp)
|
|
return false;
|
|
|
|
// If the concat has a constant operand then we can transform this.
|
|
if (isa<hw::ConstantOp>(operandOp))
|
|
return true;
|
|
// If the concat has the same logical operation and that operation has
|
|
// a constant operation than we can fold it into that suboperation.
|
|
return operandOp->getName() == logicalOp->getName() &&
|
|
operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
|
|
operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
|
|
});
|
|
|
|
if (!canSimplify)
|
|
return false;
|
|
|
|
// Create a new instance of the logical operation. We have to do this the
|
|
// hard way since we're generic across a family of different ops.
|
|
auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
|
|
return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
|
|
rewriter);
|
|
};
|
|
|
|
// Ok, let's do the transformation. We do this by slicing up the constant
|
|
// for each unit of the concat and duplicate the operation into the
|
|
// sub-operand.
|
|
SmallVector<Value> newConcatOperands;
|
|
newConcatOperands.reserve(concatOp->getNumOperands());
|
|
|
|
// Work from MSB to LSB.
|
|
size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
|
|
for (Value operand : concatOp->getOperands()) {
|
|
size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
|
|
nextOperandBit -= operandWidth;
|
|
// Take a slice of the constant.
|
|
auto eltCst = rewriter.create<hw::ConstantOp>(
|
|
logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
|
|
|
|
newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
|
|
}
|
|
|
|
// Create the concat, and the rest of the logical op if we need it.
|
|
Value newResult =
|
|
rewriter.create<ConcatOp>(concatOp.getLoc(), newConcatOperands);
|
|
|
|
// If we had a variadic logical op on the top level, then recreate it with the
|
|
// new concat and without the constant operand.
|
|
if (logicalOp->getNumOperands() > 2) {
|
|
auto origOperands = logicalOp->getOperands();
|
|
SmallVector<Value> operands;
|
|
// Take any stuff before the concat.
|
|
operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
|
|
// Take any stuff after the concat but before the constant.
|
|
operands.append(origOperands.begin() + concatIdx + 1,
|
|
origOperands.begin() + (origOperands.size() - 1));
|
|
// Include the new concat.
|
|
operands.push_back(newResult);
|
|
newResult = createLogicalOp(operands);
|
|
}
|
|
|
|
replaceOpAndCopyNamehint(rewriter, logicalOp, newResult);
|
|
return true;
|
|
}
|
|
|
|
// Determines whether the inputs to a logical element are of opposite
|
|
// comparisons and can lowered into a constant.
|
|
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands) {
|
|
llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
|
|
|
|
for (auto op : operands) {
|
|
if (auto icmpOp = op.getDefiningOp<ICmpOp>();
|
|
icmpOp && icmpOp.getTwoState()) {
|
|
auto predicate = icmpOp.getPredicate();
|
|
auto lhs = icmpOp.getLhs();
|
|
auto rhs = icmpOp.getRhs();
|
|
if (seenPredicates.contains(
|
|
{ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
|
|
return true;
|
|
|
|
seenPredicates.insert({predicate, lhs, rhs});
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
|
|
APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).getWidth());
|
|
|
|
auto inputs = adaptor.getInputs();
|
|
|
|
// and(x, 01, 10) -> 00 -- annulment.
|
|
for (auto operand : inputs) {
|
|
if (!operand)
|
|
continue;
|
|
value &= cast<IntegerAttr>(operand).getValue();
|
|
if (value.isZero())
|
|
return getIntAttr(value, getContext());
|
|
}
|
|
|
|
// and(x, -1) -> x.
|
|
if (inputs.size() == 2 && inputs[1] &&
|
|
cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
|
|
return getInputs()[0];
|
|
|
|
// and(x, x, x) -> x. This also handles and(x) -> x.
|
|
if (llvm::all_of(getInputs(),
|
|
[&](auto in) { return in == this->getInputs()[0]; }))
|
|
return getInputs()[0];
|
|
|
|
// and(..., x, ..., ~x, ...) -> 0
|
|
for (Value arg : getInputs()) {
|
|
Value subExpr;
|
|
if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
|
|
for (Value arg2 : getInputs())
|
|
if (arg2 == subExpr)
|
|
return getIntAttr(
|
|
APInt::getZero(cast<IntegerType>(getType()).getWidth()),
|
|
getContext());
|
|
}
|
|
}
|
|
|
|
// x0 = icmp(pred, x, y)
|
|
// x1 = icmp(!pred, x, y)
|
|
// and(x0, x1) -> 0
|
|
if (canCombineOppositeBinCmpIntoConstant(getInputs()))
|
|
return getIntAttr(APInt::getZero(cast<IntegerType>(getType()).getWidth()),
|
|
getContext());
|
|
|
|
// Constant fold
|
|
return constFoldAssociativeOp(inputs, hw::PEO::And);
|
|
}
|
|
|
|
/// Returns a single common operand that all inputs of the operation `op` can
|
|
/// be traced back to, or an empty `Value` if no such operand exists.
|
|
///
|
|
/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
|
|
/// (assuming the bit-width of `a` is `n`).
|
|
template <typename Op>
|
|
static Value getCommonOperand(Op op) {
|
|
if (!op.getType().isInteger(1))
|
|
return Value();
|
|
|
|
auto inputs = op.getInputs();
|
|
size_t size = inputs.size();
|
|
|
|
auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
|
|
if (!sourceOp)
|
|
return Value();
|
|
Value source = sourceOp.getOperand();
|
|
|
|
// Fast path: the input size is not equal to the width of the source.
|
|
if (size != source.getType().getIntOrFloatBitWidth())
|
|
return Value();
|
|
|
|
// Tracks the bits that were encountered.
|
|
llvm::BitVector bits(size);
|
|
bits.set(sourceOp.getLowBit());
|
|
|
|
for (size_t i = 1; i != size; ++i) {
|
|
auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
|
|
if (!extractOp || extractOp.getOperand() != source)
|
|
return Value();
|
|
bits.set(extractOp.getLowBit());
|
|
}
|
|
|
|
return bits.all() ? source : Value();
|
|
}
|
|
|
|
/// Canonicalize an idempotent operation `op` so that only one input of any kind
|
|
/// occurs.
|
|
///
|
|
/// Example: `and(x, y, x, z)` -> `and(x, y, z)`
|
|
template <typename Op>
|
|
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
|
|
// Depth limit to search, in operations. Chosen arbitrarily, keep small.
|
|
constexpr unsigned limit = 3;
|
|
auto inputs = op.getInputs();
|
|
|
|
llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
|
|
llvm::SmallDenseSet<Op, 8> checked;
|
|
checked.insert(op);
|
|
|
|
struct OpWithDepth {
|
|
Op op;
|
|
unsigned depth;
|
|
};
|
|
llvm::SmallVector<OpWithDepth, 8> worklist;
|
|
|
|
auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) {
|
|
// Add to worklist if within depth limit, is defined in the same block by
|
|
// the same kind of operation, has same two-state-ness, and not enqueued
|
|
// previously.
|
|
if (depth < limit && input.getParentBlock() == op->getBlock()) {
|
|
auto inputOp = input.template getDefiningOp<Op>();
|
|
if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
|
|
checked.insert(inputOp).second)
|
|
worklist.push_back({inputOp, depth + 1});
|
|
}
|
|
};
|
|
|
|
for (auto input : uniqueInputs)
|
|
enqueue(input, 0);
|
|
|
|
while (!worklist.empty()) {
|
|
auto item = worklist.pop_back_val();
|
|
|
|
for (auto input : item.op.getInputs()) {
|
|
uniqueInputs.remove(input);
|
|
enqueue(input, item.depth);
|
|
}
|
|
}
|
|
|
|
if (uniqueInputs.size() < inputs.size()) {
|
|
replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
|
|
uniqueInputs.getArrayRef(),
|
|
op.getTwoState());
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
|
|
auto inputs = op.getInputs();
|
|
auto size = inputs.size();
|
|
|
|
// and(x, and(...)) -> and(x, ...) -- flatten
|
|
if (tryFlatteningOperands(op, rewriter))
|
|
return success();
|
|
|
|
// and(..., x, ..., x) -> and(..., x, ...) -- idempotent
|
|
// and(..., x, and(..., x, ...)) -> and(..., and(..., x, ...)) -- idempotent
|
|
// Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
|
|
if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
|
|
return success();
|
|
|
|
assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
|
|
|
|
// Patterns for and with a constant on RHS.
|
|
APInt value;
|
|
if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
|
|
// and(..., '1) -> and(...) -- identity
|
|
if (value.isAllOnes()) {
|
|
replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
|
|
inputs.drop_back(), false);
|
|
return success();
|
|
}
|
|
|
|
// TODO: Combine multiple constants together even if they aren't at the
|
|
// end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
|
|
// folding
|
|
APInt value2;
|
|
if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
|
|
auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value & value2);
|
|
SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
|
|
newOperands.push_back(cst);
|
|
replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
|
|
newOperands, false);
|
|
return success();
|
|
}
|
|
|
|
// Handle 'and' with a single bit constant on the RHS.
|
|
if (size == 2 && value.isPowerOf2()) {
|
|
// If the LHS is a replicate from a single bit, we can 'concat' it
|
|
// into place. e.g.:
|
|
// `replicate(x) & 4` -> `concat(zeros, x, zeros)`
|
|
// TODO: Generalize this for non-single-bit operands.
|
|
if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
|
|
auto replicateOperand = replicate.getOperand();
|
|
if (replicateOperand.getType().isInteger(1)) {
|
|
unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
|
|
auto trailingZeros = value.countTrailingZeros();
|
|
|
|
// Don't add zero bit constants unnecessarily.
|
|
SmallVector<Value, 3> concatOperands;
|
|
if (trailingZeros != resultWidth - 1) {
|
|
auto highZeros = rewriter.create<hw::ConstantOp>(
|
|
op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
|
|
concatOperands.push_back(highZeros);
|
|
}
|
|
concatOperands.push_back(replicateOperand);
|
|
if (trailingZeros != 0) {
|
|
auto lowZeros = rewriter.create<hw::ConstantOp>(
|
|
op.getLoc(), APInt::getZero(trailingZeros));
|
|
concatOperands.push_back(lowZeros);
|
|
}
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
|
|
rewriter, op, op.getType(), concatOperands);
|
|
return success();
|
|
}
|
|
}
|
|
}
|
|
|
|
// If this is an and from an extract op, try shrinking the extract.
|
|
if (auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
|
|
if (size == 2 &&
|
|
// We can shrink it if the mask has leading or trailing zeros.
|
|
(value.countLeadingZeros() || value.countTrailingZeros())) {
|
|
unsigned lz = value.countLeadingZeros();
|
|
unsigned tz = value.countTrailingZeros();
|
|
|
|
// Start by extracting the smaller number of bits.
|
|
auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
|
|
Value smallElt = rewriter.createOrFold<ExtractOp>(
|
|
extractOp.getLoc(), smallTy, extractOp->getOperand(0),
|
|
extractOp.getLowBit() + tz);
|
|
// Apply the 'and' mask if needed.
|
|
APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
|
|
if (!smallMask.isAllOnes()) {
|
|
auto loc = inputs.back().getLoc();
|
|
smallElt = rewriter.createOrFold<AndOp>(
|
|
loc, smallElt, rewriter.create<hw::ConstantOp>(loc, smallMask),
|
|
false);
|
|
}
|
|
|
|
// The final replacement will be a concat of the leading/trailing zeros
|
|
// along with the smaller extracted value.
|
|
SmallVector<Value> resultElts;
|
|
if (lz)
|
|
resultElts.push_back(
|
|
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
|
|
resultElts.push_back(smallElt);
|
|
if (tz)
|
|
resultElts.push_back(
|
|
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
|
|
return success();
|
|
}
|
|
}
|
|
|
|
// and(concat(x, cst1), a, b, c, cst2)
|
|
// ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
|
|
// We do this for even more multi-use concats since they are "just wiring".
|
|
for (size_t i = 0; i < size - 1; ++i) {
|
|
if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
|
|
if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
|
|
return success();
|
|
}
|
|
}
|
|
|
|
// extracts only of and(...) -> and(extract()...)
|
|
if (narrowOperationWidth(op, true, rewriter))
|
|
return success();
|
|
|
|
// and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
|
|
if (auto source = getCommonOperand(op)) {
|
|
auto cmpAgainst =
|
|
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
|
|
replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
|
|
source, cmpAgainst);
|
|
return success();
|
|
}
|
|
|
|
/// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
|
|
return failure();
|
|
}
|
|
|
|
OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
|
|
auto value = APInt::getZero(cast<IntegerType>(getType()).getWidth());
|
|
auto inputs = adaptor.getInputs();
|
|
// or(x, 10, 01) -> 11
|
|
for (auto operand : inputs) {
|
|
if (!operand)
|
|
continue;
|
|
value |= cast<IntegerAttr>(operand).getValue();
|
|
if (value.isAllOnes())
|
|
return getIntAttr(value, getContext());
|
|
}
|
|
|
|
// or(x, 0) -> x
|
|
if (inputs.size() == 2 && inputs[1] &&
|
|
cast<IntegerAttr>(inputs[1]).getValue().isZero())
|
|
return getInputs()[0];
|
|
|
|
// or(x, x, x) -> x. This also handles or(x) -> x
|
|
if (llvm::all_of(getInputs(),
|
|
[&](auto in) { return in == this->getInputs()[0]; }))
|
|
return getInputs()[0];
|
|
|
|
// or(..., x, ..., ~x, ...) -> -1
|
|
for (Value arg : getInputs()) {
|
|
Value subExpr;
|
|
if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
|
|
for (Value arg2 : getInputs())
|
|
if (arg2 == subExpr)
|
|
return getIntAttr(
|
|
APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
|
|
getContext());
|
|
}
|
|
}
|
|
|
|
// x0 = icmp(pred, x, y)
|
|
// x1 = icmp(!pred, x, y)
|
|
// or(x0, x1) -> 1
|
|
if (canCombineOppositeBinCmpIntoConstant(getInputs()))
|
|
return getIntAttr(
|
|
APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
|
|
getContext());
|
|
|
|
// Constant fold
|
|
return constFoldAssociativeOp(inputs, hw::PEO::Or);
|
|
}
|
|
|
|
LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
|
|
auto inputs = op.getInputs();
|
|
auto size = inputs.size();
|
|
|
|
// or(x, or(...)) -> or(x, ...) -- flatten
|
|
if (tryFlatteningOperands(op, rewriter))
|
|
return success();
|
|
|
|
// or(..., x, ..., x, ...) -> or(..., x) -- idempotent
|
|
// or(..., x, or(..., x, ...)) -> or(..., or(..., x, ...)) -- idempotent
|
|
// Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
|
|
if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
|
|
return success();
|
|
|
|
assert(size > 1 && "expected 2 or more operands");
|
|
|
|
// Patterns for and with a constant on RHS.
|
|
APInt value;
|
|
if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
|
|
// or(..., '0) -> or(...) -- identity
|
|
if (value.isZero()) {
|
|
replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
|
|
inputs.drop_back());
|
|
return success();
|
|
}
|
|
|
|
// or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
|
|
APInt value2;
|
|
if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
|
|
auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value | value2);
|
|
SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
|
|
newOperands.push_back(cst);
|
|
replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
|
|
newOperands);
|
|
return success();
|
|
}
|
|
|
|
// or(concat(x, cst1), a, b, c, cst2)
|
|
// ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
|
|
// We do this for even more multi-use concats since they are "just wiring".
|
|
for (size_t i = 0; i < size - 1; ++i) {
|
|
if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
|
|
if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
|
|
return success();
|
|
}
|
|
}
|
|
|
|
// extracts only of or(...) -> or(extract()...)
|
|
if (narrowOperationWidth(op, true, rewriter))
|
|
return success();
|
|
|
|
// or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
|
|
if (auto source = getCommonOperand(op)) {
|
|
auto cmpAgainst =
|
|
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
|
|
replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
|
|
source, cmpAgainst);
|
|
return success();
|
|
}
|
|
|
|
// or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
|
|
// .., c_n), a, 0)
|
|
if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
|
|
APInt value;
|
|
if (op.getTwoState() && firstMux.getTwoState() &&
|
|
matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
|
|
value.isZero()) {
|
|
SmallVector<Value> conditions{firstMux.getCond()};
|
|
auto check = [&](Value v) {
|
|
auto mux = v.getDefiningOp<comb::MuxOp>();
|
|
if (!mux)
|
|
return false;
|
|
conditions.push_back(mux.getCond());
|
|
return mux.getTwoState() &&
|
|
firstMux.getTrueValue() == mux.getTrueValue() &&
|
|
firstMux.getFalseValue() == mux.getFalseValue();
|
|
};
|
|
if (llvm::all_of(op.getOperands().drop_front(), check)) {
|
|
auto cond = rewriter.create<comb::OrOp>(op.getLoc(), conditions, true);
|
|
replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
|
|
rewriter, op, cond, firstMux.getTrueValue(),
|
|
firstMux.getFalseValue(), true);
|
|
return success();
|
|
}
|
|
}
|
|
}
|
|
|
|
/// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
|
|
return failure();
|
|
}
|
|
|
|
OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
|
|
auto size = getInputs().size();
|
|
auto inputs = adaptor.getInputs();
|
|
|
|
// xor(x) -> x -- noop
|
|
if (size == 1)
|
|
return getInputs()[0];
|
|
|
|
// xor(x, x) -> 0 -- idempotent
|
|
if (size == 2 && getInputs()[0] == getInputs()[1])
|
|
return IntegerAttr::get(getType(), 0);
|
|
|
|
// xor(x, 0) -> x
|
|
if (inputs.size() == 2 && inputs[1] &&
|
|
cast<IntegerAttr>(inputs[1]).getValue().isZero())
|
|
return getInputs()[0];
|
|
|
|
// xor(xor(x,1),1) -> x
|
|
// but not self loop
|
|
if (isBinaryNot()) {
|
|
Value subExpr;
|
|
if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) &&
|
|
subExpr != getResult())
|
|
return subExpr;
|
|
}
|
|
|
|
// Constant fold
|
|
return constFoldAssociativeOp(inputs, hw::PEO::Xor);
|
|
}
|
|
|
|
// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
|
|
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
|
|
PatternRewriter &rewriter) {
|
|
auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
|
|
auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
|
|
|
|
Value result =
|
|
rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
|
|
icmp.getOperand(1), icmp.getTwoState());
|
|
|
|
// If the xor had other operands, rebuild it.
|
|
if (op.getNumOperands() > 2) {
|
|
SmallVector<Value, 4> newOperands(op.getOperands());
|
|
newOperands.pop_back();
|
|
newOperands.erase(newOperands.begin() + icmpOperand);
|
|
newOperands.push_back(result);
|
|
result = rewriter.create<XorOp>(op.getLoc(), newOperands, op.getTwoState());
|
|
}
|
|
|
|
replaceOpAndCopyNamehint(rewriter, op, result);
|
|
}
|
|
|
|
LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
|
|
auto inputs = op.getInputs();
|
|
auto size = inputs.size();
|
|
assert(size > 1 && "expected 2 or more operands");
|
|
|
|
// xor(..., x, x) -> xor (...) -- idempotent
|
|
if (inputs[size - 1] == inputs[size - 2]) {
|
|
assert(size > 2 &&
|
|
"expected idempotent case for 2 elements handled already.");
|
|
replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
|
|
inputs.drop_back(/*n=*/2), false);
|
|
return success();
|
|
}
|
|
|
|
// Patterns for xor with a constant on RHS.
|
|
APInt value;
|
|
if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
|
|
// xor(..., 0) -> xor(...) -- identity
|
|
if (value.isZero()) {
|
|
replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
|
|
inputs.drop_back(), false);
|
|
return success();
|
|
}
|
|
|
|
// xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
|
|
APInt value2;
|
|
if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
|
|
auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value ^ value2);
|
|
SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
|
|
newOperands.push_back(cst);
|
|
replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
|
|
newOperands, false);
|
|
return success();
|
|
}
|
|
|
|
bool isSingleBit = value.getBitWidth() == 1;
|
|
|
|
// Check for subexpressions that we can simplify.
|
|
for (size_t i = 0; i < size - 1; ++i) {
|
|
Value operand = inputs[i];
|
|
|
|
// xor(concat(x, cst1), a, b, c, cst2)
|
|
// ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
|
|
// We do this for even more multi-use concats since they are "just
|
|
// wiring".
|
|
if (auto concat = operand.getDefiningOp<ConcatOp>())
|
|
if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
|
|
return success();
|
|
|
|
// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
|
|
if (isSingleBit && operand.hasOneUse()) {
|
|
assert(value == 1 && "single bit constant has to be one if not zero");
|
|
if (auto icmp = operand.getDefiningOp<ICmpOp>())
|
|
return canonicalizeXorIcmpTrue(op, i, rewriter), success();
|
|
}
|
|
}
|
|
}
|
|
|
|
// xor(x, xor(...)) -> xor(x, ...) -- flatten
|
|
if (tryFlatteningOperands(op, rewriter))
|
|
return success();
|
|
|
|
// extracts only of xor(...) -> xor(extract()...)
|
|
if (narrowOperationWidth(op, true, rewriter))
|
|
return success();
|
|
|
|
// xor(a[0], a[1], ..., a[n]) -> parity(a)
|
|
if (auto source = getCommonOperand(op)) {
|
|
replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
// sub(x - x) -> 0
|
|
if (getRhs() == getLhs())
|
|
return getIntAttr(
|
|
APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
|
|
getContext());
|
|
|
|
if (adaptor.getRhs()) {
|
|
// If both are constants, we can unconditionally fold.
|
|
if (adaptor.getLhs()) {
|
|
// Constant fold (c1 - c2) => (c1 + -1*c2).
|
|
auto negOne = getIntAttr(
|
|
APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
|
|
getContext());
|
|
auto rhsNeg = hw::ParamExprAttr::get(
|
|
hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
|
|
return hw::ParamExprAttr::get(hw::PEO::Add,
|
|
cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
|
|
}
|
|
|
|
// sub(x - 0) -> x
|
|
if (auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
|
|
if (rhsC.getValue().isZero())
|
|
return getLhs();
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
// sub(x, cst) -> add(x, -cst)
|
|
APInt value;
|
|
if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
|
|
auto negCst = rewriter.create<hw::ConstantOp>(op.getLoc(), -value);
|
|
replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
|
|
false);
|
|
return success();
|
|
}
|
|
|
|
// extracts only of sub(...) -> sub(extract()...)
|
|
if (narrowOperationWidth(op, false, rewriter))
|
|
return success();
|
|
|
|
return failure();
|
|
}
|
|
|
|
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
auto size = getInputs().size();
|
|
|
|
// add(x) -> x -- noop
|
|
if (size == 1u)
|
|
return getInputs()[0];
|
|
|
|
// Constant fold constant operands.
|
|
return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
|
|
}
|
|
|
|
LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
auto inputs = op.getInputs();
|
|
auto size = inputs.size();
|
|
assert(size > 1 && "expected 2 or more operands");
|
|
|
|
APInt value, value2;
|
|
|
|
// add(..., 0) -> add(...) -- identity
|
|
if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
|
|
replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
|
|
inputs.drop_back(), false);
|
|
return success();
|
|
}
|
|
|
|
// add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
|
|
if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
|
|
matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
|
|
auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
|
|
SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
|
|
newOperands.push_back(cst);
|
|
replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
|
|
newOperands, false);
|
|
return success();
|
|
}
|
|
|
|
// add(..., x, x) -> add(..., shl(x, 1))
|
|
if (inputs[size - 1] == inputs[size - 2]) {
|
|
SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
|
|
|
|
auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
|
|
auto shiftLeftOp =
|
|
rewriter.create<comb::ShlOp>(op.getLoc(), inputs.back(), one, false);
|
|
|
|
newOperands.push_back(shiftLeftOp);
|
|
replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
|
|
newOperands, false);
|
|
return success();
|
|
}
|
|
|
|
auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
|
|
// add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
|
|
if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
|
|
matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
|
|
|
|
APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
|
|
auto rhs =
|
|
rewriter.create<hw::ConstantOp>(op.getLoc(), (one << value) + one);
|
|
|
|
std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
|
|
auto mulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
|
|
|
|
SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
|
|
newOperands.push_back(mulOp);
|
|
replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
|
|
newOperands, false);
|
|
return success();
|
|
}
|
|
|
|
auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
|
|
// add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
|
|
if (mulOp && mulOp.getInputs().size() == 2 &&
|
|
mulOp.getInputs()[0] == inputs[size - 2] &&
|
|
matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
|
|
|
|
APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
|
|
auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + one);
|
|
std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
|
|
auto newMulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
|
|
|
|
SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
|
|
newOperands.push_back(newMulOp);
|
|
replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
|
|
newOperands, false);
|
|
return success();
|
|
}
|
|
|
|
// add(a, add(...)) -> add(a, ...) -- flatten
|
|
if (tryFlatteningOperands(op, rewriter))
|
|
return success();
|
|
|
|
// extracts only of add(...) -> add(extract()...)
|
|
if (narrowOperationWidth(op, false, rewriter))
|
|
return success();
|
|
|
|
// add(add(x, c1), c2) -> add(x, c1 + c2)
|
|
auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
|
|
if (addOp && addOp.getInputs().size() == 2 &&
|
|
matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
|
|
inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
|
|
|
|
auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
|
|
replaceOpWithNewOpAndCopyNamehint<AddOp>(
|
|
rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
|
|
/*twoState=*/op.getTwoState() && addOp.getTwoState());
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
auto size = getInputs().size();
|
|
auto inputs = adaptor.getInputs();
|
|
|
|
// mul(x) -> x -- noop
|
|
if (size == 1u)
|
|
return getInputs()[0];
|
|
|
|
auto width = cast<IntegerType>(getType()).getWidth();
|
|
if (width == 0)
|
|
return getIntAttr(APInt::getZero(0), getContext());
|
|
|
|
APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
|
|
|
|
// mul(x, 0, 1) -> 0 -- annulment
|
|
for (auto operand : inputs) {
|
|
if (!operand)
|
|
continue;
|
|
value *= cast<IntegerAttr>(operand).getValue();
|
|
if (value.isZero())
|
|
return getIntAttr(value, getContext());
|
|
}
|
|
|
|
// Constant fold
|
|
return constFoldAssociativeOp(inputs, hw::PEO::Mul);
|
|
}
|
|
|
|
LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
auto inputs = op.getInputs();
|
|
auto size = inputs.size();
|
|
assert(size > 1 && "expected 2 or more operands");
|
|
|
|
APInt value, value2;
|
|
|
|
// mul(x, c) -> shl(x, log2(c)), where c is a power of two.
|
|
if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
|
|
value.isPowerOf2()) {
|
|
auto shift = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(),
|
|
value.exactLogBase2());
|
|
auto shlOp =
|
|
rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift, false);
|
|
|
|
replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
|
|
ArrayRef<Value>(shlOp), false);
|
|
return success();
|
|
}
|
|
|
|
// mul(..., 1) -> mul(...) -- identity
|
|
if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
|
|
replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
|
|
inputs.drop_back());
|
|
return success();
|
|
}
|
|
|
|
// mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
|
|
if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
|
|
matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
|
|
auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value * value2);
|
|
SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
|
|
newOperands.push_back(cst);
|
|
replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
|
|
newOperands);
|
|
return success();
|
|
}
|
|
|
|
// mul(a, mul(...)) -> mul(a, ...) -- flatten
|
|
if (tryFlatteningOperands(op, rewriter))
|
|
return success();
|
|
|
|
// extracts only of mul(...) -> mul(extract()...)
|
|
if (narrowOperationWidth(op, false, rewriter))
|
|
return success();
|
|
|
|
return failure();
|
|
}
|
|
|
|
template <class Op, bool isSigned>
|
|
static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
|
|
if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
|
|
// divu(x, 1) -> x, divs(x, 1) -> x
|
|
if (rhsValue.getValue() == 1)
|
|
return op.getLhs();
|
|
|
|
// If the divisor is zero, do not fold for now.
|
|
if (rhsValue.getValue().isZero())
|
|
return {};
|
|
}
|
|
|
|
return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
|
|
}
|
|
|
|
OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
|
|
}
|
|
|
|
OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
|
|
}
|
|
|
|
template <class Op, bool isSigned>
|
|
static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
|
|
if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
|
|
// modu(x, 1) -> 0, mods(x, 1) -> 0
|
|
if (rhsValue.getValue() == 1)
|
|
return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
|
|
op.getContext());
|
|
|
|
// If the divisor is zero, do not fold for now.
|
|
if (rhsValue.getValue().isZero())
|
|
return {};
|
|
}
|
|
|
|
if (auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
|
|
// modu(0, x) -> 0, mods(0, x) -> 0
|
|
if (lhsValue.getValue().isZero())
|
|
return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
|
|
op.getContext());
|
|
}
|
|
|
|
return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
|
|
}
|
|
|
|
OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
|
|
}
|
|
|
|
OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
|
|
}
|
|
//===----------------------------------------------------------------------===//
|
|
// ConcatOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Constant folding
|
|
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
if (getNumOperands() == 1)
|
|
return getOperand(0);
|
|
|
|
// If all the operands are constant, we can fold.
|
|
for (auto attr : adaptor.getInputs())
|
|
if (!attr || !isa<IntegerAttr>(attr))
|
|
return {};
|
|
|
|
// If we got here, we can constant fold.
|
|
unsigned resultWidth = getType().getIntOrFloatBitWidth();
|
|
APInt result(resultWidth, 0);
|
|
|
|
unsigned nextInsertion = resultWidth;
|
|
// Insert each chunk into the result.
|
|
for (auto attr : adaptor.getInputs()) {
|
|
auto chunk = cast<IntegerAttr>(attr).getValue();
|
|
nextInsertion -= chunk.getBitWidth();
|
|
result.insertBits(chunk, nextInsertion);
|
|
}
|
|
|
|
return getIntAttr(result, getContext());
|
|
}
|
|
|
|
LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
auto inputs = op.getInputs();
|
|
auto size = inputs.size();
|
|
assert(size > 1 && "expected 2 or more operands");
|
|
|
|
// This function is used when we flatten neighboring operands of a
|
|
// (variadic) concat into a new vesion of the concat. first/last indices
|
|
// are inclusive.
|
|
auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
|
|
ValueRange replacements) -> LogicalResult {
|
|
SmallVector<Value, 4> newOperands;
|
|
newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
|
|
newOperands.append(replacements.begin(), replacements.end());
|
|
newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
|
|
if (newOperands.size() == 1)
|
|
replaceOpAndCopyNamehint(rewriter, op, newOperands[0]);
|
|
else
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
|
|
newOperands);
|
|
return success();
|
|
};
|
|
|
|
Value commonOperand = inputs[0];
|
|
for (size_t i = 0; i != size; ++i) {
|
|
// Check to see if all operands are the same.
|
|
if (inputs[i] != commonOperand)
|
|
commonOperand = Value();
|
|
|
|
// If an operand to the concat is itself a concat, then we can fold them
|
|
// together.
|
|
if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
|
|
return flattenConcat(i, i, subConcat->getOperands());
|
|
|
|
// Check for canonicalization due to neighboring operands.
|
|
if (i != 0) {
|
|
// Merge neighboring constants.
|
|
if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
|
|
if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
|
|
unsigned prevWidth = prevCst.getValue().getBitWidth();
|
|
unsigned thisWidth = cst.getValue().getBitWidth();
|
|
auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
|
|
resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
|
|
<< thisWidth;
|
|
Value replacement =
|
|
rewriter.create<hw::ConstantOp>(op.getLoc(), resultCst);
|
|
return flattenConcat(i - 1, i, replacement);
|
|
}
|
|
}
|
|
|
|
// If the two operands are the same, turn them into a replicate.
|
|
if (inputs[i] == inputs[i - 1]) {
|
|
Value replacement =
|
|
rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
|
|
return flattenConcat(i - 1, i, replacement);
|
|
}
|
|
|
|
// If this input is a replicate, see if we can fold it with the previous
|
|
// one.
|
|
if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
|
|
// ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
|
|
if (repl.getOperand() == inputs[i - 1]) {
|
|
Value replacement = rewriter.createOrFold<ReplicateOp>(
|
|
op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
|
|
return flattenConcat(i - 1, i, replacement);
|
|
}
|
|
// ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
|
|
if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
|
|
if (prevRepl.getOperand() == repl.getOperand()) {
|
|
Value replacement = rewriter.createOrFold<ReplicateOp>(
|
|
op.getLoc(), repl.getOperand(),
|
|
repl.getMultiple() + prevRepl.getMultiple());
|
|
return flattenConcat(i - 1, i, replacement);
|
|
}
|
|
}
|
|
}
|
|
|
|
// ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
|
|
if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
|
|
if (repl.getOperand() == inputs[i]) {
|
|
Value replacement = rewriter.createOrFold<ReplicateOp>(
|
|
op.getLoc(), inputs[i], repl.getMultiple() + 1);
|
|
return flattenConcat(i - 1, i, replacement);
|
|
}
|
|
}
|
|
|
|
// Merge neighboring extracts of neighboring inputs, e.g.
|
|
// {A[3], A[2]} -> A[3:2]
|
|
if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
|
|
if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
|
|
if (extract.getInput() == prevExtract.getInput()) {
|
|
auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
|
|
if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
|
|
auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
|
|
auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
|
|
Value replacement = rewriter.create<ExtractOp>(
|
|
op.getLoc(), resType, extract.getInput(),
|
|
extract.getLowBit());
|
|
return flattenConcat(i - 1, i, replacement);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Merge neighboring array extracts of neighboring inputs, e.g.
|
|
// {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
|
|
|
|
// This represents a slice of an array.
|
|
struct ArraySlice {
|
|
Value input;
|
|
Value index;
|
|
size_t width;
|
|
static std::optional<ArraySlice> get(Value value) {
|
|
assert(isa<IntegerType>(value.getType()) && "expected integer type");
|
|
if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
|
|
return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
|
|
// array slice op is wrapped with bitcast.
|
|
if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
|
|
if (auto arraySlice =
|
|
bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
|
|
return ArraySlice{
|
|
arraySlice.getInput(), arraySlice.getLowIndex(),
|
|
hw::type_cast<hw::ArrayType>(arraySlice.getType())
|
|
.getNumElements()};
|
|
return std::nullopt;
|
|
}
|
|
};
|
|
if (auto extractOpt = ArraySlice::get(inputs[i])) {
|
|
if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
|
|
// Check that two array slices are mergable.
|
|
if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
|
|
prevExtractOpt->input == extractOpt->input &&
|
|
hw::isOffset(extractOpt->index, prevExtractOpt->index,
|
|
extractOpt->width)) {
|
|
auto resType = hw::ArrayType::get(
|
|
hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
|
|
.getElementType(),
|
|
extractOpt->width + prevExtractOpt->width);
|
|
auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
|
|
Value replacement = rewriter.create<hw::BitcastOp>(
|
|
op.getLoc(), resIntType,
|
|
rewriter.create<hw::ArraySliceOp>(op.getLoc(), resType,
|
|
prevExtractOpt->input,
|
|
extractOpt->index));
|
|
return flattenConcat(i - 1, i, replacement);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// If all operands were the same, then this is a replicate.
|
|
if (commonOperand) {
|
|
replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
|
|
commonOperand);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MuxOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
// mux (c, b, b) -> b
|
|
if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
|
|
return getTrueValue();
|
|
if (auto tv = adaptor.getTrueValue())
|
|
if (tv == adaptor.getFalseValue())
|
|
return tv;
|
|
|
|
// mux(0, a, b) -> b
|
|
// mux(1, a, b) -> a
|
|
if (auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
|
|
if (pred.getValue().isZero())
|
|
return getFalseValue();
|
|
return getTrueValue();
|
|
}
|
|
|
|
// mux(cond, 1, 0) -> cond
|
|
if (auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
|
|
if (auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
|
|
if (tv.getValue().isOne() && fv.getValue().isZero() &&
|
|
hw::getBitWidth(getType()) == 1)
|
|
return getCond();
|
|
|
|
return {};
|
|
}
|
|
|
|
/// Check to see if the condition to the specified mux is an equality
|
|
/// comparison `indexValue` and one or more constants. If so, put the
|
|
/// constants in the constants vector and return true, otherwise return false.
|
|
///
|
|
/// This is part of foldMuxChain.
|
|
///
|
|
static bool
|
|
getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
|
|
std::function<void(hw::ConstantOp)> constantFn) {
|
|
// Handle `idx == 42` and `idx != 42`.
|
|
if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
|
|
// TODO: We could handle things like "x < 2" as two entries.
|
|
auto requiredPredicate =
|
|
(isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
|
|
if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
|
|
if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
|
|
constantFn(cst);
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Handle mux(`idx == 1 || idx == 3`, value, muxchain).
|
|
if (auto orOp = cond.getDefiningOp<OrOp>()) {
|
|
if (!isInverted)
|
|
return false;
|
|
for (auto operand : orOp.getOperands())
|
|
if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
// Handle mux(`idx != 1 && idx != 3`, muxchain, value).
|
|
if (auto andOp = cond.getDefiningOp<AndOp>()) {
|
|
if (isInverted)
|
|
return false;
|
|
for (auto operand : andOp.getOperands())
|
|
if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
/// Given a mux, check to see if the "on true" value (or "on false" value if
|
|
/// isFalseSide=true) is a mux tree with the same condition. This allows us
|
|
/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
|
|
/// `array_get (array_create(A, B, C), VAL)` which is far more compact and
|
|
/// allows synthesis tools to do more interesting optimizations.
|
|
///
|
|
/// This returns false if we cannot form the mux tree (or do not want to) and
|
|
/// returns true if the mux was replaced.
|
|
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
|
|
PatternRewriter &rewriter) {
|
|
// Get the index value being compared. Later we check to see if it is
|
|
// compared to a constant with the right predicate.
|
|
auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
|
|
if (!rootCmp)
|
|
return false;
|
|
Value indexValue = rootCmp.getLhs();
|
|
|
|
// Return the value to use if the equality match succeeds.
|
|
auto getCaseValue = [&](MuxOp mux) -> Value {
|
|
return mux.getOperand(1 + unsigned(!isFalseSide));
|
|
};
|
|
|
|
// Return the value to use if the equality match fails. This is the next
|
|
// mux in the sequence or the "otherwise" value.
|
|
auto getTreeValue = [&](MuxOp mux) -> Value {
|
|
return mux.getOperand(1 + unsigned(isFalseSide));
|
|
};
|
|
|
|
// Start scanning the mux tree to see what we've got. Keep track of the
|
|
// constant comparison value and the SSA value to use when equal to it.
|
|
SmallVector<Location> locationsFound;
|
|
SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
|
|
|
|
/// Extract constants and values into `valuesFound` and return true if this is
|
|
/// part of the mux tree, otherwise return false.
|
|
auto collectConstantValues = [&](MuxOp mux) -> bool {
|
|
return getMuxChainCondConstant(
|
|
mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
|
|
valuesFound.push_back({cst, getCaseValue(mux)});
|
|
locationsFound.push_back(mux.getCond().getLoc());
|
|
locationsFound.push_back(mux->getLoc());
|
|
});
|
|
};
|
|
|
|
// Make sure the root is a correct comparison with a constant.
|
|
if (!collectConstantValues(rootMux))
|
|
return false;
|
|
|
|
// Make sure that we're not looking at the intermediate node in a mux tree.
|
|
if (rootMux->hasOneUse()) {
|
|
if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
|
|
if (getTreeValue(userMux) == rootMux.getResult() &&
|
|
getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
|
|
[&](hw::ConstantOp cst) {}))
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Scan up the tree linearly.
|
|
auto nextTreeValue = getTreeValue(rootMux);
|
|
while (1) {
|
|
auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
|
|
if (!nextMux || !nextMux->hasOneUse())
|
|
break;
|
|
if (!collectConstantValues(nextMux))
|
|
break;
|
|
nextTreeValue = getTreeValue(nextMux);
|
|
}
|
|
|
|
// We need to have more than three values to create an array. This is an
|
|
// arbitrary threshold which is saying that one or two muxes together is ok,
|
|
// but three should be folded.
|
|
if (valuesFound.size() < 3)
|
|
return false;
|
|
|
|
// If the array is greater that 9 bits, it will take over 512 elements and
|
|
// it will be too large for a single expression.
|
|
auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
|
|
if (indexWidth >= 9)
|
|
return false;
|
|
|
|
// Next we need to see if the values are dense-ish. We don't want to have
|
|
// a tremendous number of replicated entries in the array. Some sparsity is
|
|
// ok though, so we require the table to be at least 5/8 utilized.
|
|
uint64_t tableSize = 1ULL << indexWidth;
|
|
if (valuesFound.size() < (tableSize * 5) / 8)
|
|
return false; // Not dense enough.
|
|
|
|
// Ok, we're going to do the transformation, start by building the table
|
|
// filled with the "otherwise" value.
|
|
SmallVector<Value, 8> table(tableSize, nextTreeValue);
|
|
|
|
// Fill in entries in the table from the leaf to the root of the expression.
|
|
// This ensures that any duplicate matches end up with the ultimate value,
|
|
// which is the one closer to the root.
|
|
for (auto &elt : llvm::reverse(valuesFound)) {
|
|
uint64_t idx = elt.first.getValue().getZExtValue();
|
|
assert(idx < table.size() && "constant should be same bitwidth as index");
|
|
table[idx] = elt.second;
|
|
}
|
|
|
|
// The hw.array_create operation has the operand list in unintuitive order
|
|
// with a[0] stored as the last element, not the first.
|
|
std::reverse(table.begin(), table.end());
|
|
|
|
// Build the array_create and the array_get.
|
|
auto fusedLoc = rewriter.getFusedLoc(locationsFound);
|
|
auto array = rewriter.create<hw::ArrayCreateOp>(fusedLoc, table);
|
|
replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
|
|
indexValue);
|
|
return true;
|
|
}
|
|
|
|
/// Given a fully associative variadic operation like (a+b+c+d), break the
|
|
/// expression into two parts, one without the specified operand (e.g.
|
|
/// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
|
|
/// `tmp+c`), and return the inner expression.
|
|
///
|
|
/// NOTE: This mutates the operation in place if it only has a single user,
|
|
/// which assumes that user will be removed.
|
|
///
|
|
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
|
|
size_t operandNo,
|
|
PatternRewriter &rewriter) {
|
|
assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
|
|
assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
|
|
|
|
// If this expression already has two operands (the common case) no splitting
|
|
// is necessary.
|
|
if (fullyAssoc->getNumOperands() == 2)
|
|
return fullyAssoc->getOperand(operandNo ^ 1);
|
|
|
|
// If the operation has a single use, mutate it in place.
|
|
if (fullyAssoc->hasOneUse()) {
|
|
rewriter.modifyOpInPlace(fullyAssoc,
|
|
[&]() { fullyAssoc->eraseOperand(operandNo); });
|
|
return fullyAssoc->getResult(0);
|
|
}
|
|
|
|
// Form the new operation with the operands that remain.
|
|
SmallVector<Value> operands;
|
|
operands.append(fullyAssoc->getOperands().begin(),
|
|
fullyAssoc->getOperands().begin() + operandNo);
|
|
operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
|
|
fullyAssoc->getOperands().end());
|
|
Value opWithoutExcluded = createGenericOp(
|
|
fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
|
|
Value excluded = fullyAssoc->getOperand(operandNo);
|
|
|
|
Value fullResult =
|
|
createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
|
|
ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
|
|
replaceOpAndCopyNamehint(rewriter, fullyAssoc, fullResult);
|
|
return opWithoutExcluded;
|
|
}
|
|
|
|
/// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
|
|
/// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
|
|
/// is true. Return true on successful transformation, false if not.
|
|
///
|
|
/// These are various forms of "predicated ops" that can be handled with a
|
|
/// replicate/and combination.
|
|
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
|
|
PatternRewriter &rewriter) {
|
|
// Check to see the operand in question is an operation. If it is a port,
|
|
// we can't simplify it.
|
|
Operation *subExpr =
|
|
(isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
|
|
if (!subExpr || subExpr->getNumOperands() < 2)
|
|
return false;
|
|
|
|
// If this isn't an operation we can handle, don't spend energy on it.
|
|
if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
|
|
return false;
|
|
|
|
// Check to see if the common value occurs in the operand list for the
|
|
// subexpression op. If so, then we can simplify it.
|
|
Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
|
|
size_t opNo = 0, e = subExpr->getNumOperands();
|
|
while (opNo != e && subExpr->getOperand(opNo) != commonValue)
|
|
++opNo;
|
|
if (opNo == e)
|
|
return false;
|
|
|
|
// If we got a hit, then go ahead and simplify it!
|
|
Value cond = op.getCond();
|
|
|
|
// `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
|
|
// `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
|
|
// `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
|
|
// `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
|
|
if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
|
|
if (subMux == op)
|
|
return false;
|
|
|
|
Value otherValue;
|
|
Value subCond = subMux.getCond();
|
|
|
|
// Invert th subCond if needed and dig out the 'b' value.
|
|
if (subMux.getTrueValue() == commonValue)
|
|
otherValue = subMux.getFalseValue();
|
|
else if (subMux.getFalseValue() == commonValue) {
|
|
otherValue = subMux.getTrueValue();
|
|
subCond = createOrFoldNot(op.getLoc(), subCond, rewriter);
|
|
} else {
|
|
// We can't fold `mux(cond, a, mux(a, x, y))`.
|
|
return false;
|
|
}
|
|
|
|
// Invert the outer cond if needed, and combine the mux conditions.
|
|
if (!isTrueOperand)
|
|
cond = createOrFoldNot(op.getLoc(), cond, rewriter);
|
|
cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
|
|
replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
|
|
otherValue, op.getTwoState());
|
|
return true;
|
|
}
|
|
|
|
// Invert the condition if needed. Or/Xor invert when dealing with
|
|
// TrueOperand, And inverts for False operand.
|
|
bool isaAndOp = isa<AndOp>(subExpr);
|
|
if (isTrueOperand ^ isaAndOp)
|
|
cond = createOrFoldNot(op.getLoc(), cond, rewriter);
|
|
|
|
auto extendedCond =
|
|
rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
|
|
|
|
// Cache this information before subExpr is erased by extraction below.
|
|
bool isaXorOp = isa<XorOp>(subExpr);
|
|
bool isaOrOp = isa<OrOp>(subExpr);
|
|
|
|
// Handle the fully associative ops, start by pulling out the subexpression
|
|
// from a many operand version of the op.
|
|
auto restOfAssoc =
|
|
extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
|
|
|
|
// `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
|
|
// `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
|
|
if (isaOrOp || isaXorOp) {
|
|
auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
|
|
restOfAssoc, false);
|
|
if (isaXorOp)
|
|
replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
|
|
commonValue, false);
|
|
else
|
|
replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
|
|
false);
|
|
return true;
|
|
}
|
|
|
|
// `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
|
|
assert(isaAndOp && "unexpected operation here");
|
|
auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
|
|
restOfAssoc, false);
|
|
replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
|
|
false);
|
|
return true;
|
|
}
|
|
|
|
/// This function is invoke when we find a mux with true/false operations that
|
|
/// have the same opcode. Check to see if we can strength reduce the mux by
|
|
/// applying it to less data by applying this transformation:
|
|
/// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
|
|
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
|
|
Operation *falseOp,
|
|
PatternRewriter &rewriter) {
|
|
// Right now we only apply to concat.
|
|
// TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
|
|
if (!isa<ConcatOp>(trueOp))
|
|
return false;
|
|
|
|
// Decode the operands, looking through recursive concats and replicates.
|
|
SmallVector<Value> trueOperands, falseOperands;
|
|
getConcatOperands(trueOp->getResult(0), trueOperands);
|
|
getConcatOperands(falseOp->getResult(0), falseOperands);
|
|
|
|
size_t numTrueOperands = trueOperands.size();
|
|
size_t numFalseOperands = falseOperands.size();
|
|
|
|
if (!numTrueOperands || !numFalseOperands ||
|
|
(trueOperands.front() != falseOperands.front() &&
|
|
trueOperands.back() != falseOperands.back()))
|
|
return false;
|
|
|
|
// Pull all leading shared operands out into their own op if any are common.
|
|
if (trueOperands.front() == falseOperands.front()) {
|
|
SmallVector<Value> operands;
|
|
size_t i;
|
|
for (i = 0; i < numTrueOperands; ++i) {
|
|
Value trueOperand = trueOperands[i];
|
|
if (trueOperand == falseOperands[i])
|
|
operands.push_back(trueOperand);
|
|
else
|
|
break;
|
|
}
|
|
if (i == numTrueOperands) {
|
|
// Selecting between distinct, but lexically identical, concats.
|
|
replaceOpAndCopyNamehint(rewriter, mux, trueOp->getResult(0));
|
|
return true;
|
|
}
|
|
|
|
Value sharedMSB;
|
|
if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
|
|
sharedMSB = rewriter.createOrFold<ReplicateOp>(
|
|
mux->getLoc(), operands.front(), operands.size());
|
|
else
|
|
sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
|
|
operands.clear();
|
|
|
|
// Get a concat of the LSB's on each side.
|
|
operands.append(trueOperands.begin() + i, trueOperands.end());
|
|
Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
|
|
operands.clear();
|
|
operands.append(falseOperands.begin() + i, falseOperands.end());
|
|
Value falseLSB =
|
|
rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
|
|
// Merge the LSBs with a new mux and concat the MSB with the LSB to be
|
|
// done.
|
|
Value lsb = rewriter.createOrFold<MuxOp>(
|
|
mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
|
|
return true;
|
|
}
|
|
|
|
// If trailing operands match, try to commonize them.
|
|
if (trueOperands.back() == falseOperands.back()) {
|
|
SmallVector<Value> operands;
|
|
size_t i;
|
|
for (i = 0;; ++i) {
|
|
Value trueOperand = trueOperands[numTrueOperands - i - 1];
|
|
if (trueOperand == falseOperands[numFalseOperands - i - 1])
|
|
operands.push_back(trueOperand);
|
|
else
|
|
break;
|
|
}
|
|
std::reverse(operands.begin(), operands.end());
|
|
Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
|
|
operands.clear();
|
|
|
|
// Get a concat of the MSB's on each side.
|
|
operands.append(trueOperands.begin(), trueOperands.end() - i);
|
|
Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
|
|
operands.clear();
|
|
operands.append(falseOperands.begin(), falseOperands.end() - i);
|
|
Value falseMSB =
|
|
rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
|
|
// Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
|
|
Value msb = rewriter.createOrFold<MuxOp>(
|
|
mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
// If both arguments of the mux are arrays with the same elements, sink the
|
|
// mux and return a uniform array initializing all elements to it.
|
|
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
|
|
auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
|
|
auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
|
|
if (!trueVec || !falseVec)
|
|
return false;
|
|
if (!trueVec.isUniform() || !falseVec.isUniform())
|
|
return false;
|
|
|
|
auto mux = rewriter.create<MuxOp>(
|
|
op.getLoc(), op.getCond(), trueVec.getUniformElement(),
|
|
falseVec.getUniformElement(), op.getTwoState());
|
|
|
|
SmallVector<Value> values(trueVec.getInputs().size(), mux);
|
|
rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
|
|
return true;
|
|
}
|
|
|
|
namespace {
|
|
struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(MuxOp op,
|
|
PatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
|
|
PatternRewriter &rewriter) const {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
// If the op has a SV attribute, don't optimize it.
|
|
if (hasSVAttributes(op))
|
|
return failure();
|
|
APInt value;
|
|
|
|
if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
|
|
if (value.getBitWidth() == 1) {
|
|
// mux(a, 0, b) -> and(~a, b) for single-bit values.
|
|
if (value.isZero()) {
|
|
auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
|
|
replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
|
|
op.getFalseValue(), false);
|
|
return success();
|
|
}
|
|
|
|
// mux(a, 1, b) -> or(a, b) for single-bit values.
|
|
replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
|
|
op.getFalseValue(), false);
|
|
return success();
|
|
}
|
|
|
|
// Check for mux of two constants. There are many ways to simplify them.
|
|
APInt value2;
|
|
if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
|
|
// When both inputs are constants and differ by only one bit, we can
|
|
// simplify by splitting the mux into up to three contiguous chunks: one
|
|
// for the differing bit and up to two for the bits that are the same.
|
|
// E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
|
|
APInt xorValue = value ^ value2;
|
|
if (xorValue.isPowerOf2()) {
|
|
unsigned leadingZeros = xorValue.countLeadingZeros();
|
|
unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
|
|
SmallVector<Value, 3> operands;
|
|
|
|
// Concat operands go from MSB to LSB, so we handle chunks in reverse
|
|
// order of bit indexes.
|
|
// For the chunks that are identical (i.e. correspond to 0s in
|
|
// xorValue), we can extract directly from either input value, and we
|
|
// arbitrarily pick the trueValue().
|
|
|
|
if (leadingZeros > 0)
|
|
operands.push_back(rewriter.createOrFold<ExtractOp>(
|
|
op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
|
|
|
|
// Handle the differing bit, which should simplify into either cond or
|
|
// ~cond.
|
|
auto v1 = rewriter.createOrFold<ExtractOp>(
|
|
op.getLoc(), op.getTrueValue(), trailingZeros, 1);
|
|
auto v2 = rewriter.createOrFold<ExtractOp>(
|
|
op.getLoc(), op.getFalseValue(), trailingZeros, 1);
|
|
operands.push_back(rewriter.createOrFold<MuxOp>(
|
|
op.getLoc(), op.getCond(), v1, v2, false));
|
|
|
|
if (trailingZeros > 0)
|
|
operands.push_back(rewriter.createOrFold<ExtractOp>(
|
|
op.getLoc(), op.getTrueValue(), 0, trailingZeros));
|
|
|
|
replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
|
|
operands);
|
|
return success();
|
|
}
|
|
|
|
// If the true value is all ones and the false is all zeros then we have a
|
|
// replicate pattern.
|
|
if (value.isAllOnes() && value2.isZero()) {
|
|
replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
|
|
rewriter, op, op.getType(), op.getCond());
|
|
return success();
|
|
}
|
|
}
|
|
}
|
|
|
|
if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
|
|
value.getBitWidth() == 1) {
|
|
// mux(a, b, 0) -> and(a, b) for single-bit values.
|
|
if (value.isZero()) {
|
|
replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
|
|
op.getTrueValue(), false);
|
|
return success();
|
|
}
|
|
|
|
// mux(a, b, 1) -> or(~a, b) for single-bit values.
|
|
// falseValue() is known to be a single-bit 1, which we can use for
|
|
// the 1 in the representation of ~ using xor.
|
|
auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
|
|
op.getFalseValue(), false);
|
|
replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
|
|
op.getTrueValue(), false);
|
|
return success();
|
|
}
|
|
|
|
// mux(!a, b, c) -> mux(a, c, b)
|
|
Value subExpr;
|
|
Operation *condOp = op.getCond().getDefiningOp();
|
|
if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
|
|
op.getTwoState()) {
|
|
replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
|
|
subExpr, op.getFalseValue(),
|
|
op.getTrueValue(), true);
|
|
return success();
|
|
}
|
|
|
|
// Same but with Demorgan's law.
|
|
// mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
|
|
// mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
|
|
if (condOp && condOp->hasOneUse()) {
|
|
SmallVector<Value> invertedOperands;
|
|
|
|
/// Scan all the operands to see if they are complemented. If so, build a
|
|
/// vector of them and return true, otherwise return false.
|
|
auto getInvertedOperands = [&]() -> bool {
|
|
for (Value operand : condOp->getOperands()) {
|
|
if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
|
|
invertedOperands.push_back(subExpr);
|
|
else
|
|
return false;
|
|
}
|
|
return true;
|
|
};
|
|
|
|
if (isa<AndOp>(condOp) && getInvertedOperands()) {
|
|
auto newOr =
|
|
rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
|
|
replaceOpWithNewOpAndCopyNamehint<MuxOp>(
|
|
rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
|
|
op.getTwoState());
|
|
return success();
|
|
}
|
|
if (isa<OrOp>(condOp) && getInvertedOperands()) {
|
|
auto newAnd =
|
|
rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
|
|
replaceOpWithNewOpAndCopyNamehint<MuxOp>(
|
|
rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
|
|
op.getTwoState());
|
|
return success();
|
|
}
|
|
}
|
|
|
|
if (auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
|
|
falseMux && falseMux != op) {
|
|
// mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
|
|
if (op.getCond() == falseMux.getCond()) {
|
|
replaceOpWithNewOpAndCopyNamehint<MuxOp>(
|
|
rewriter, op, op.getCond(), op.getTrueValue(),
|
|
falseMux.getFalseValue(), op.getTwoStateAttr());
|
|
return success();
|
|
}
|
|
|
|
// Check to see if we can fold a mux tree into an array_create/get pair.
|
|
if (foldMuxChain(op, /*isFalse*/ true, rewriter))
|
|
return success();
|
|
}
|
|
|
|
if (auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
|
|
trueMux && trueMux != op) {
|
|
// mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
|
|
if (op.getCond() == trueMux.getCond()) {
|
|
replaceOpWithNewOpAndCopyNamehint<MuxOp>(
|
|
rewriter, op, op.getCond(), trueMux.getTrueValue(),
|
|
op.getFalseValue(), op.getTwoStateAttr());
|
|
return success();
|
|
}
|
|
|
|
// Check to see if we can fold a mux tree into an array_create/get pair.
|
|
if (foldMuxChain(op, /*isFalseSide*/ false, rewriter))
|
|
return success();
|
|
}
|
|
|
|
// mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
|
|
if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
|
|
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
|
|
trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
|
|
trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
|
|
falseMux != op) {
|
|
auto subMux = rewriter.create<MuxOp>(
|
|
rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
|
|
op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
|
|
replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
|
|
trueMux.getTrueValue(), subMux,
|
|
op.getTwoStateAttr());
|
|
return success();
|
|
}
|
|
|
|
// mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
|
|
if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
|
|
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
|
|
trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
|
|
trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
|
|
falseMux != op) {
|
|
auto subMux = rewriter.create<MuxOp>(
|
|
rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
|
|
op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
|
|
replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
|
|
subMux, trueMux.getFalseValue(),
|
|
op.getTwoStateAttr());
|
|
return success();
|
|
}
|
|
|
|
// mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
|
|
if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
|
|
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
|
|
trueMux && falseMux &&
|
|
trueMux.getTrueValue() == falseMux.getTrueValue() &&
|
|
trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
|
|
falseMux != op) {
|
|
auto subMux = rewriter.create<MuxOp>(
|
|
rewriter.getFusedLoc(
|
|
{op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
|
|
op.getCond(), trueMux.getCond(), falseMux.getCond());
|
|
replaceOpWithNewOpAndCopyNamehint<MuxOp>(
|
|
rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
|
|
op.getTwoStateAttr());
|
|
return success();
|
|
}
|
|
|
|
// mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
|
|
if (foldCommonMuxValue(op, false, rewriter))
|
|
return success();
|
|
// mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
|
|
if (foldCommonMuxValue(op, true, rewriter))
|
|
return success();
|
|
|
|
// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
|
|
if (Operation *trueOp = op.getTrueValue().getDefiningOp())
|
|
if (Operation *falseOp = op.getFalseValue().getDefiningOp())
|
|
if (trueOp->getName() == falseOp->getName())
|
|
if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
|
|
return success();
|
|
|
|
// extracts only of mux(...) -> mux(extract()...)
|
|
if (narrowOperationWidth(op, true, rewriter))
|
|
return success();
|
|
|
|
// mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
|
|
if (foldMuxOfUniformArrays(op, rewriter))
|
|
return success();
|
|
|
|
return failure();
|
|
}
|
|
|
|
static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
|
|
// Do not fold uniform or singleton arrays to avoid duplicating muxes.
|
|
if (op.getInputs().empty() || op.isUniform())
|
|
return false;
|
|
auto inputs = op.getInputs();
|
|
if (inputs.size() <= 1)
|
|
return false;
|
|
|
|
// Check the operands to the array create. Ensure all of them are the
|
|
// same op with the same number of operands.
|
|
auto first = inputs[0].getDefiningOp<comb::MuxOp>();
|
|
if (!first || hasSVAttributes(first))
|
|
return false;
|
|
|
|
// Check whether all operands are muxes with the same condition.
|
|
for (size_t i = 1, n = inputs.size(); i < n; ++i) {
|
|
auto input = inputs[i].getDefiningOp<comb::MuxOp>();
|
|
if (!input || first.getCond() != input.getCond())
|
|
return false;
|
|
}
|
|
|
|
// Collect the true and the false branches into arrays.
|
|
SmallVector<Value> trues{first.getTrueValue()};
|
|
SmallVector<Value> falses{first.getFalseValue()};
|
|
SmallVector<Location> locs{first->getLoc()};
|
|
bool isTwoState = true;
|
|
for (size_t i = 1, n = inputs.size(); i < n; ++i) {
|
|
auto input = inputs[i].getDefiningOp<comb::MuxOp>();
|
|
trues.push_back(input.getTrueValue());
|
|
falses.push_back(input.getFalseValue());
|
|
locs.push_back(input->getLoc());
|
|
if (!input.getTwoState())
|
|
isTwoState = false;
|
|
}
|
|
|
|
// Define the location of the array create as the aggregate of all muxes.
|
|
auto loc = FusedLoc::get(op.getContext(), locs);
|
|
|
|
// Replace the create with an aggregate operation. Push the create op
|
|
// into the operands of the aggregate operation.
|
|
auto arrayTy = op.getType();
|
|
auto trueValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, trues);
|
|
auto falseValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, falses);
|
|
rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
|
|
trueValues, falseValues, isTwoState);
|
|
return true;
|
|
}
|
|
|
|
struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
if (foldArrayOfMuxes(op, rewriter))
|
|
return success();
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.insert<MuxRewriter, ArrayRewriter>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ICmpOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Calculate the result of a comparison when the LHS and RHS are both
|
|
// constants.
|
|
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
|
|
const APInt &rhs) {
|
|
switch (predicate) {
|
|
case ICmpPredicate::eq:
|
|
return lhs.eq(rhs);
|
|
case ICmpPredicate::ne:
|
|
return lhs.ne(rhs);
|
|
case ICmpPredicate::slt:
|
|
return lhs.slt(rhs);
|
|
case ICmpPredicate::sle:
|
|
return lhs.sle(rhs);
|
|
case ICmpPredicate::sgt:
|
|
return lhs.sgt(rhs);
|
|
case ICmpPredicate::sge:
|
|
return lhs.sge(rhs);
|
|
case ICmpPredicate::ult:
|
|
return lhs.ult(rhs);
|
|
case ICmpPredicate::ule:
|
|
return lhs.ule(rhs);
|
|
case ICmpPredicate::ugt:
|
|
return lhs.ugt(rhs);
|
|
case ICmpPredicate::uge:
|
|
return lhs.uge(rhs);
|
|
case ICmpPredicate::ceq:
|
|
return lhs.eq(rhs);
|
|
case ICmpPredicate::cne:
|
|
return lhs.ne(rhs);
|
|
case ICmpPredicate::weq:
|
|
return lhs.eq(rhs);
|
|
case ICmpPredicate::wne:
|
|
return lhs.ne(rhs);
|
|
}
|
|
llvm_unreachable("unknown comparison predicate");
|
|
}
|
|
|
|
// Returns the result of applying the predicate when the LHS and RHS are the
|
|
// exact same value.
|
|
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
|
|
switch (predicate) {
|
|
case ICmpPredicate::eq:
|
|
case ICmpPredicate::sle:
|
|
case ICmpPredicate::sge:
|
|
case ICmpPredicate::ule:
|
|
case ICmpPredicate::uge:
|
|
case ICmpPredicate::ceq:
|
|
case ICmpPredicate::weq:
|
|
return true;
|
|
case ICmpPredicate::ne:
|
|
case ICmpPredicate::slt:
|
|
case ICmpPredicate::sgt:
|
|
case ICmpPredicate::ult:
|
|
case ICmpPredicate::ugt:
|
|
case ICmpPredicate::cne:
|
|
case ICmpPredicate::wne:
|
|
return false;
|
|
}
|
|
llvm_unreachable("unknown comparison predicate");
|
|
}
|
|
|
|
OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
|
|
if (hasOperandsOutsideOfBlock(getOperation()))
|
|
return {};
|
|
|
|
// gt a, a -> false
|
|
// gte a, a -> true
|
|
if (getLhs() == getRhs()) {
|
|
auto val = applyCmpPredicateToEqualOperands(getPredicate());
|
|
return IntegerAttr::get(getType(), val);
|
|
}
|
|
|
|
// gt 1, 2 -> false
|
|
if (auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
|
|
if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
|
|
auto val =
|
|
applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
|
return IntegerAttr::get(getType(), val);
|
|
}
|
|
}
|
|
return {};
|
|
}
|
|
|
|
// Given a range of operands, computes the number of matching prefix and
|
|
// suffix elements. This does not perform cross-element matching.
|
|
template <typename Range>
|
|
static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
|
|
size_t commonPrefixLength = 0;
|
|
auto ia = a.begin();
|
|
auto ib = b.begin();
|
|
|
|
for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
|
|
if (*ia != *ib) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
return commonPrefixLength;
|
|
}
|
|
|
|
static size_t getTotalWidth(ArrayRef<Value> operands) {
|
|
size_t totalWidth = 0;
|
|
for (auto operand : operands) {
|
|
// getIntOrFloatBitWidth should never raise, since all arguments to
|
|
// ConcatOp are integers.
|
|
ssize_t width = operand.getType().getIntOrFloatBitWidth();
|
|
assert(width >= 0);
|
|
totalWidth += width;
|
|
}
|
|
return totalWidth;
|
|
}
|
|
|
|
/// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
|
|
/// comparison on common prefix and suffixes. Returns success() if a rewriting
|
|
/// happens. This handles both concat and replicate.
|
|
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
|
|
Operation *rhs,
|
|
PatternRewriter &rewriter) {
|
|
// It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
|
|
// all elements have non-zero length. Both these invariants are verified
|
|
// by the ConcatOp verifier.
|
|
SmallVector<Value> lhsOperands, rhsOperands;
|
|
getConcatOperands(lhs->getResult(0), lhsOperands);
|
|
getConcatOperands(rhs->getResult(0), rhsOperands);
|
|
ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
|
|
|
|
auto formCatOrReplicate = [&](Location loc,
|
|
ArrayRef<Value> operands) -> Value {
|
|
assert(!operands.empty());
|
|
Value sameElement = operands[0];
|
|
for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
|
|
if (sameElement != operands[i])
|
|
sameElement = Value();
|
|
if (sameElement)
|
|
return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
|
|
operands.size());
|
|
return rewriter.createOrFold<ConcatOp>(loc, operands);
|
|
};
|
|
|
|
auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
|
|
Value rhs) -> LogicalResult {
|
|
replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
|
|
op.getTwoState());
|
|
return success();
|
|
};
|
|
|
|
size_t commonPrefixLength =
|
|
computeCommonPrefixLength(lhsOperands, rhsOperands);
|
|
if (commonPrefixLength == lhsOperands.size()) {
|
|
// cat(a, b, c) == cat(a, b, c) -> 1
|
|
bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
|
|
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
|
|
APInt(1, result));
|
|
return success();
|
|
}
|
|
|
|
size_t commonSuffixLength = computeCommonPrefixLength(
|
|
llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
|
|
|
|
size_t commonPrefixTotalWidth =
|
|
getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
|
|
size_t commonSuffixTotalWidth =
|
|
getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
|
|
auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
|
|
.drop_back(commonSuffixLength);
|
|
auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
|
|
.drop_back(commonSuffixLength);
|
|
|
|
auto replaceWithoutReplicatingSignBit = [&]() {
|
|
auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
|
|
auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
|
|
return replaceWith(op.getPredicate(), newLhs, newRhs);
|
|
};
|
|
|
|
auto replaceWithReplicatingSignBit = [&]() {
|
|
auto firstNonEmptyValue = lhsOperands[0];
|
|
auto firstNonEmptyElemWidth =
|
|
firstNonEmptyValue.getType().getIntOrFloatBitWidth();
|
|
Value signBit = rewriter.createOrFold<ExtractOp>(
|
|
op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
|
|
|
|
auto newLhs = rewriter.create<ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
|
|
auto newRhs = rewriter.create<ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
|
|
return replaceWith(op.getPredicate(), newLhs, newRhs);
|
|
};
|
|
|
|
if (ICmpOp::isPredicateSigned(op.getPredicate())) {
|
|
// scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
|
|
if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
|
|
return replaceWithoutReplicatingSignBit();
|
|
|
|
// scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
|
|
// cat(sgn(b), ..y)) Note that we cannot perform this optimization if
|
|
// [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
|
|
// bit. Doing the rewrite can result in an infinite loop.
|
|
if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
|
|
return replaceWithReplicatingSignBit();
|
|
|
|
} else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
|
|
// ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
|
|
return replaceWithoutReplicatingSignBit();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
/// Given an equality comparison with a constant value and some operand that has
|
|
/// known bits, simplify the comparison to check only the unknown bits of the
|
|
/// input.
|
|
///
|
|
/// One simple example of this is that `concat(0, stuff) == 0` can be simplified
|
|
/// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
|
|
/// `extract x[1:0] == 0`
|
|
static void combineEqualityICmpWithKnownBitsAndConstant(
|
|
ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
|
|
PatternRewriter &rewriter) {
|
|
|
|
// If any of the known bits disagree with any of the comparison bits, then
|
|
// we can constant fold this comparison right away.
|
|
APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
|
|
if ((bitsKnown & rhsCst) != bitAnalysis.One) {
|
|
// If we discover a mismatch then we know an "eq" comparison is false
|
|
// and a "ne" comparison is true!
|
|
bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
|
|
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
|
|
APInt(1, result));
|
|
return;
|
|
}
|
|
|
|
// Check to see if we can prove the result entirely of the comparison (in
|
|
// which we bail out early), otherwise build a list of values to concat and a
|
|
// smaller constant to compare against.
|
|
SmallVector<Value> newConcatOperands;
|
|
auto newConstant = APInt::getZeroWidth();
|
|
|
|
// Ok, some (maybe all) bits are known and some others may be unknown.
|
|
// Extract out segments of the operand and compare against the
|
|
// corresponding bits.
|
|
unsigned knownMSB = bitsKnown.countLeadingOnes();
|
|
|
|
Value operand = cmpOp.getLhs();
|
|
|
|
// Ok, some bits are known but others are not. Extract out sequences of
|
|
// bits that are unknown and compare just those bits. We work from MSB to
|
|
// LSB.
|
|
while (knownMSB != bitsKnown.getBitWidth()) {
|
|
// Drop any high bits that are known.
|
|
if (knownMSB)
|
|
bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
|
|
|
|
// Find the span of unknown bits, and extract it.
|
|
unsigned unknownBits = bitsKnown.countLeadingZeros();
|
|
unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
|
|
auto spanOperand = rewriter.createOrFold<ExtractOp>(
|
|
operand.getLoc(), operand, /*lowBit=*/lowBit,
|
|
/*bitWidth=*/unknownBits);
|
|
auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
|
|
|
|
// Add this info to the concat we're generating.
|
|
newConcatOperands.push_back(spanOperand);
|
|
// FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
|
|
// newConstant = newConstant.concat(spanConstant);
|
|
if (newConstant.getBitWidth() != 0)
|
|
newConstant = newConstant.concat(spanConstant);
|
|
else
|
|
newConstant = spanConstant;
|
|
|
|
// Drop the unknown bits in prep for the next chunk.
|
|
unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
|
|
bitsKnown = bitsKnown.trunc(newWidth);
|
|
knownMSB = bitsKnown.countLeadingOnes();
|
|
}
|
|
|
|
// If all the operands to the concat are foldable then we have an identity
|
|
// situation where all the sub-elements equal each other. This implies that
|
|
// the overall result is foldable.
|
|
if (newConcatOperands.empty()) {
|
|
bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
|
|
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
|
|
APInt(1, result));
|
|
return;
|
|
}
|
|
|
|
// If we have a single operand remaining, use it, otherwise form a concat.
|
|
Value concatResult =
|
|
rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
|
|
|
|
// Form the comparison against the smaller constant.
|
|
auto newConstantOp = rewriter.create<hw::ConstantOp>(
|
|
cmpOp.getOperand(1).getLoc(), newConstant);
|
|
|
|
replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
|
|
cmpOp.getPredicate(), concatResult,
|
|
newConstantOp, cmpOp.getTwoState());
|
|
}
|
|
|
|
// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
|
|
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
|
|
const APInt &rhs,
|
|
PatternRewriter &rewriter) {
|
|
auto ip = rewriter.saveInsertionPoint();
|
|
rewriter.setInsertionPoint(xorOp);
|
|
|
|
auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
|
|
auto newRHS = rewriter.create<hw::ConstantOp>(xorRHS->getLoc(),
|
|
xorRHS.getValue() ^ rhs);
|
|
Value newLHS;
|
|
switch (xorOp.getNumOperands()) {
|
|
case 1:
|
|
// This isn't common but is defined so we need to handle it.
|
|
newLHS = rewriter.create<hw::ConstantOp>(xorOp.getLoc(),
|
|
APInt::getZero(rhs.getBitWidth()));
|
|
break;
|
|
case 2:
|
|
// The binary case is the most common.
|
|
newLHS = xorOp.getOperand(0);
|
|
break;
|
|
default:
|
|
// The general case forces us to form a new xor with the remaining operands.
|
|
SmallVector<Value> newOperands(xorOp.getOperands());
|
|
newOperands.pop_back();
|
|
newLHS = rewriter.create<XorOp>(xorOp.getLoc(), newOperands, false);
|
|
break;
|
|
}
|
|
|
|
bool xorMultipleUses = !xorOp->hasOneUse();
|
|
|
|
// If the xor has multiple uses (not just the compare, then we need/want to
|
|
// replace them as well.
|
|
if (xorMultipleUses)
|
|
replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
|
|
false);
|
|
|
|
// Replace the comparison.
|
|
rewriter.restoreInsertionPoint(ip);
|
|
replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
|
|
rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS, false);
|
|
}
|
|
|
|
LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
|
|
if (hasOperandsOutsideOfBlock(&*op))
|
|
return failure();
|
|
|
|
APInt lhs, rhs;
|
|
|
|
// icmp 1, x -> icmp x, 1
|
|
if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
|
|
assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
|
|
"Should be folded");
|
|
replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
|
|
rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
|
|
op.getRhs(), op.getLhs(), op.getTwoState());
|
|
return success();
|
|
}
|
|
|
|
// Canonicalize with RHS constant
|
|
if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
|
|
auto getConstant = [&](APInt constant) -> Value {
|
|
return rewriter.create<hw::ConstantOp>(op.getLoc(), std::move(constant));
|
|
};
|
|
|
|
auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
|
|
Value rhs) -> LogicalResult {
|
|
replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
|
|
rhs, op.getTwoState());
|
|
return success();
|
|
};
|
|
|
|
auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
|
|
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
|
|
APInt(1, constant));
|
|
return success();
|
|
};
|
|
|
|
switch (op.getPredicate()) {
|
|
case ICmpPredicate::slt:
|
|
// x < max -> x != max
|
|
if (rhs.isMaxSignedValue())
|
|
return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
|
|
// x < min -> false
|
|
if (rhs.isMinSignedValue())
|
|
return replaceWithConstantI1(0);
|
|
// x < min+1 -> x == min
|
|
if ((rhs - 1).isMinSignedValue())
|
|
return replaceWith(ICmpPredicate::eq, op.getLhs(),
|
|
getConstant(rhs - 1));
|
|
break;
|
|
case ICmpPredicate::sgt:
|
|
// x > min -> x != min
|
|
if (rhs.isMinSignedValue())
|
|
return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
|
|
// x > max -> false
|
|
if (rhs.isMaxSignedValue())
|
|
return replaceWithConstantI1(0);
|
|
// x > max-1 -> x == max
|
|
if ((rhs + 1).isMaxSignedValue())
|
|
return replaceWith(ICmpPredicate::eq, op.getLhs(),
|
|
getConstant(rhs + 1));
|
|
break;
|
|
case ICmpPredicate::ult:
|
|
// x < max -> x != max
|
|
if (rhs.isAllOnes())
|
|
return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
|
|
// x < min -> false
|
|
if (rhs.isZero())
|
|
return replaceWithConstantI1(0);
|
|
// x < min+1 -> x == min
|
|
if ((rhs - 1).isZero())
|
|
return replaceWith(ICmpPredicate::eq, op.getLhs(),
|
|
getConstant(rhs - 1));
|
|
|
|
// x < 0xE0 -> extract(x, 5..7) != 0b111
|
|
if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
|
|
rhs.getBitWidth()) {
|
|
auto numOnes = rhs.countLeadingOnes();
|
|
auto smaller = rewriter.create<ExtractOp>(
|
|
op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
|
|
return replaceWith(ICmpPredicate::ne, smaller,
|
|
getConstant(APInt::getAllOnes(numOnes)));
|
|
}
|
|
|
|
break;
|
|
case ICmpPredicate::ugt:
|
|
// x > min -> x != min
|
|
if (rhs.isZero())
|
|
return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
|
|
// x > max -> false
|
|
if (rhs.isAllOnes())
|
|
return replaceWithConstantI1(0);
|
|
// x > max-1 -> x == max
|
|
if ((rhs + 1).isAllOnes())
|
|
return replaceWith(ICmpPredicate::eq, op.getLhs(),
|
|
getConstant(rhs + 1));
|
|
|
|
// x > 0x07 -> extract(x, 3..7) != 0b00000
|
|
if ((rhs + 1).isPowerOf2()) {
|
|
auto numOnes = rhs.countTrailingOnes();
|
|
auto newWidth = rhs.getBitWidth() - numOnes;
|
|
auto smaller = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(),
|
|
numOnes, newWidth);
|
|
return replaceWith(ICmpPredicate::ne, smaller,
|
|
getConstant(APInt::getZero(newWidth)));
|
|
}
|
|
|
|
break;
|
|
case ICmpPredicate::sle:
|
|
// x <= max -> true
|
|
if (rhs.isMaxSignedValue())
|
|
return replaceWithConstantI1(1);
|
|
// x <= c -> x < (c+1)
|
|
return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
|
|
case ICmpPredicate::sge:
|
|
// x >= min -> true
|
|
if (rhs.isMinSignedValue())
|
|
return replaceWithConstantI1(1);
|
|
// x >= c -> x > (c-1)
|
|
return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
|
|
case ICmpPredicate::ule:
|
|
// x <= max -> true
|
|
if (rhs.isAllOnes())
|
|
return replaceWithConstantI1(1);
|
|
// x <= c -> x < (c+1)
|
|
return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
|
|
case ICmpPredicate::uge:
|
|
// x >= min -> true
|
|
if (rhs.isZero())
|
|
return replaceWithConstantI1(1);
|
|
// x >= c -> x > (c-1)
|
|
return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
|
|
case ICmpPredicate::eq:
|
|
if (rhs.getBitWidth() == 1) {
|
|
if (rhs.isZero()) {
|
|
// x == 0 -> x ^ 1
|
|
replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
|
|
getConstant(APInt(1, 1)),
|
|
op.getTwoState());
|
|
return success();
|
|
}
|
|
if (rhs.isAllOnes()) {
|
|
// x == 1 -> x
|
|
replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
|
|
return success();
|
|
}
|
|
}
|
|
break;
|
|
case ICmpPredicate::ne:
|
|
if (rhs.getBitWidth() == 1) {
|
|
if (rhs.isZero()) {
|
|
// x != 0 -> x
|
|
replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
|
|
return success();
|
|
}
|
|
if (rhs.isAllOnes()) {
|
|
// x != 1 -> x ^ 1
|
|
replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
|
|
getConstant(APInt(1, 1)),
|
|
op.getTwoState());
|
|
return success();
|
|
}
|
|
}
|
|
break;
|
|
case ICmpPredicate::ceq:
|
|
case ICmpPredicate::cne:
|
|
case ICmpPredicate::weq:
|
|
case ICmpPredicate::wne:
|
|
break;
|
|
}
|
|
|
|
// We have some specific optimizations for comparison with a constant that
|
|
// are only supported for equality comparisons.
|
|
if (op.getPredicate() == ICmpPredicate::eq ||
|
|
op.getPredicate() == ICmpPredicate::ne) {
|
|
// Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
|
|
// with a smaller constant. We only support equality comparisons for
|
|
// this.
|
|
auto knownBits = computeKnownBits(op.getLhs());
|
|
if (!knownBits.isUnknown())
|
|
return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
|
|
rewriter),
|
|
success();
|
|
|
|
// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
|
|
// cst1^cst2).
|
|
if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
|
|
if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
|
|
return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
|
|
success();
|
|
|
|
// Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
|
|
// all one.
|
|
if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
|
|
if (rhs.isAllOnes() || rhs.isZero()) {
|
|
auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
|
|
auto cst = rewriter.create<hw::ConstantOp>(
|
|
op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
|
|
: APInt::getZero(width));
|
|
replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
|
|
rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
|
|
op.getTwoState());
|
|
return success();
|
|
}
|
|
}
|
|
}
|
|
|
|
// icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
|
|
// b), cat(c, d)). contains special handling for sign bit in signed
|
|
// compressions.
|
|
if (Operation *opLHS = op.getLhs().getDefiningOp())
|
|
if (Operation *opRHS = op.getRhs().getDefiningOp())
|
|
if (isa<ConcatOp, ReplicateOp>(opLHS) &&
|
|
isa<ConcatOp, ReplicateOp>(opRHS)) {
|
|
if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|