[Datapath] Create Datapath to Comb Pass (#8736)

* Initiate datapath to comb pass

* Add tests and tidy datapath to compress implementation

* Improve comments

* Formatting and test corrections

* Correct CAPI

* Move wallace tree reduction and full-adder to comb ops

* Adding integration tests using circt-lec and correcting review comments

* Fix bug in Booth code for final sign correction row and add testing using lec. Add a forceBooth option largely for testing purposes

* Minor fix

* Formatting

* Removing populate patterns function

* Formatting
This commit is contained in:
Samuel Coward 2025-07-25 14:13:24 +01:00 committed by GitHub
parent a681fefd5e
commit 30a4a38fca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 585 additions and 82 deletions

View File

@ -0,0 +1,21 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_CONVERSION_DATAPATHTOCOMB_H
#define CIRCT_CONVERSION_DATAPATHTOCOMB_H
#include "circt/Support/LLVM.h"
namespace circt {
#define GEN_PASS_DECL_CONVERTDATAPATHTOCOMB
#include "circt/Conversion/Passes.h.inc"
} // namespace circt
#endif // CIRCT_CONVERSION_DATAPATHTOCOMB_H

View File

@ -26,6 +26,7 @@
#include "circt/Conversion/CombToSMT.h"
#include "circt/Conversion/ConvertToArcs.h"
#include "circt/Conversion/DCToHW.h"
#include "circt/Conversion/DatapathToComb.h"
#include "circt/Conversion/DatapathToSMT.h"
#include "circt/Conversion/ExportChiselInterface.h"
#include "circt/Conversion/ExportVerilog.h"

View File

@ -871,4 +871,22 @@ def ConvertDatapathToSMT : Pass<"convert-datapath-to-smt"> {
];
}
//===----------------------------------------------------------------------===//
// ConvertDatapathToComb
//===----------------------------------------------------------------------===//
def ConvertDatapathToComb : Pass<"convert-datapath-to-comb"> {
let summary = "Convert Datapath ops to Comb ops";
let dependentDialects = [
"circt::comb::CombDialect", "circt::datapath::DatapathDialect",
"circt::hw::HWDialect"
];
let options = [
Option<"lowerCompressToAdd", "lower-compress-to-add", "bool", "false",
"Lower compress operators to variadic add.">,
Option<"forceBooth", "lower-partial-product-to-booth", "bool", "false",
"Force all partial products to be lowered to Booth arrays.">
];
}
#endif // CIRCT_CONVERSION_PASSES_TD

View File

@ -23,6 +23,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
namespace llvm {
struct KnownBits;
@ -85,6 +86,18 @@ Value createDynamicInject(OpBuilder &builder, Location loc, Value value,
Value createInject(OpBuilder &builder, Location loc, Value value,
unsigned offset, Value replacement);
/// Construct a full adder for three 1-bit inputs.
std::pair<Value, Value> fullAdder(OpBuilder &builder, Location loc, Value a,
Value b, Value c);
/// Perform Wallace tree reduction on partial products.
/// See https://en.wikipedia.org/wiki/Wallace_tree
/// \param targetAddends The number of addends to reduce to (2 for carry-save).
/// \param inputAddends The rows of bits to be summed.
SmallVector<Value> wallaceReduction(OpBuilder &builder, Location loc,
size_t width, size_t targetAddends,
SmallVector<SmallVector<Value>> &addends);
} // namespace comb
} // namespace circt

View File

@ -0,0 +1,48 @@
// REQUIRES: libz3
// REQUIRES: circt-lec-jit
// RUN: circt-opt %s --convert-datapath-to-comb -o %t.mlir
// RUN: circt-lec %t.mlir %s -c1=partial_product_5 -c2=partial_product_5 --shared-libs=%libz3 | FileCheck %s --check-prefix=AND5
// AND5: c1 == c2
hw.module @partial_product_5(in %a : i5, in %b : i5, out sum : i5) {
%0:5 = datapath.partial_product %a, %b : (i5, i5) -> (i5, i5, i5, i5, i5)
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4 : i5
hw.output %1 : i5
}
// RUN: circt-lec %t.mlir %s -c1=partial_product_4 -c2=partial_product_4 --shared-libs=%libz3 | FileCheck %s --check-prefix=AND4
// AND4: c1 == c2
hw.module @partial_product_4(in %a : i4, in %b : i4, out sum : i4) {
%0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4)
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3 : i4
hw.output %1 : i4
}
// RUN: circt-lec %t.mlir %s -c1=compress_3 -c2=compress_3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMP3
// COMP3: c1 == c2
hw.module @compress_3(in %a : i4, in %b : i4, in %c : i4, out sum : i4) {
%0:2 = datapath.compress %a, %b, %c : i4 [3 -> 2]
%1 = comb.add bin %0#0, %0#1 : i4
hw.output %1 : i4
}
// RUN: circt-lec %t.mlir %s -c1=compress_6 -c2=compress_6 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMP6
// COMP6: c1 == c2
hw.module @compress_6(in %a : i4, in %b : i4, in %c : i4, in %d : i4, in %e : i4, in %f : i4, out sum : i4) {
%0:3 = datapath.compress %a, %b, %c, %d, %e, %f : i4 [6 -> 3]
%1 = comb.add bin %0#0, %0#1, %0#2 : i4
hw.output %1 : i4
}
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-partial-product-to-booth=true lower-compress-to-add=true}))" -o %t.mlir
// RUN: circt-lec %t.mlir %s -c1=partial_product_5 -c2=partial_product_5 --shared-libs=%libz3 | FileCheck %s --check-prefix=BOOTH5
// BOOTH5: c1 == c2
// RUN: circt-lec %t.mlir %s -c1=partial_product_4 -c2=partial_product_4 --shared-libs=%libz3 | FileCheck %s --check-prefix=BOOTH4
// BOOTH4: c1 == c2
// RUN: circt-lec %t.mlir %s -c1=compress_3 -c2=compress_3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMPADD3
// COMPADD3: c1 == c2
// RUN: circt-lec %t.mlir %s -c1=compress_6 -c2=compress_6 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMPADD6
// COMPADD6: c1 == c2

View File

@ -14,6 +14,7 @@ add_circt_public_c_api_library(CIRCTCAPIConversion
CIRCTCombToLLVM
CIRCTCombToSMT
CIRCTConvertToArcs
CIRCTDatapathToComb
CIRCTDatapathToSMT
CIRCTDCToHW
CIRCTExportChiselInterface

View File

@ -9,6 +9,7 @@ add_subdirectory(CombToDatapath)
add_subdirectory(CombToLLVM)
add_subdirectory(CombToSMT)
add_subdirectory(ConvertToArcs)
add_subdirectory(DatapathToComb)
add_subdirectory(DatapathToSMT)
add_subdirectory(DCToHW)
add_subdirectory(ExportAIGER)

View File

@ -668,21 +668,6 @@ struct CombSubOpConversion : OpConversionPattern<SubOp> {
}
};
// Construct a full adder for three 1-bit inputs.
std::pair<Value, Value> fullAdder(ConversionPatternRewriter &rewriter,
Location loc, Value a, Value b, Value c) {
auto aXorB = rewriter.createOrFold<comb::XorOp>(loc, a, b, true);
Value sum = rewriter.createOrFold<comb::XorOp>(loc, aXorB, c, true);
auto carry = rewriter.createOrFold<comb::OrOp>(
loc,
ArrayRef<Value>{rewriter.createOrFold<comb::AndOp>(loc, a, b, true),
rewriter.createOrFold<comb::AndOp>(loc, aXorB, c, true)},
true);
return {sum, carry};
}
struct CombMulOpConversion : OpConversionPattern<MulOp> {
using OpConversionPattern<MulOp>::OpConversionPattern;
using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
@ -729,74 +714,14 @@ struct CombMulOpConversion : OpConversionPattern<MulOp> {
return success();
}
// Wallace tree reduction
replaceOpAndCopyNamehint(
rewriter, op,
wallaceReduction(falseValue, width, rewriter, loc, partialProducts));
// Wallace tree reduction - reduce to two addends.
auto addends =
comb::wallaceReduction(rewriter, loc, width, 2, partialProducts);
// Sum the two addends using a carry-propagate adder
auto newAdd = rewriter.create<comb::AddOp>(loc, addends, true);
replaceOpAndCopyNamehint(rewriter, op, newAdd);
return success();
}
private:
// Perform Wallace tree reduction on partial products.
// See https://en.wikipedia.org/wiki/Wallace_tree
static Value
wallaceReduction(Value falseValue, size_t width,
ConversionPatternRewriter &rewriter, Location loc,
SmallVector<SmallVector<Value>> &partialProducts) {
SmallVector<SmallVector<Value>> newPartialProducts;
newPartialProducts.reserve(partialProducts.size());
// Continue reduction until we have only two rows. The length of
// `partialProducts` is reduced by 1/3 in each iteration.
while (partialProducts.size() > 2) {
newPartialProducts.clear();
// Take three rows at a time and reduce to two rows(sum and carry).
for (unsigned i = 0; i < partialProducts.size(); i += 3) {
if (i + 2 < partialProducts.size()) {
// We have three rows to reduce
auto &row1 = partialProducts[i];
auto &row2 = partialProducts[i + 1];
auto &row3 = partialProducts[i + 2];
assert(row1.size() == width && row2.size() == width &&
row3.size() == width);
SmallVector<Value> sumRow, carryRow;
sumRow.reserve(width);
carryRow.reserve(width);
carryRow.push_back(falseValue);
// Process each bit position
for (unsigned j = 0; j < width; ++j) {
// Full adder logic
auto [sum, carry] =
fullAdder(rewriter, loc, row1[j], row2[j], row3[j]);
sumRow.push_back(sum);
if (j + 1 < width)
carryRow.push_back(carry);
}
newPartialProducts.push_back(std::move(sumRow));
newPartialProducts.push_back(std::move(carryRow));
} else {
// Add remaining rows as is
newPartialProducts.append(partialProducts.begin() + i,
partialProducts.end());
}
}
std::swap(newPartialProducts, partialProducts);
}
assert(partialProducts.size() == 2);
// Reverse the order of the bits
std::reverse(partialProducts[0].begin(), partialProducts[0].end());
std::reverse(partialProducts[1].begin(), partialProducts[1].end());
auto lhs = rewriter.create<comb::ConcatOp>(loc, partialProducts[0]);
auto rhs = rewriter.create<comb::ConcatOp>(loc, partialProducts[1]);
// Use comb.add for the final addition.
return rewriter.create<comb::AddOp>(loc, ArrayRef<Value>{lhs, rhs}, true);
}
};
template <typename OpTy>

View File

@ -0,0 +1,15 @@
add_circt_conversion_library(CIRCTDatapathToComb
DatapathToComb.cpp
DEPENDS
CIRCTConversionPassIncGen
LINK_LIBS PUBLIC
CIRCTComb
CIRCTDatapath
CIRCTHW
MLIRIR
MLIRPass
MLIRSupport
MLIRTransforms
)

View File

@ -0,0 +1,274 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Conversion/DatapathToComb.h"
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/Datapath/DatapathOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "datapath-to-comb"
namespace circt {
#define GEN_PASS_DEF_CONVERTDATAPATHTOCOMB
#include "circt/Conversion/Passes.h.inc"
} // namespace circt
using namespace circt;
using namespace datapath;
// A wrapper for comb::extractBits that returns a SmallVector<Value>.
static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
SmallVector<Value> bits;
comb::extractBits(builder, val, bits);
return bits;
}
//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//
namespace {
// Replace compressor by an adder of the inputs and zero for the other results:
// compress(a,b,c,d) -> {a+b+c+d, 0}
// Facilitates use of downstream compression algorithms e.g. Yosys
struct DatapathCompressOpAddConversion : OpConversionPattern<CompressOp> {
using OpConversionPattern<CompressOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto inputs = op.getOperands();
unsigned width = inputs[0].getType().getIntOrFloatBitWidth();
// Sum all the inputs - set that to result value 0
auto addOp = rewriter.create<comb::AddOp>(loc, inputs, true);
// Replace remaining results with zeros
auto zeroOp = rewriter.create<hw::ConstantOp>(loc, APInt(width, 0));
SmallVector<Value> results(op.getNumResults() - 1, zeroOp);
results.push_back(addOp);
rewriter.replaceOp(op, results);
return success();
}
};
// Replace compressor by a wallace tree of full-adders
struct DatapathCompressOpConversion : OpConversionPattern<CompressOp> {
using OpConversionPattern<CompressOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto inputs = op.getOperands();
unsigned width = inputs[0].getType().getIntOrFloatBitWidth();
SmallVector<SmallVector<Value>> addends;
for (auto input : inputs) {
addends.push_back(
extractBits(rewriter, input)); // Extract bits from each input
}
// Wallace tree reduction
// TODO - implement a more efficient compression algorithm to compete with
// yosys's `alumacc` lowering - a coarse grained timing model would help to
// sort the inputs according to arrival time.
auto targetAddends = op.getNumResults();
rewriter.replaceOp(op, comb::wallaceReduction(rewriter, loc, width,
targetAddends, addends));
return success();
}
};
struct DatapathPartialProductOpConversion
: OpConversionPattern<PartialProductOp> {
using OpConversionPattern<PartialProductOp>::OpConversionPattern;
DatapathPartialProductOpConversion(MLIRContext *context, bool forceBooth)
: OpConversionPattern<PartialProductOp>(context),
forceBooth(forceBooth){};
const bool forceBooth;
LogicalResult
matchAndRewrite(PartialProductOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value a = op.getLhs();
Value b = op.getRhs();
unsigned width = a.getType().getIntOrFloatBitWidth();
// Skip a zero width value.
if (width == 0) {
rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(0), 0);
return success();
}
// Use width as a heuristic to guide partial product implementation
if (width > 16 || forceBooth)
return lowerBoothArray(rewriter, a, b, op, width);
else
return lowerAndArray(rewriter, a, b, op, width);
}
private:
static LogicalResult lowerAndArray(ConversionPatternRewriter &rewriter,
Value a, Value b, PartialProductOp op,
unsigned width) {
Location loc = op.getLoc();
// Keep a as a bitvector - multiply by each digit of b
SmallVector<Value> bBits = extractBits(rewriter, b);
SmallVector<Value> partialProducts;
partialProducts.reserve(width);
// AND Array Construction:
// partialProducts[i] = ({b[i],..., b[i]} & a) << i
assert(op.getNumResults() <= width &&
"Cannot return more results than the operator width");
for (unsigned i = 0; i < op.getNumResults(); ++i) {
auto repl = rewriter.create<comb::ReplicateOp>(loc, bBits[i], width);
auto ppRow = rewriter.create<comb::AndOp>(loc, repl, a);
auto shiftBy = rewriter.create<hw::ConstantOp>(loc, APInt(width, i));
auto ppAlign = rewriter.create<comb::ShlOp>(loc, ppRow, shiftBy);
partialProducts.push_back(ppAlign);
}
rewriter.replaceOp(op, partialProducts);
return success();
}
static LogicalResult lowerBoothArray(ConversionPatternRewriter &rewriter,
Value a, Value b, PartialProductOp op,
unsigned width) {
Location loc = op.getLoc();
auto zeroFalse = rewriter.create<hw::ConstantOp>(loc, APInt(1, 0));
auto zeroWidth = rewriter.create<hw::ConstantOp>(loc, APInt(width, 0));
auto oneWidth = rewriter.create<hw::ConstantOp>(loc, APInt(width, 1));
Value twoA = rewriter.create<comb::ShlOp>(loc, a, oneWidth);
SmallVector<Value> bBits = extractBits(rewriter, b);
SmallVector<Value> partialProducts;
partialProducts.reserve(width);
// Booth encoding halves array height by grouping three bits at a time:
// partialProducts[i] = a * (-2*b[2*i+1] + b[2*i] + b[2*i-1]) << 2*i
// encNeg \approx (-2*b[2*i+1] + b[2*i] + b[2*i-1]) <= 0
// encOne = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 1
// encTwo = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 2
Value encNegPrev;
// For even width - additional row contains the final sign correction
for (unsigned i = 0; i <= width; i += 2) {
// Get Booth bits: b[i+1], b[i], b[i-1] (b[-1] = 0)
Value bim1 = (i == 0) ? zeroFalse : bBits[i - 1];
Value bi = (i < width) ? bBits[i] : zeroFalse;
Value bip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse;
// Is the encoding zero or negative (an approximation)
Value encNeg = bip1;
// Is the encoding one = b[i] xor b[i-1]
Value encOne = rewriter.create<comb::XorOp>(loc, bi, bim1, true);
// Is the encoding two = (bip1 & ~bi & ~bim1) | (~bip1 & bi & bim1)
Value constOne = rewriter.create<hw::ConstantOp>(loc, APInt(1, 1));
Value biInv = rewriter.create<comb::XorOp>(loc, bi, constOne, true);
Value bip1Inv = rewriter.create<comb::XorOp>(loc, bip1, constOne, true);
Value bim1Inv = rewriter.create<comb::XorOp>(loc, bim1, constOne, true);
Value andLeft = rewriter.create<comb::AndOp>(
loc, ValueRange{bip1Inv, bi, bim1}, true);
Value andRight = rewriter.create<comb::AndOp>(
loc, ValueRange{bip1, biInv, bim1Inv}, true);
Value encTwo = rewriter.create<comb::OrOp>(loc, andLeft, andRight, true);
Value encNegRepl = rewriter.create<comb::ReplicateOp>(loc, encNeg, width);
Value encOneRepl = rewriter.create<comb::ReplicateOp>(loc, encOne, width);
Value encTwoRepl = rewriter.create<comb::ReplicateOp>(loc, encTwo, width);
// Select between 2*a or 1*a or 0*a
Value selTwoA = rewriter.create<comb::AndOp>(loc, encTwoRepl, twoA);
Value selOneA = rewriter.create<comb::AndOp>(loc, encOneRepl, a);
Value magA = rewriter.create<comb::OrOp>(loc, selTwoA, selOneA, true);
// Conditionally invert the row
Value ppRow = rewriter.create<comb::XorOp>(loc, magA, encNegRepl, true);
// No sign-correction in the first row
if (i == 0) {
partialProducts.push_back(ppRow);
encNegPrev = encNeg;
continue;
}
// Insert a sign-correction from the previous row
assert(i >= 2 && "Expected i to be at least 2 for sign correction");
// {ppRow, 0, encNegPrev} << 2*(i-1)
Value withSignCorrection = rewriter.create<comb::ConcatOp>(
loc, ValueRange{ppRow, zeroFalse, encNegPrev});
Value ppAlignPre =
rewriter.create<comb::ExtractOp>(loc, withSignCorrection, 0, width);
Value shiftBy = rewriter.create<hw::ConstantOp>(loc, APInt(width, i - 2));
Value ppAlign = rewriter.create<comb::ShlOp>(loc, ppAlignPre, shiftBy);
partialProducts.push_back(ppAlign);
encNegPrev = encNeg;
if (partialProducts.size() == op.getNumResults())
break;
}
// Zero-pad to match the required output width
while (partialProducts.size() < op.getNumResults())
partialProducts.push_back(zeroWidth);
assert(partialProducts.size() == op.getNumResults() &&
"Expected number of booth partial products to match results");
rewriter.replaceOp(op, partialProducts);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert Datapath to Comb pass
//===----------------------------------------------------------------------===//
namespace {
struct ConvertDatapathToCombPass
: public impl::ConvertDatapathToCombBase<ConvertDatapathToCombPass> {
void runOnOperation() override;
using ConvertDatapathToCombBase<
ConvertDatapathToCombPass>::ConvertDatapathToCombBase;
};
} // namespace
void ConvertDatapathToCombPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
target.addIllegalDialect<DatapathDialect>();
RewritePatternSet patterns(&getContext());
patterns.add<DatapathPartialProductOpConversion>(patterns.getContext(),
forceBooth);
if (lowerCompressToAdd)
// Lower compressors to simple add operations for downstream optimisations
patterns.add<DatapathCompressOpAddConversion>(patterns.getContext());
else
// Lower compressors to a complete gate-level implementation
patterns.add<DatapathCompressOpConversion>(patterns.getContext());
if (failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}

View File

@ -216,6 +216,86 @@ Value comb::createInject(OpBuilder &builder, Location loc, Value value,
return builder.createOrFold<comb::ConcatOp>(loc, fragments);
}
// Construct a full adder for three 1-bit inputs.
std::pair<Value, Value> comb::fullAdder(OpBuilder &builder, Location loc,
Value a, Value b, Value c) {
auto aXorB = builder.createOrFold<comb::XorOp>(loc, a, b, true);
Value sum = builder.createOrFold<comb::XorOp>(loc, aXorB, c, true);
auto carry = builder.createOrFold<comb::OrOp>(
loc,
ArrayRef<Value>{builder.createOrFold<comb::AndOp>(loc, a, b, true),
builder.createOrFold<comb::AndOp>(loc, aXorB, c, true)},
true);
return {sum, carry};
}
// Perform Wallace tree reduction on partial products.
// See https://en.wikipedia.org/wiki/Wallace_tree
SmallVector<Value>
comb::wallaceReduction(OpBuilder &builder, Location loc, size_t width,
size_t targetAddends,
SmallVector<SmallVector<Value>> &addends) {
auto falseValue = builder.create<hw::ConstantOp>(loc, APInt(1, 0));
SmallVector<SmallVector<Value>> newAddends;
newAddends.reserve(addends.size());
// Continue reduction until we have only two rows. The length of
// `addends` is reduced by 1/3 in each iteration.
while (addends.size() > targetAddends) {
newAddends.clear();
// Take three rows at a time and reduce to two rows(sum and carry).
for (unsigned i = 0; i < addends.size(); i += 3) {
if (i + 2 < addends.size()) {
// We have three rows to reduce
auto &row1 = addends[i];
auto &row2 = addends[i + 1];
auto &row3 = addends[i + 2];
assert(row1.size() == width && row2.size() == width &&
row3.size() == width);
SmallVector<Value> sumRow, carryRow;
sumRow.reserve(width);
carryRow.reserve(width);
carryRow.push_back(falseValue);
// Process each bit position
for (unsigned j = 0; j < width; ++j) {
// Full adder logic
auto [sum, carry] =
comb::fullAdder(builder, loc, row1[j], row2[j], row3[j]);
sumRow.push_back(sum);
if (j + 1 < width)
carryRow.push_back(carry);
}
newAddends.push_back(std::move(sumRow));
newAddends.push_back(std::move(carryRow));
} else {
// Add remaining rows as is
newAddends.append(addends.begin() + i, addends.end());
}
}
std::swap(newAddends, addends);
}
assert(addends.size() <= targetAddends);
SmallVector<Value> carrySave;
for (auto &addend : addends) {
// Reverse the order of the bits
std::reverse(addend.begin(), addend.end());
carrySave.push_back(builder.create<comb::ConcatOp>(loc, addend));
}
// Pad with zeros
auto zero = builder.create<hw::ConstantOp>(loc, APInt(width, 0));
while (carrySave.size() < targetAddends)
carrySave.push_back(zero);
return carrySave;
}
//===----------------------------------------------------------------------===//
// ICmpOp
//===----------------------------------------------------------------------===//

View File

@ -67,7 +67,6 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
// ALLOW_ADD-NEXT: %[[RHS0:.+]] = comb.extract %rhs from 0 : (i3) -> i1
// ALLOW_ADD-NEXT: %[[RHS1:.+]] = comb.extract %rhs from 1 : (i3) -> i1
// ALLOW_ADD-NEXT: %[[RHS2:.+]] = comb.extract %rhs from 2 : (i3) -> i1
// ALLOW_ADD-NEXT: %false = hw.constant false
// Partial Products
// ALLOW_ADD-NEXT: %[[P_0_0:.+]] = comb.and %[[LHS0]], %[[RHS0]] : i1
// ALLOW_ADD-NEXT: %[[P_1_0:.+]] = comb.and %[[LHS1]], %[[RHS0]] : i1
@ -76,6 +75,7 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
// ALLOW_ADD-NEXT: %[[P_1_1:.+]] = comb.and %[[LHS1]], %[[RHS1]] : i1
// ALLOW_ADD-NEXT: %[[P_2_1:.+]] = comb.and %[[LHS0]], %[[RHS2]] : i1
// Wallace Tree Reduction
// ALLOW_ADD-NEXT: %false = hw.constant false
// ALLOW_ADD-NEXT: %[[XOR0:.+]] = comb.xor bin %[[P_1_0]], %[[P_0_1]] : i1
// ALLOW_ADD-NEXT: %[[AND0:.+]] = comb.and bin %[[P_1_0]], %[[P_0_1]] : i1
// ALLOW_ADD-NEXT: %[[XOR1:.+]] = comb.xor bin %[[P_2_0]], %[[P_1_1]] : i1

View File

@ -0,0 +1,106 @@
// RUN: circt-opt %s --convert-datapath-to-comb | FileCheck %s
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-compress-to-add=true}))" | FileCheck %s --check-prefix=TO-ADD
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-partial-product-to-booth=true}, canonicalize))" | FileCheck %s --check-prefix=FORCE-BOOTH
// CHECK-LABEL: @compressor
hw.module @compressor(in %a : i2, in %b : i2, in %c : i2, out carry : i2, out save : i2) {
//CHECK-NEXT: %[[A0:.+]] = comb.extract %a from 0 : (i2) -> i1
//CHECK-NEXT: %[[A1:.+]] = comb.extract %a from 1 : (i2) -> i1
//CHECK-NEXT: %[[B0:.+]] = comb.extract %b from 0 : (i2) -> i1
//CHECK-NEXT: %[[B1:.+]] = comb.extract %b from 1 : (i2) -> i1
//CHECK-NEXT: %[[C0:.+]] = comb.extract %c from 0 : (i2) -> i1
//CHECK-NEXT: %[[C1:.+]] = comb.extract %c from 1 : (i2) -> i1
//CHECK-NEXT: %false = hw.constant false
//CHECK-NEXT: %[[AxB0:.+]] = comb.xor bin %[[A0]], %[[B0]] : i1
//CHECK-NEXT: %[[AxBxC0:.+]] = comb.xor bin %[[AxB0]], %[[C0]] : i1
//CHECK-NEXT: %[[AB0:.+]] = comb.and bin %[[A0]], %[[B0]] : i1
//CHECK-NEXT: %[[AxBC0:.+]] = comb.and bin %[[AxB0]], %[[C0]] : i1
//CHECK-NEXT: %[[AB0oAxBC0:.+]] = comb.or bin %[[AB0]], %[[AxBC0]] : i1
//CHECK-NEXT: %[[AxB1:.+]] = comb.xor bin %[[A1]], %[[B1]] : i1
//CHECK-NEXT: %[[AxBxC1:.+]] = comb.xor bin %[[AxB1]], %[[C1]] : i1
//CHECK-NEXT: %[[AB1:.+]] = comb.and bin %[[A1]], %[[B1]] : i1
//CHECK-NEXT: %[[AxBC1:.+]] = comb.and bin %[[AxB1]], %[[C1]] : i1
//CHECK-NEXT: comb.or bin %[[AB1]], %[[AxBC1]] : i1
//CHECK-NEXT: comb.concat %[[AxBxC1]], %[[AxBxC0]] : i1, i1
//CHECK-NEXT: comb.concat %[[AB0oAxBC0]], %false : i1, i1
%0:2 = datapath.compress %a, %b, %c : i2 [3 -> 2]
hw.output %0#0, %0#1 : i2, i2
}
// CHECK-LABEL: @compressor_add
// TO-ADD-LABEL: @compressor_add
// TO-ADD-NEXT: %[[ADD:.+]] = comb.add bin %a, %b, %c : i2
// TO-ADD-NEXT: %c0_i2 = hw.constant 0 : i2
// TO-ADD-NEXT: hw.output %c0_i2, %[[ADD]] : i2, i2
hw.module @compressor_add(in %a : i2, in %b : i2, in %c : i2, out carry : i2, out save : i2) {
%0:2 = datapath.compress %a, %b, %c : i2 [3 -> 2]
hw.output %0#0, %0#1 : i2, i2
}
// CHECK-LABEL: @partial_product
hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) {
// CHECK-NEXT: %[[B0:.+]] = comb.extract %b from 0 : (i3) -> i1
// CHECK-NEXT: %[[B1:.+]] = comb.extract %b from 1 : (i3) -> i1
// CHECK-NEXT: %[[B2:.+]] = comb.extract %b from 2 : (i3) -> i1
// CHECK-NEXT: %[[B0R:.+]] = comb.replicate %[[B0]] : (i1) -> i3
// CHECK-NEXT: %[[PP0:.+]] = comb.and %[[B0R]], %a : i3
// CHECK-NEXT: %c0_i3 = hw.constant 0 : i3
// CHECK-NEXT: comb.shl %[[PP0]], %c0_i3 : i3
// CHECK-NEXT: %[[B1R:.+]] = comb.replicate %[[B1]] : (i1) -> i3
// CHECK-NEXT: %[[PP1:.+]] = comb.and %[[B1R]], %a : i3
// CHECK-NEXT: %c1_i3 = hw.constant 1 : i3
// CHECK-NEXT: comb.shl %[[PP1]], %c1_i3 : i3
// CHECK-NEXT: %[[B2R:.+]] = comb.replicate %[[B2]] : (i1) -> i3
// CHECK-NEXT: %[[PP2:.+]] = comb.and %[[B2R]], %a : i3
// CHECK-NEXT: %c2_i3 = hw.constant 2 : i3
// CHECK-NEXT: comb.shl %[[PP2]], %c2_i3 : i3
%0:3 = datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3)
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
}
// CHECK-LABEL: @partial_product_booth
// FORCE-BOOTH-LABEL: @partial_product_booth
// Constants
// FORCE-BOOTH-NEXT: %true = hw.constant true
// FORCE-BOOTH-NEXT: %false = hw.constant false
// FORCE-BOOTH-NEXT: %c0_i3 = hw.constant 0 : i3
// 2*a
// FORCE-BOOTH-NEXT: %0 = comb.extract %a from 0 : (i3) -> i2
// FORCE-BOOTH-NEXT: %[[TWOA:.+]] = comb.concat %0, %false : i2, i1
// FORCE-BOOTH-NEXT: %[[B0:.+]] = comb.extract %b from 0 : (i3) -> i1
// FORCE-BOOTH-NEXT: %[[B1:.+]] = comb.extract %b from 1 : (i3) -> i1
// FORCE-BOOTH-NEXT: %[[B2:.+]] = comb.extract %b from 2 : (i3) -> i1
// PP0
// FORCE-BOOTH-NEXT: %[[NB0:.+]] = comb.xor bin %[[B0]], %true : i1
// FORCE-BOOTH-NEXT: %[[TWO0:.+]] = comb.and %[[B1]], %[[NB0]] : i1
// FORCE-BOOTH-NEXT: %[[PPOSGN:.+]] = comb.replicate %[[B1]] : (i1) -> i3
// FORCE-BOOTH-NEXT: %[[ONER:.+]] = comb.replicate %[[B0]] : (i1) -> i3
// FORCE-BOOTH-NEXT: %[[TWO0R:.+]] = comb.replicate %[[TWO0]] : (i1) -> i3
// FORCE-BOOTH-NEXT: %[[PP0TWOA:.+]] = comb.and %[[TWO0R]], %[[TWOA]] : i3
// FORCE-BOOTH-NEXT: %[[PP0ONEA:.+]] = comb.and %[[ONER]], %a : i3
// FORCE-BOOTH-NEXT: %[[PP0MAG:.+]] = comb.or bin %[[PP0TWOA]], %[[PP0ONEA]] : i3
// FORCE-BOOTH-NEXT: %[[PP0:.+]] = comb.xor bin %[[PP0MAG]], %[[PPOSGN]] : i3
// PP1
// FORCE-BOOTH-NEXT: %[[B2XORB1:.+]] = comb.xor bin %4, %3 : i1
// FORCE-BOOTH-NEXT: %[[A0:.+]] = comb.extract %a from 0 : (i3) -> i1
// FORCE-BOOTH-NEXT: %[[PP1MSB:.+]] = comb.and %[[B2XORB1]], %[[A0]] : i1
// FORCE-BOOTH-NEXT: %[[PP1:.+]] = comb.concat %[[PP1MSB]], %false, %[[B1]] : i1, i1, i1
// FORCE-BOOTH-NEXT: hw.output %[[PP0]], %[[PP1]], %c0_i3 : i3, i3, i3
hw.module @partial_product_booth(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) {
%0:3 = datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3)
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
}
// CHECK-LABEL: @partial_product_24
hw.module @partial_product_24(in %a : i24, in %b : i24, out sum : i24) {
%0:24 = datapath.partial_product %a, %b : (i24, i24) -> (i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24)
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12, %0#13, %0#14, %0#15, %0#16, %0#17, %0#18, %0#19, %0#20, %0#21, %0#22, %0#23 : i24
hw.output %1 : i24
}
// CHECK-LABEL: @partial_product_25
hw.module @partial_product_25(in %a : i25, in %b : i25, out sum : i25) {
%0:25 = datapath.partial_product %a, %b : (i25, i25) -> (i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25)
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12, %0#13, %0#14, %0#15, %0#16, %0#17, %0#18, %0#19, %0#20, %0#21, %0#22, %0#23 : i25
hw.output %1 : i25
}