mirror of https://github.com/llvm/circt.git
1139 lines
43 KiB
C++
1139 lines
43 KiB
C++
//===- CombToAIG.cpp - Comb to AIG Conversion Pass --------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This is the main Comb to AIG Conversion Pass Implementation.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Conversion/CombToAIG.h"
|
|
#include "circt/Dialect/AIG/AIGOps.h"
|
|
#include "circt/Dialect/Comb/CombOps.h"
|
|
#include "circt/Dialect/HW/HWOps.h"
|
|
#include "circt/Support/Naming.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "llvm/ADT/PointerUnion.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "comb-to-aig"
|
|
|
|
namespace circt {
|
|
#define GEN_PASS_DEF_CONVERTCOMBTOAIG
|
|
#include "circt/Conversion/Passes.h.inc"
|
|
} // namespace circt
|
|
|
|
using namespace circt;
|
|
using namespace comb;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utility Functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// A wrapper for comb::extractBits that returns a SmallVector<Value>.
|
|
static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
|
|
SmallVector<Value> bits;
|
|
comb::extractBits(builder, val, bits);
|
|
return bits;
|
|
}
|
|
|
|
// Construct a mux tree for shift operations. `isLeftShift` controls the
|
|
// direction of the shift operation and is used to determine order of the
|
|
// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
|
|
// to get the padding and extracted bits for each shift amount. `getPadding`
|
|
// could return a nullptr as i0 value but except for that, these callbacks must
|
|
// return a valid value for each shift amount in the range [0, maxShiftAmount].
|
|
// The value for `maxShiftAmount` is used as the out-of-bounds value.
|
|
template <bool isLeftShift>
|
|
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
|
|
Value shiftAmount, int64_t maxShiftAmount,
|
|
llvm::function_ref<Value(int64_t)> getPadding,
|
|
llvm::function_ref<Value(int64_t)> getExtract) {
|
|
// Extract individual bits from shift amount
|
|
auto bits = extractBits(rewriter, shiftAmount);
|
|
|
|
// Create nodes for each possible shift amount
|
|
SmallVector<Value> nodes;
|
|
nodes.reserve(maxShiftAmount);
|
|
for (int64_t i = 0; i < maxShiftAmount; ++i) {
|
|
Value extract = getExtract(i);
|
|
Value padding = getPadding(i);
|
|
|
|
if (!padding) {
|
|
nodes.push_back(extract);
|
|
continue;
|
|
}
|
|
|
|
// Concatenate extracted bits with padding
|
|
if (isLeftShift)
|
|
nodes.push_back(
|
|
rewriter.createOrFold<comb::ConcatOp>(loc, extract, padding));
|
|
else
|
|
nodes.push_back(
|
|
rewriter.createOrFold<comb::ConcatOp>(loc, padding, extract));
|
|
}
|
|
|
|
// Create out-of-bounds value
|
|
auto outOfBoundsValue = getPadding(maxShiftAmount);
|
|
assert(outOfBoundsValue && "outOfBoundsValue must be valid");
|
|
|
|
// Construct mux tree for shift operation
|
|
auto result =
|
|
comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
|
|
|
|
// Add bounds checking
|
|
auto inBound = rewriter.createOrFold<comb::ICmpOp>(
|
|
loc, ICmpPredicate::ult, shiftAmount,
|
|
hw::ConstantOp::create(rewriter, loc, shiftAmount.getType(),
|
|
maxShiftAmount));
|
|
|
|
return rewriter.createOrFold<comb::MuxOp>(loc, inBound, result,
|
|
outOfBoundsValue);
|
|
}
|
|
|
|
namespace {
|
|
// A union of Value and IntegerAttr to cleanly handle constant values.
|
|
using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
|
|
} // namespace
|
|
|
|
// Return the number of unknown bits and populate the concatenated values.
|
|
static int64_t getNumUnknownBitsAndPopulateValues(
|
|
Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
|
|
// Constant or zero width value are all known.
|
|
if (value.getType().isInteger(0))
|
|
return 0;
|
|
|
|
// Recursively count unknown bits for concat.
|
|
if (auto concat = value.getDefiningOp<comb::ConcatOp>()) {
|
|
int64_t totalUnknownBits = 0;
|
|
for (auto concatInput : llvm::reverse(concat.getInputs())) {
|
|
auto unknownBits =
|
|
getNumUnknownBitsAndPopulateValues(concatInput, values);
|
|
if (unknownBits < 0)
|
|
return unknownBits;
|
|
totalUnknownBits += unknownBits;
|
|
}
|
|
return totalUnknownBits;
|
|
}
|
|
|
|
// Constant value is known.
|
|
if (auto constant = value.getDefiningOp<hw::ConstantOp>()) {
|
|
values.push_back(constant.getValueAttr());
|
|
return 0;
|
|
}
|
|
|
|
// Consider other operations as unknown bits.
|
|
// TODO: We can handle replicate, extract, etc.
|
|
values.push_back(value);
|
|
return hw::getBitWidth(value.getType());
|
|
}
|
|
|
|
// Return a value that substitutes the unknown bits with the mask.
|
|
static APInt
|
|
substitueMaskToValues(size_t width,
|
|
llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
|
|
uint32_t mask) {
|
|
uint32_t bitPos = 0, unknownPos = 0;
|
|
APInt result(width, 0);
|
|
for (auto constantOrValue : constantOrValues) {
|
|
int64_t elemWidth;
|
|
if (auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
|
|
elemWidth = constant.getValue().getBitWidth();
|
|
result.insertBits(constant.getValue(), bitPos);
|
|
} else {
|
|
elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
|
|
assert(elemWidth >= 0 && "unknown bit width");
|
|
assert(elemWidth + unknownPos < 32 && "unknown bit width too large");
|
|
// Create a mask for the unknown bits.
|
|
uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
|
|
result.insertBits(APInt(elemWidth, usedBits), bitPos);
|
|
unknownPos += elemWidth;
|
|
}
|
|
bitPos += elemWidth;
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
// Emulate a binary operation with unknown bits using a table lookup.
|
|
// This function enumerates all possible combinations of unknown bits and
|
|
// emulates the operation for each combination.
|
|
static LogicalResult emulateBinaryOpForUnknownBits(
|
|
ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
|
|
Operation *op,
|
|
llvm::function_ref<APInt(const APInt &, const APInt &)> emulate) {
|
|
SmallVector<ConstantOrValue> lhsValues, rhsValues;
|
|
|
|
assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
|
|
"op must be a single result binary operation");
|
|
|
|
auto lhs = op->getOperand(0);
|
|
auto rhs = op->getOperand(1);
|
|
auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
|
|
auto loc = op->getLoc();
|
|
auto numLhsUnknownBits = getNumUnknownBitsAndPopulateValues(lhs, lhsValues);
|
|
auto numRhsUnknownBits = getNumUnknownBitsAndPopulateValues(rhs, rhsValues);
|
|
|
|
// If unknown bit width is detected, abort the lowering.
|
|
if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
|
|
return failure();
|
|
|
|
int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
|
|
if (totalUnknownBits > maxEmulationUnknownBits)
|
|
return failure();
|
|
|
|
SmallVector<Value> emulatedResults;
|
|
emulatedResults.reserve(1 << totalUnknownBits);
|
|
|
|
// Emulate all possible cases.
|
|
DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
|
|
auto getConstant = [&](const APInt &value) -> hw::ConstantOp {
|
|
auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
|
|
auto it = constantPool.find(attr);
|
|
if (it != constantPool.end())
|
|
return it->second;
|
|
auto constant = hw::ConstantOp::create(rewriter, loc, value);
|
|
constantPool[attr] = constant;
|
|
return constant;
|
|
};
|
|
|
|
for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
|
|
lhsMask < lhsMaskEnd; ++lhsMask) {
|
|
APInt lhsValue = substitueMaskToValues(width, lhsValues, lhsMask);
|
|
for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
|
|
rhsMask < rhsMaskEnd; ++rhsMask) {
|
|
APInt rhsValue = substitueMaskToValues(width, rhsValues, rhsMask);
|
|
// Emulate.
|
|
emulatedResults.push_back(getConstant(emulate(lhsValue, rhsValue)));
|
|
}
|
|
}
|
|
|
|
// Create selectors for mux tree.
|
|
SmallVector<Value> selectors;
|
|
selectors.reserve(totalUnknownBits);
|
|
for (auto &concatedValues : {rhsValues, lhsValues})
|
|
for (auto valueOrConstant : concatedValues) {
|
|
auto value = dyn_cast<Value>(valueOrConstant);
|
|
if (!value)
|
|
continue;
|
|
extractBits(rewriter, value, selectors);
|
|
}
|
|
|
|
assert(totalUnknownBits == static_cast<int64_t>(selectors.size()) &&
|
|
"number of selectors must match");
|
|
auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
|
|
getConstant(APInt::getZero(width)));
|
|
|
|
replaceOpAndCopyNamehint(rewriter, op, muxed);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Conversion patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Lower a comb::AndOp operation to aig::AndInverterOp
|
|
struct CombAndOpConversion : OpConversionPattern<AndOp> {
|
|
using OpConversionPattern<AndOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(AndOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
SmallVector<bool> nonInverts(adaptor.getInputs().size(), false);
|
|
replaceOpWithNewOpAndCopyNamehint<aig::AndInverterOp>(
|
|
rewriter, op, adaptor.getInputs(), nonInverts);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower a comb::OrOp operation to aig::AndInverterOp with invert flags
|
|
struct CombOrOpConversion : OpConversionPattern<OrOp> {
|
|
using OpConversionPattern<OrOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(OrOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Implement Or using And and invert flags: a | b = ~(~a & ~b)
|
|
SmallVector<bool> allInverts(adaptor.getInputs().size(), true);
|
|
auto andOp = aig::AndInverterOp::create(rewriter, op.getLoc(),
|
|
adaptor.getInputs(), allInverts);
|
|
replaceOpWithNewOpAndCopyNamehint<aig::AndInverterOp>(rewriter, op, andOp,
|
|
/*invert=*/true);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower a comb::XorOp operation to AIG operations
|
|
struct CombXorOpConversion : OpConversionPattern<XorOp> {
|
|
using OpConversionPattern<XorOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(XorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (op.getNumOperands() != 2)
|
|
return failure();
|
|
// Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)
|
|
|
|
// (a | b) = ~(~a & ~b)
|
|
// (~a | ~b) = ~(a & b)
|
|
auto inputs = adaptor.getInputs();
|
|
SmallVector<bool> allInverts(inputs.size(), true);
|
|
SmallVector<bool> allNotInverts(inputs.size(), false);
|
|
|
|
auto notAAndNotB =
|
|
aig::AndInverterOp::create(rewriter, op.getLoc(), inputs, allInverts);
|
|
auto aAndB = aig::AndInverterOp::create(rewriter, op.getLoc(), inputs,
|
|
allNotInverts);
|
|
|
|
replaceOpWithNewOpAndCopyNamehint<aig::AndInverterOp>(rewriter, op,
|
|
notAAndNotB, aAndB,
|
|
/*lhs_invert=*/true,
|
|
/*rhs_invert=*/true);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <typename OpTy>
|
|
struct CombLowerVariadicOp : OpConversionPattern<OpTy> {
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
|
|
LogicalResult
|
|
matchAndRewrite(OpTy op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto result = lowerFullyAssociativeOp(op, op.getOperands(), rewriter);
|
|
replaceOpAndCopyNamehint(rewriter, op, result);
|
|
return success();
|
|
}
|
|
|
|
static Value lowerFullyAssociativeOp(OpTy op, OperandRange operands,
|
|
ConversionPatternRewriter &rewriter) {
|
|
Value lhs, rhs;
|
|
switch (operands.size()) {
|
|
case 0:
|
|
assert(false && "cannot be called with empty operand range");
|
|
break;
|
|
case 1:
|
|
return operands[0];
|
|
case 2:
|
|
lhs = operands[0];
|
|
rhs = operands[1];
|
|
return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
|
|
default:
|
|
auto firstHalf = operands.size() / 2;
|
|
lhs =
|
|
lowerFullyAssociativeOp(op, operands.take_front(firstHalf), rewriter);
|
|
rhs =
|
|
lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), rewriter);
|
|
return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
|
|
}
|
|
}
|
|
};
|
|
|
|
// Lower comb::MuxOp to AIG operations.
|
|
struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
|
|
using OpConversionPattern<MuxOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(MuxOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Implement: c ? a : b = (replicate(c) & a) | (~replicate(c) & b)
|
|
|
|
Value cond = op.getCond();
|
|
auto trueVal = op.getTrueValue();
|
|
auto falseVal = op.getFalseValue();
|
|
|
|
if (!op.getType().isInteger()) {
|
|
// If the type of the mux is not integer, bitcast the operands first.
|
|
auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
|
|
trueVal =
|
|
hw::BitcastOp::create(rewriter, op->getLoc(), widthType, trueVal);
|
|
falseVal =
|
|
hw::BitcastOp::create(rewriter, op->getLoc(), widthType, falseVal);
|
|
}
|
|
|
|
// Replicate condition if needed
|
|
if (!trueVal.getType().isInteger(1))
|
|
cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
|
|
cond);
|
|
|
|
// c ? a : b => (replicate(c) & a) | (~replicate(c) & b)
|
|
auto lhs = aig::AndInverterOp::create(rewriter, op.getLoc(), cond, trueVal);
|
|
auto rhs = aig::AndInverterOp::create(rewriter, op.getLoc(), cond, falseVal,
|
|
true, false);
|
|
|
|
Value result = comb::OrOp::create(rewriter, op.getLoc(), lhs, rhs);
|
|
// Insert the bitcast if the type of the mux is not integer.
|
|
if (result.getType() != op.getType())
|
|
result =
|
|
hw::BitcastOp::create(rewriter, op.getLoc(), op.getType(), result);
|
|
replaceOpAndCopyNamehint(rewriter, op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct CombAddOpConversion : OpConversionPattern<AddOp> {
|
|
using OpConversionPattern<AddOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(AddOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto inputs = adaptor.getInputs();
|
|
// Lower only when there are two inputs.
|
|
// Variadic operands must be lowered in a different pattern.
|
|
if (inputs.size() != 2)
|
|
return failure();
|
|
|
|
auto width = op.getType().getIntOrFloatBitWidth();
|
|
// Skip a zero width value.
|
|
if (width == 0) {
|
|
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
|
|
op.getType(), 0);
|
|
return success();
|
|
}
|
|
|
|
if (width < 8)
|
|
lowerRippleCarryAdder(op, inputs, rewriter);
|
|
else
|
|
lowerParallelPrefixAdder(op, inputs, rewriter);
|
|
|
|
return success();
|
|
}
|
|
|
|
// Implement a basic ripple-carry adder for small bitwidths.
|
|
void lowerRippleCarryAdder(comb::AddOp op, ValueRange inputs,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto width = op.getType().getIntOrFloatBitWidth();
|
|
// Implement a naive Ripple-carry full adder.
|
|
Value carry;
|
|
|
|
auto aBits = extractBits(rewriter, inputs[0]);
|
|
auto bBits = extractBits(rewriter, inputs[1]);
|
|
SmallVector<Value> results;
|
|
results.resize(width);
|
|
for (int64_t i = 0; i < width; ++i) {
|
|
SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
|
|
if (carry)
|
|
xorOperands.push_back(carry);
|
|
|
|
// sum[i] = xor(carry[i-1], a[i], b[i])
|
|
// NOTE: The result is stored in reverse order.
|
|
results[width - i - 1] =
|
|
comb::XorOp::create(rewriter, op.getLoc(), xorOperands, true);
|
|
|
|
// If this is the last bit, we are done.
|
|
if (i == width - 1)
|
|
break;
|
|
|
|
// carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
|
|
Value nextCarry = comb::AndOp::create(
|
|
rewriter, op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
|
|
if (!carry) {
|
|
// This is the first bit, so the carry is the next carry.
|
|
carry = nextCarry;
|
|
continue;
|
|
}
|
|
|
|
auto aXnorB = comb::XorOp::create(rewriter, op.getLoc(),
|
|
ValueRange{aBits[i], bBits[i]}, true);
|
|
auto andOp = comb::AndOp::create(rewriter, op.getLoc(),
|
|
ValueRange{carry, aXnorB}, true);
|
|
carry = comb::OrOp::create(rewriter, op.getLoc(),
|
|
ValueRange{andOp, nextCarry}, true);
|
|
}
|
|
LLVM_DEBUG(llvm::dbgs() << "Lower comb.add to Ripple-Carry Adder of width "
|
|
<< width << "\n");
|
|
|
|
replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
|
|
}
|
|
|
|
// Implement a parallel prefix adder - with Kogge-Stone or Brent-Kung trees
|
|
// Will introduce unused signals for the carry bits but these will be removed
|
|
// by the AIG pass.
|
|
void lowerParallelPrefixAdder(comb::AddOp op, ValueRange inputs,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto width = op.getType().getIntOrFloatBitWidth();
|
|
|
|
auto aBits = extractBits(rewriter, inputs[0]);
|
|
auto bBits = extractBits(rewriter, inputs[1]);
|
|
// Construct propagate (p) and generate (g) signals
|
|
SmallVector<Value> p, g;
|
|
p.reserve(width);
|
|
g.reserve(width);
|
|
|
|
for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
|
|
// p_i = a_i XOR b_i
|
|
p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
|
|
// g_i = a_i AND b_i
|
|
g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
|
|
}
|
|
|
|
LLVM_DEBUG({
|
|
llvm::dbgs() << "Lower comb.add to Parallel-Prefix of width " << width
|
|
<< "\n--------------------------------------- Init\n";
|
|
|
|
for (int64_t i = 0; i < width; ++i) {
|
|
// p_i = a_i XOR b_i
|
|
llvm::dbgs() << "P0" << i << " = A" << i << " XOR B" << i << "\n";
|
|
// g_i = a_i AND b_i
|
|
llvm::dbgs() << "G0" << i << " = A" << i << " AND B" << i << "\n";
|
|
}
|
|
});
|
|
|
|
// Create copies of p and g for the prefix computation
|
|
SmallVector<Value> pPrefix = p;
|
|
SmallVector<Value> gPrefix = g;
|
|
if (width < 32)
|
|
lowerKoggeStonePrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
|
|
else
|
|
lowerBrentKungPrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
|
|
|
|
// Generate result sum bits
|
|
// NOTE: The result is stored in reverse order.
|
|
SmallVector<Value> results;
|
|
results.resize(width);
|
|
// Sum bit 0 is just p[0] since carry_in = 0
|
|
results[width - 1] = p[0];
|
|
|
|
// For remaining bits, sum_i = p_i XOR c_(i-1)
|
|
// The carry into position i is the group generate from position i-1
|
|
for (int64_t i = 1; i < width; ++i)
|
|
results[width - 1 - i] =
|
|
comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
|
|
|
|
replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
|
|
|
|
LLVM_DEBUG({
|
|
llvm::dbgs() << "--------------------------------------- Completion\n"
|
|
<< "RES0 = P0\n";
|
|
for (int64_t i = 1; i < width; ++i)
|
|
llvm::dbgs() << "RES" << i << " = P" << i << " XOR G" << i - 1 << "\n";
|
|
});
|
|
}
|
|
|
|
// Implement the Kogge-Stone parallel prefix tree
|
|
// Described in https://en.wikipedia.org/wiki/Kogge%E2%80%93Stone_adder
|
|
// Slightly better delay than Brent-Kung, but more area.
|
|
void lowerKoggeStonePrefixTree(comb::AddOp op, ValueRange inputs,
|
|
ConversionPatternRewriter &rewriter,
|
|
SmallVector<Value> &pPrefix,
|
|
SmallVector<Value> &gPrefix) const {
|
|
auto width = op.getType().getIntOrFloatBitWidth();
|
|
|
|
// Kogge-Stone parallel prefix computation
|
|
for (int64_t stride = 1; stride < width; stride *= 2) {
|
|
for (int64_t i = stride; i < width; ++i) {
|
|
int64_t j = i - stride;
|
|
// Group generate: g_i OR (p_i AND g_j)
|
|
Value andPG =
|
|
comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
|
|
gPrefix[i] =
|
|
comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
|
|
|
|
// Group propagate: p_i AND p_j
|
|
pPrefix[i] =
|
|
comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
|
|
}
|
|
}
|
|
LLVM_DEBUG({
|
|
int64_t stage = 0;
|
|
for (int64_t stride = 1; stride < width; stride *= 2) {
|
|
llvm::dbgs()
|
|
<< "--------------------------------------- Kogge-Stone Stage "
|
|
<< stage << "\n";
|
|
for (int64_t i = stride; i < width; ++i) {
|
|
int64_t j = i - stride;
|
|
// Group generate: g_i OR (p_i AND g_j)
|
|
llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
|
|
<< " OR (P" << i << stage << " AND G" << j << stage
|
|
<< ")\n";
|
|
|
|
// Group propagate: p_i AND p_j
|
|
llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
|
|
<< " AND P" << j << stage << "\n";
|
|
}
|
|
++stage;
|
|
}
|
|
});
|
|
}
|
|
|
|
// Implement the Brent-Kung parallel prefix tree
|
|
// Described in https://en.wikipedia.org/wiki/Brent%E2%80%93Kung_adder
|
|
// Slightly worse delay than Kogge-Stone, but less area.
|
|
void lowerBrentKungPrefixTree(comb::AddOp op, ValueRange inputs,
|
|
ConversionPatternRewriter &rewriter,
|
|
SmallVector<Value> &pPrefix,
|
|
SmallVector<Value> &gPrefix) const {
|
|
auto width = op.getType().getIntOrFloatBitWidth();
|
|
|
|
// Brent-Kung parallel prefix computation
|
|
// Forward phase
|
|
int64_t stride;
|
|
for (stride = 1; stride < width; stride *= 2) {
|
|
for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
|
|
int64_t j = i - stride;
|
|
|
|
// Group generate: g_i OR (p_i AND g_j)
|
|
Value andPG =
|
|
comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
|
|
gPrefix[i] =
|
|
comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
|
|
|
|
// Group propagate: p_i AND p_j
|
|
pPrefix[i] =
|
|
comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
|
|
}
|
|
}
|
|
|
|
// Backward phase
|
|
for (; stride > 0; stride /= 2) {
|
|
for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
|
|
int64_t j = i - stride;
|
|
|
|
// Group generate: g_i OR (p_i AND g_j)
|
|
Value andPG =
|
|
comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
|
|
gPrefix[i] = OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
|
|
|
|
// Group propagate: p_i AND p_j
|
|
pPrefix[i] =
|
|
comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
|
|
}
|
|
}
|
|
|
|
LLVM_DEBUG({
|
|
int64_t stage = 0;
|
|
for (stride = 1; stride < width; stride *= 2) {
|
|
llvm::dbgs() << "--------------------------------------- Brent-Kung FW "
|
|
<< stage << " : Stride " << stride << "\n";
|
|
for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
|
|
int64_t j = i - stride;
|
|
|
|
// Group generate: g_i OR (p_i AND g_j)
|
|
llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
|
|
<< " OR (P" << i << stage << " AND G" << j << stage
|
|
<< ")\n";
|
|
|
|
// Group propagate: p_i AND p_j
|
|
llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
|
|
<< " AND P" << j << stage << "\n";
|
|
}
|
|
++stage;
|
|
}
|
|
|
|
for (; stride > 0; stride /= 2) {
|
|
if (stride * 3 - 1 < width)
|
|
llvm::dbgs()
|
|
<< "--------------------------------------- Brent-Kung BW "
|
|
<< stage << " : Stride " << stride << "\n";
|
|
|
|
for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
|
|
int64_t j = i - stride;
|
|
|
|
// Group generate: g_i OR (p_i AND g_j)
|
|
llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
|
|
<< " OR (P" << i << stage << " AND G" << j << stage
|
|
<< ")\n";
|
|
|
|
// Group propagate: p_i AND p_j
|
|
llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
|
|
<< " AND P" << j << stage << "\n";
|
|
}
|
|
--stage;
|
|
}
|
|
});
|
|
}
|
|
};
|
|
|
|
struct CombSubOpConversion : OpConversionPattern<SubOp> {
|
|
using OpConversionPattern<SubOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(SubOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto lhs = op.getLhs();
|
|
auto rhs = op.getRhs();
|
|
// Since `-rhs = ~rhs + 1` holds, rewrite `sub(lhs, rhs)` to:
|
|
// sub(lhs, rhs) => add(lhs, -rhs) => add(lhs, add(~rhs, 1))
|
|
// => add(lhs, ~rhs, 1)
|
|
auto notRhs = aig::AndInverterOp::create(rewriter, op.getLoc(), rhs,
|
|
/*invert=*/true);
|
|
auto one = hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(), 1);
|
|
replaceOpWithNewOpAndCopyNamehint<comb::AddOp>(
|
|
rewriter, op, ValueRange{lhs, notRhs, one}, true);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct CombMulOpConversion : OpConversionPattern<MulOp> {
|
|
using OpConversionPattern<MulOp>::OpConversionPattern;
|
|
using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
|
|
LogicalResult
|
|
matchAndRewrite(MulOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (adaptor.getInputs().size() != 2)
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
Value a = adaptor.getInputs()[0];
|
|
Value b = adaptor.getInputs()[1];
|
|
unsigned width = op.getType().getIntOrFloatBitWidth();
|
|
|
|
// Skip a zero width value.
|
|
if (width == 0) {
|
|
rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
|
|
return success();
|
|
}
|
|
|
|
// Extract individual bits from operands
|
|
SmallVector<Value> aBits = extractBits(rewriter, a);
|
|
SmallVector<Value> bBits = extractBits(rewriter, b);
|
|
|
|
auto falseValue = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
|
|
|
|
// Generate partial products
|
|
SmallVector<SmallVector<Value>> partialProducts;
|
|
partialProducts.reserve(width);
|
|
for (unsigned i = 0; i < width; ++i) {
|
|
SmallVector<Value> row(i, falseValue);
|
|
row.reserve(width);
|
|
// Generate partial product bits
|
|
for (unsigned j = 0; i + j < width; ++j)
|
|
row.push_back(
|
|
rewriter.createOrFold<comb::AndOp>(loc, aBits[j], bBits[i]));
|
|
|
|
partialProducts.push_back(row);
|
|
}
|
|
|
|
// If the width is 1, we are done.
|
|
if (width == 1) {
|
|
rewriter.replaceOp(op, partialProducts[0][0]);
|
|
return success();
|
|
}
|
|
|
|
// Wallace tree reduction - reduce to two addends.
|
|
auto addends =
|
|
comb::wallaceReduction(rewriter, loc, width, 2, partialProducts);
|
|
// Sum the two addends using a carry-propagate adder
|
|
auto newAdd = comb::AddOp::create(rewriter, loc, addends, true);
|
|
replaceOpAndCopyNamehint(rewriter, op, newAdd);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <typename OpTy>
|
|
struct DivModOpConversionBase : OpConversionPattern<OpTy> {
|
|
DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
|
|
: OpConversionPattern<OpTy>(context),
|
|
maxEmulationUnknownBits(maxEmulationUnknownBits) {
|
|
assert(maxEmulationUnknownBits < 32 &&
|
|
"maxEmulationUnknownBits must be less than 32");
|
|
}
|
|
const int64_t maxEmulationUnknownBits;
|
|
};
|
|
|
|
struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
|
|
using DivModOpConversionBase<DivUOp>::DivModOpConversionBase;
|
|
LogicalResult
|
|
matchAndRewrite(DivUOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Check if the divisor is a power of two.
|
|
if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
|
|
if (rhsConstantOp.getValue().isPowerOf2()) {
|
|
// Extract upper bits.
|
|
size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
|
|
size_t width = op.getType().getIntOrFloatBitWidth();
|
|
Value upperBits = rewriter.createOrFold<comb::ExtractOp>(
|
|
op.getLoc(), adaptor.getLhs(), extractAmount,
|
|
width - extractAmount);
|
|
Value constZero = hw::ConstantOp::create(rewriter, op.getLoc(),
|
|
APInt::getZero(extractAmount));
|
|
replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
|
|
rewriter, op, op.getType(), ArrayRef<Value>{constZero, upperBits});
|
|
return success();
|
|
}
|
|
|
|
// When rhs is not power of two and the number of unknown bits are small,
|
|
// create a mux tree that emulates all possible cases.
|
|
return emulateBinaryOpForUnknownBits(
|
|
rewriter, maxEmulationUnknownBits, op,
|
|
[](const APInt &lhs, const APInt &rhs) {
|
|
// Division by zero is undefined, just return zero.
|
|
if (rhs.isZero())
|
|
return APInt::getZero(rhs.getBitWidth());
|
|
return lhs.udiv(rhs);
|
|
});
|
|
}
|
|
};
|
|
|
|
struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
|
|
using DivModOpConversionBase<ModUOp>::DivModOpConversionBase;
|
|
LogicalResult
|
|
matchAndRewrite(ModUOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Check if the divisor is a power of two.
|
|
if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
|
|
if (rhsConstantOp.getValue().isPowerOf2()) {
|
|
// Extract lower bits.
|
|
size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
|
|
size_t width = op.getType().getIntOrFloatBitWidth();
|
|
Value lowerBits = rewriter.createOrFold<comb::ExtractOp>(
|
|
op.getLoc(), adaptor.getLhs(), 0, extractAmount);
|
|
Value constZero = hw::ConstantOp::create(
|
|
rewriter, op.getLoc(), APInt::getZero(width - extractAmount));
|
|
replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
|
|
rewriter, op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
|
|
return success();
|
|
}
|
|
|
|
// When rhs is not power of two and the number of unknown bits are small,
|
|
// create a mux tree that emulates all possible cases.
|
|
return emulateBinaryOpForUnknownBits(
|
|
rewriter, maxEmulationUnknownBits, op,
|
|
[](const APInt &lhs, const APInt &rhs) {
|
|
// Division by zero is undefined, just return zero.
|
|
if (rhs.isZero())
|
|
return APInt::getZero(rhs.getBitWidth());
|
|
return lhs.urem(rhs);
|
|
});
|
|
}
|
|
};
|
|
|
|
struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
|
|
using DivModOpConversionBase<DivSOp>::DivModOpConversionBase;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(DivSOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Currently only lower with emulation.
|
|
// TODO: Implement a signed division lowering at least for power of two.
|
|
return emulateBinaryOpForUnknownBits(
|
|
rewriter, maxEmulationUnknownBits, op,
|
|
[](const APInt &lhs, const APInt &rhs) {
|
|
// Division by zero is undefined, just return zero.
|
|
if (rhs.isZero())
|
|
return APInt::getZero(rhs.getBitWidth());
|
|
return lhs.sdiv(rhs);
|
|
});
|
|
}
|
|
};
|
|
|
|
struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
|
|
using DivModOpConversionBase<ModSOp>::DivModOpConversionBase;
|
|
LogicalResult
|
|
matchAndRewrite(ModSOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Currently only lower with emulation.
|
|
// TODO: Implement a signed modulus lowering at least for power of two.
|
|
return emulateBinaryOpForUnknownBits(
|
|
rewriter, maxEmulationUnknownBits, op,
|
|
[](const APInt &lhs, const APInt &rhs) {
|
|
// Division by zero is undefined, just return zero.
|
|
if (rhs.isZero())
|
|
return APInt::getZero(rhs.getBitWidth());
|
|
return lhs.srem(rhs);
|
|
});
|
|
}
|
|
};
|
|
|
|
struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
|
|
using OpConversionPattern<ICmpOp>::OpConversionPattern;
|
|
static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
|
|
ArrayRef<Value> bBits, bool isLess,
|
|
bool includeEq,
|
|
ConversionPatternRewriter &rewriter) {
|
|
// Construct following unsigned comparison expressions.
|
|
// a <= b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0])
|
|
// a < b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0])
|
|
// a >= b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] >= b[n-1:0])
|
|
// a > b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] > b[n-1:0])
|
|
Value acc =
|
|
hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(), includeEq);
|
|
|
|
for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
|
|
auto aBitXorBBit =
|
|
rewriter.createOrFold<comb::XorOp>(op.getLoc(), aBit, bBit, true);
|
|
auto aEqualB = rewriter.createOrFold<aig::AndInverterOp>(
|
|
op.getLoc(), aBitXorBBit, true);
|
|
auto pred = rewriter.createOrFold<aig::AndInverterOp>(
|
|
op.getLoc(), aBit, bBit, isLess, !isLess);
|
|
|
|
auto aBitAndBBit = rewriter.createOrFold<comb::AndOp>(
|
|
op.getLoc(), ValueRange{aEqualB, acc}, true);
|
|
acc = rewriter.createOrFold<comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
|
|
true);
|
|
}
|
|
return acc;
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto lhs = adaptor.getLhs();
|
|
auto rhs = adaptor.getRhs();
|
|
|
|
switch (op.getPredicate()) {
|
|
default:
|
|
return failure();
|
|
|
|
case ICmpPredicate::eq:
|
|
case ICmpPredicate::ceq: {
|
|
// a == b ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ...
|
|
auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
|
|
auto xorBits = extractBits(rewriter, xorOp);
|
|
SmallVector<bool> allInverts(xorBits.size(), true);
|
|
replaceOpWithNewOpAndCopyNamehint<aig::AndInverterOp>(
|
|
rewriter, op, xorBits, allInverts);
|
|
return success();
|
|
}
|
|
|
|
case ICmpPredicate::ne:
|
|
case ICmpPredicate::cne: {
|
|
// a != b ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ...
|
|
auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
|
|
replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
|
|
rewriter, op, extractBits(rewriter, xorOp), true);
|
|
return success();
|
|
}
|
|
|
|
case ICmpPredicate::uge:
|
|
case ICmpPredicate::ugt:
|
|
case ICmpPredicate::ule:
|
|
case ICmpPredicate::ult: {
|
|
bool isLess = op.getPredicate() == ICmpPredicate::ult ||
|
|
op.getPredicate() == ICmpPredicate::ule;
|
|
bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
|
|
op.getPredicate() == ICmpPredicate::ule;
|
|
auto aBits = extractBits(rewriter, lhs);
|
|
auto bBits = extractBits(rewriter, rhs);
|
|
replaceOpAndCopyNamehint(rewriter, op,
|
|
constructUnsignedCompare(op, aBits, bBits,
|
|
isLess, includeEq,
|
|
rewriter));
|
|
return success();
|
|
}
|
|
case ICmpPredicate::slt:
|
|
case ICmpPredicate::sle:
|
|
case ICmpPredicate::sgt:
|
|
case ICmpPredicate::sge: {
|
|
if (lhs.getType().getIntOrFloatBitWidth() == 0)
|
|
return rewriter.notifyMatchFailure(
|
|
op.getLoc(), "i0 signed comparison is unsupported");
|
|
bool isLess = op.getPredicate() == ICmpPredicate::slt ||
|
|
op.getPredicate() == ICmpPredicate::sle;
|
|
bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
|
|
op.getPredicate() == ICmpPredicate::sle;
|
|
|
|
auto aBits = extractBits(rewriter, lhs);
|
|
auto bBits = extractBits(rewriter, rhs);
|
|
|
|
// Get a sign bit
|
|
auto signA = aBits.back();
|
|
auto signB = bBits.back();
|
|
|
|
// Compare magnitudes (all bits except sign)
|
|
auto sameSignResult = constructUnsignedCompare(
|
|
op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
|
|
includeEq, rewriter);
|
|
|
|
// XOR of signs: true if signs are different
|
|
auto signsDiffer =
|
|
comb::XorOp::create(rewriter, op.getLoc(), signA, signB);
|
|
|
|
// Result when signs are different
|
|
Value diffSignResult = isLess ? signA : signB;
|
|
|
|
// Final result: choose based on whether signs differ
|
|
replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
|
|
rewriter, op, signsDiffer, diffSignResult, sameSignResult);
|
|
return success();
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
struct CombParityOpConversion : OpConversionPattern<ParityOp> {
|
|
using OpConversionPattern<ParityOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ParityOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Parity is the XOR of all bits.
|
|
replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
|
|
rewriter, op, extractBits(rewriter, adaptor.getInput()), true);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct CombShlOpConversion : OpConversionPattern<comb::ShlOp> {
|
|
using OpConversionPattern<comb::ShlOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(comb::ShlOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto width = op.getType().getIntOrFloatBitWidth();
|
|
auto lhs = adaptor.getLhs();
|
|
auto result = createShiftLogic</*isLeftShift=*/true>(
|
|
rewriter, op.getLoc(), adaptor.getRhs(), width,
|
|
/*getPadding=*/
|
|
[&](int64_t index) {
|
|
// Don't create zero width value.
|
|
if (index == 0)
|
|
return Value();
|
|
// Padding is 0 for left shift.
|
|
return rewriter.createOrFold<hw::ConstantOp>(
|
|
op.getLoc(), rewriter.getIntegerType(index), 0);
|
|
},
|
|
/*getExtract=*/
|
|
[&](int64_t index) {
|
|
assert(index < width && "index out of bounds");
|
|
// Exract the bits from LSB.
|
|
return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, 0,
|
|
width - index);
|
|
});
|
|
|
|
replaceOpAndCopyNamehint(rewriter, op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct CombShrUOpConversion : OpConversionPattern<comb::ShrUOp> {
|
|
using OpConversionPattern<comb::ShrUOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(comb::ShrUOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto width = op.getType().getIntOrFloatBitWidth();
|
|
auto lhs = adaptor.getLhs();
|
|
auto result = createShiftLogic</*isLeftShift=*/false>(
|
|
rewriter, op.getLoc(), adaptor.getRhs(), width,
|
|
/*getPadding=*/
|
|
[&](int64_t index) {
|
|
// Don't create zero width value.
|
|
if (index == 0)
|
|
return Value();
|
|
// Padding is 0 for right shift.
|
|
return rewriter.createOrFold<hw::ConstantOp>(
|
|
op.getLoc(), rewriter.getIntegerType(index), 0);
|
|
},
|
|
/*getExtract=*/
|
|
[&](int64_t index) {
|
|
assert(index < width && "index out of bounds");
|
|
// Exract the bits from MSB.
|
|
return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
|
|
width - index);
|
|
});
|
|
|
|
replaceOpAndCopyNamehint(rewriter, op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct CombShrSOpConversion : OpConversionPattern<comb::ShrSOp> {
|
|
using OpConversionPattern<comb::ShrSOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(comb::ShrSOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto width = op.getType().getIntOrFloatBitWidth();
|
|
if (width == 0)
|
|
return rewriter.notifyMatchFailure(op.getLoc(),
|
|
"i0 signed shift is unsupported");
|
|
auto lhs = adaptor.getLhs();
|
|
// Get the sign bit.
|
|
auto sign =
|
|
rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
|
|
|
|
// NOTE: The max shift amount is width - 1 because the sign bit is
|
|
// already shifted out.
|
|
auto result = createShiftLogic</*isLeftShift=*/false>(
|
|
rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
|
|
/*getPadding=*/
|
|
[&](int64_t index) {
|
|
return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
|
|
index + 1);
|
|
},
|
|
/*getExtract=*/
|
|
[&](int64_t index) {
|
|
return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
|
|
width - index - 1);
|
|
});
|
|
|
|
replaceOpAndCopyNamehint(rewriter, op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Convert Comb to AIG pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct ConvertCombToAIGPass
|
|
: public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
|
|
void runOnOperation() override;
|
|
using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
|
|
using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
|
|
using ConvertCombToAIGBase<ConvertCombToAIGPass>::maxEmulationUnknownBits;
|
|
};
|
|
} // namespace
|
|
|
|
static void
|
|
populateCombToAIGConversionPatterns(RewritePatternSet &patterns,
|
|
uint32_t maxEmulationUnknownBits) {
|
|
patterns.add<
|
|
// Bitwise Logical Ops
|
|
CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
|
|
CombMuxOpConversion, CombParityOpConversion,
|
|
// Arithmetic Ops
|
|
CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
|
|
CombICmpOpConversion,
|
|
// Shift Ops
|
|
CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
|
|
// Variadic ops that must be lowered to binary operations
|
|
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
|
|
CombLowerVariadicOp<MulOp>>(patterns.getContext());
|
|
|
|
// Add div/mod patterns with a threshold given by the pass option.
|
|
patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
|
|
CombModSOpConversion>(patterns.getContext(),
|
|
maxEmulationUnknownBits);
|
|
}
|
|
|
|
void ConvertCombToAIGPass::runOnOperation() {
|
|
ConversionTarget target(getContext());
|
|
|
|
// Comb is source dialect.
|
|
target.addIllegalDialect<comb::CombDialect>();
|
|
// Keep data movement operations like Extract, Concat and Replicate.
|
|
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
|
|
hw::BitcastOp, hw::ConstantOp>();
|
|
|
|
// Treat array operations as illegal. Strictly speaking, other than array
|
|
// get operation with non-const index are legal in AIG but array types
|
|
// prevent a bunch of optimizations so just lower them to integer
|
|
// operations. It's required to run HWAggregateToComb pass before this pass.
|
|
target.addIllegalOp<hw::ArrayGetOp, hw::ArrayCreateOp, hw::ArrayConcatOp,
|
|
hw::AggregateConstantOp>();
|
|
|
|
// AIG is target dialect.
|
|
target.addLegalDialect<aig::AIGDialect>();
|
|
|
|
// If additional legal ops are specified, add them to the target.
|
|
if (!additionalLegalOps.empty())
|
|
for (const auto &opName : additionalLegalOps)
|
|
target.addLegalOp(OperationName(opName, &getContext()));
|
|
|
|
RewritePatternSet patterns(&getContext());
|
|
populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits);
|
|
|
|
if (failed(mlir::applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|