mirror of https://github.com/llvm/circt.git
[Datapath] Create Datapath to Comb Pass (#8736)
* Initiate datapath to comb pass * Add tests and tidy datapath to compress implementation * Improve comments * Formatting and test corrections * Correct CAPI * Move wallace tree reduction and full-adder to comb ops * Adding integration tests using circt-lec and correcting review comments * Fix bug in Booth code for final sign correction row and add testing using lec. Add a forceBooth option largely for testing purposes * Minor fix * Formatting * Removing populate patterns function * Formatting
This commit is contained in:
parent
a681fefd5e
commit
30a4a38fca
|
@ -0,0 +1,21 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CIRCT_CONVERSION_DATAPATHTOCOMB_H
|
||||
#define CIRCT_CONVERSION_DATAPATHTOCOMB_H
|
||||
|
||||
#include "circt/Support/LLVM.h"
|
||||
|
||||
namespace circt {
|
||||
|
||||
#define GEN_PASS_DECL_CONVERTDATAPATHTOCOMB
|
||||
#include "circt/Conversion/Passes.h.inc"
|
||||
|
||||
} // namespace circt
|
||||
|
||||
#endif // CIRCT_CONVERSION_DATAPATHTOCOMB_H
|
|
@ -26,6 +26,7 @@
|
|||
#include "circt/Conversion/CombToSMT.h"
|
||||
#include "circt/Conversion/ConvertToArcs.h"
|
||||
#include "circt/Conversion/DCToHW.h"
|
||||
#include "circt/Conversion/DatapathToComb.h"
|
||||
#include "circt/Conversion/DatapathToSMT.h"
|
||||
#include "circt/Conversion/ExportChiselInterface.h"
|
||||
#include "circt/Conversion/ExportVerilog.h"
|
||||
|
|
|
@ -871,4 +871,22 @@ def ConvertDatapathToSMT : Pass<"convert-datapath-to-smt"> {
|
|||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConvertDatapathToComb
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertDatapathToComb : Pass<"convert-datapath-to-comb"> {
|
||||
let summary = "Convert Datapath ops to Comb ops";
|
||||
let dependentDialects = [
|
||||
"circt::comb::CombDialect", "circt::datapath::DatapathDialect",
|
||||
"circt::hw::HWDialect"
|
||||
];
|
||||
let options = [
|
||||
Option<"lowerCompressToAdd", "lower-compress-to-add", "bool", "false",
|
||||
"Lower compress operators to variadic add.">,
|
||||
Option<"forceBooth", "lower-partial-product-to-booth", "bool", "false",
|
||||
"Force all partial products to be lowered to Booth arrays.">
|
||||
];
|
||||
}
|
||||
|
||||
#endif // CIRCT_CONVERSION_PASSES_TD
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace llvm {
|
||||
struct KnownBits;
|
||||
|
@ -85,6 +86,18 @@ Value createDynamicInject(OpBuilder &builder, Location loc, Value value,
|
|||
Value createInject(OpBuilder &builder, Location loc, Value value,
|
||||
unsigned offset, Value replacement);
|
||||
|
||||
/// Construct a full adder for three 1-bit inputs.
|
||||
std::pair<Value, Value> fullAdder(OpBuilder &builder, Location loc, Value a,
|
||||
Value b, Value c);
|
||||
|
||||
/// Perform Wallace tree reduction on partial products.
|
||||
/// See https://en.wikipedia.org/wiki/Wallace_tree
|
||||
/// \param targetAddends The number of addends to reduce to (2 for carry-save).
|
||||
/// \param inputAddends The rows of bits to be summed.
|
||||
SmallVector<Value> wallaceReduction(OpBuilder &builder, Location loc,
|
||||
size_t width, size_t targetAddends,
|
||||
SmallVector<SmallVector<Value>> &addends);
|
||||
|
||||
} // namespace comb
|
||||
} // namespace circt
|
||||
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
// REQUIRES: libz3
|
||||
// REQUIRES: circt-lec-jit
|
||||
|
||||
// RUN: circt-opt %s --convert-datapath-to-comb -o %t.mlir
|
||||
// RUN: circt-lec %t.mlir %s -c1=partial_product_5 -c2=partial_product_5 --shared-libs=%libz3 | FileCheck %s --check-prefix=AND5
|
||||
// AND5: c1 == c2
|
||||
hw.module @partial_product_5(in %a : i5, in %b : i5, out sum : i5) {
|
||||
%0:5 = datapath.partial_product %a, %b : (i5, i5) -> (i5, i5, i5, i5, i5)
|
||||
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4 : i5
|
||||
hw.output %1 : i5
|
||||
}
|
||||
|
||||
// RUN: circt-lec %t.mlir %s -c1=partial_product_4 -c2=partial_product_4 --shared-libs=%libz3 | FileCheck %s --check-prefix=AND4
|
||||
// AND4: c1 == c2
|
||||
hw.module @partial_product_4(in %a : i4, in %b : i4, out sum : i4) {
|
||||
%0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4)
|
||||
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3 : i4
|
||||
hw.output %1 : i4
|
||||
}
|
||||
|
||||
// RUN: circt-lec %t.mlir %s -c1=compress_3 -c2=compress_3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMP3
|
||||
// COMP3: c1 == c2
|
||||
hw.module @compress_3(in %a : i4, in %b : i4, in %c : i4, out sum : i4) {
|
||||
%0:2 = datapath.compress %a, %b, %c : i4 [3 -> 2]
|
||||
%1 = comb.add bin %0#0, %0#1 : i4
|
||||
hw.output %1 : i4
|
||||
}
|
||||
|
||||
// RUN: circt-lec %t.mlir %s -c1=compress_6 -c2=compress_6 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMP6
|
||||
// COMP6: c1 == c2
|
||||
hw.module @compress_6(in %a : i4, in %b : i4, in %c : i4, in %d : i4, in %e : i4, in %f : i4, out sum : i4) {
|
||||
%0:3 = datapath.compress %a, %b, %c, %d, %e, %f : i4 [6 -> 3]
|
||||
%1 = comb.add bin %0#0, %0#1, %0#2 : i4
|
||||
hw.output %1 : i4
|
||||
}
|
||||
|
||||
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-partial-product-to-booth=true lower-compress-to-add=true}))" -o %t.mlir
|
||||
// RUN: circt-lec %t.mlir %s -c1=partial_product_5 -c2=partial_product_5 --shared-libs=%libz3 | FileCheck %s --check-prefix=BOOTH5
|
||||
// BOOTH5: c1 == c2
|
||||
|
||||
// RUN: circt-lec %t.mlir %s -c1=partial_product_4 -c2=partial_product_4 --shared-libs=%libz3 | FileCheck %s --check-prefix=BOOTH4
|
||||
// BOOTH4: c1 == c2
|
||||
|
||||
// RUN: circt-lec %t.mlir %s -c1=compress_3 -c2=compress_3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMPADD3
|
||||
// COMPADD3: c1 == c2
|
||||
|
||||
// RUN: circt-lec %t.mlir %s -c1=compress_6 -c2=compress_6 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMPADD6
|
||||
// COMPADD6: c1 == c2
|
|
@ -14,6 +14,7 @@ add_circt_public_c_api_library(CIRCTCAPIConversion
|
|||
CIRCTCombToLLVM
|
||||
CIRCTCombToSMT
|
||||
CIRCTConvertToArcs
|
||||
CIRCTDatapathToComb
|
||||
CIRCTDatapathToSMT
|
||||
CIRCTDCToHW
|
||||
CIRCTExportChiselInterface
|
||||
|
|
|
@ -9,6 +9,7 @@ add_subdirectory(CombToDatapath)
|
|||
add_subdirectory(CombToLLVM)
|
||||
add_subdirectory(CombToSMT)
|
||||
add_subdirectory(ConvertToArcs)
|
||||
add_subdirectory(DatapathToComb)
|
||||
add_subdirectory(DatapathToSMT)
|
||||
add_subdirectory(DCToHW)
|
||||
add_subdirectory(ExportAIGER)
|
||||
|
|
|
@ -668,21 +668,6 @@ struct CombSubOpConversion : OpConversionPattern<SubOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Construct a full adder for three 1-bit inputs.
|
||||
std::pair<Value, Value> fullAdder(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value a, Value b, Value c) {
|
||||
auto aXorB = rewriter.createOrFold<comb::XorOp>(loc, a, b, true);
|
||||
Value sum = rewriter.createOrFold<comb::XorOp>(loc, aXorB, c, true);
|
||||
|
||||
auto carry = rewriter.createOrFold<comb::OrOp>(
|
||||
loc,
|
||||
ArrayRef<Value>{rewriter.createOrFold<comb::AndOp>(loc, a, b, true),
|
||||
rewriter.createOrFold<comb::AndOp>(loc, aXorB, c, true)},
|
||||
true);
|
||||
|
||||
return {sum, carry};
|
||||
}
|
||||
|
||||
struct CombMulOpConversion : OpConversionPattern<MulOp> {
|
||||
using OpConversionPattern<MulOp>::OpConversionPattern;
|
||||
using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
|
||||
|
@ -729,74 +714,14 @@ struct CombMulOpConversion : OpConversionPattern<MulOp> {
|
|||
return success();
|
||||
}
|
||||
|
||||
// Wallace tree reduction
|
||||
replaceOpAndCopyNamehint(
|
||||
rewriter, op,
|
||||
wallaceReduction(falseValue, width, rewriter, loc, partialProducts));
|
||||
// Wallace tree reduction - reduce to two addends.
|
||||
auto addends =
|
||||
comb::wallaceReduction(rewriter, loc, width, 2, partialProducts);
|
||||
// Sum the two addends using a carry-propagate adder
|
||||
auto newAdd = rewriter.create<comb::AddOp>(loc, addends, true);
|
||||
replaceOpAndCopyNamehint(rewriter, op, newAdd);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Perform Wallace tree reduction on partial products.
|
||||
// See https://en.wikipedia.org/wiki/Wallace_tree
|
||||
static Value
|
||||
wallaceReduction(Value falseValue, size_t width,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
SmallVector<SmallVector<Value>> &partialProducts) {
|
||||
SmallVector<SmallVector<Value>> newPartialProducts;
|
||||
newPartialProducts.reserve(partialProducts.size());
|
||||
// Continue reduction until we have only two rows. The length of
|
||||
// `partialProducts` is reduced by 1/3 in each iteration.
|
||||
while (partialProducts.size() > 2) {
|
||||
newPartialProducts.clear();
|
||||
// Take three rows at a time and reduce to two rows(sum and carry).
|
||||
for (unsigned i = 0; i < partialProducts.size(); i += 3) {
|
||||
if (i + 2 < partialProducts.size()) {
|
||||
// We have three rows to reduce
|
||||
auto &row1 = partialProducts[i];
|
||||
auto &row2 = partialProducts[i + 1];
|
||||
auto &row3 = partialProducts[i + 2];
|
||||
|
||||
assert(row1.size() == width && row2.size() == width &&
|
||||
row3.size() == width);
|
||||
|
||||
SmallVector<Value> sumRow, carryRow;
|
||||
sumRow.reserve(width);
|
||||
carryRow.reserve(width);
|
||||
carryRow.push_back(falseValue);
|
||||
|
||||
// Process each bit position
|
||||
for (unsigned j = 0; j < width; ++j) {
|
||||
// Full adder logic
|
||||
auto [sum, carry] =
|
||||
fullAdder(rewriter, loc, row1[j], row2[j], row3[j]);
|
||||
sumRow.push_back(sum);
|
||||
if (j + 1 < width)
|
||||
carryRow.push_back(carry);
|
||||
}
|
||||
|
||||
newPartialProducts.push_back(std::move(sumRow));
|
||||
newPartialProducts.push_back(std::move(carryRow));
|
||||
} else {
|
||||
// Add remaining rows as is
|
||||
newPartialProducts.append(partialProducts.begin() + i,
|
||||
partialProducts.end());
|
||||
}
|
||||
}
|
||||
|
||||
std::swap(newPartialProducts, partialProducts);
|
||||
}
|
||||
|
||||
assert(partialProducts.size() == 2);
|
||||
// Reverse the order of the bits
|
||||
std::reverse(partialProducts[0].begin(), partialProducts[0].end());
|
||||
std::reverse(partialProducts[1].begin(), partialProducts[1].end());
|
||||
auto lhs = rewriter.create<comb::ConcatOp>(loc, partialProducts[0]);
|
||||
auto rhs = rewriter.create<comb::ConcatOp>(loc, partialProducts[1]);
|
||||
|
||||
// Use comb.add for the final addition.
|
||||
return rewriter.create<comb::AddOp>(loc, ArrayRef<Value>{lhs, rhs}, true);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
add_circt_conversion_library(CIRCTDatapathToComb
|
||||
DatapathToComb.cpp
|
||||
|
||||
DEPENDS
|
||||
CIRCTConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
CIRCTComb
|
||||
CIRCTDatapath
|
||||
CIRCTHW
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRSupport
|
||||
MLIRTransforms
|
||||
)
|
|
@ -0,0 +1,274 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "circt/Conversion/DatapathToComb.h"
|
||||
#include "circt/Dialect/Comb/CombOps.h"
|
||||
#include "circt/Dialect/Datapath/DatapathOps.h"
|
||||
#include "circt/Dialect/HW/HWOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "datapath-to-comb"
|
||||
|
||||
namespace circt {
|
||||
#define GEN_PASS_DEF_CONVERTDATAPATHTOCOMB
|
||||
#include "circt/Conversion/Passes.h.inc"
|
||||
} // namespace circt
|
||||
|
||||
using namespace circt;
|
||||
using namespace datapath;
|
||||
|
||||
// A wrapper for comb::extractBits that returns a SmallVector<Value>.
|
||||
static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
|
||||
SmallVector<Value> bits;
|
||||
comb::extractBits(builder, val, bits);
|
||||
return bits;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// Replace compressor by an adder of the inputs and zero for the other results:
|
||||
// compress(a,b,c,d) -> {a+b+c+d, 0}
|
||||
// Facilitates use of downstream compression algorithms e.g. Yosys
|
||||
struct DatapathCompressOpAddConversion : OpConversionPattern<CompressOp> {
|
||||
using OpConversionPattern<CompressOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto inputs = op.getOperands();
|
||||
unsigned width = inputs[0].getType().getIntOrFloatBitWidth();
|
||||
// Sum all the inputs - set that to result value 0
|
||||
auto addOp = rewriter.create<comb::AddOp>(loc, inputs, true);
|
||||
// Replace remaining results with zeros
|
||||
auto zeroOp = rewriter.create<hw::ConstantOp>(loc, APInt(width, 0));
|
||||
SmallVector<Value> results(op.getNumResults() - 1, zeroOp);
|
||||
results.push_back(addOp);
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Replace compressor by a wallace tree of full-adders
|
||||
struct DatapathCompressOpConversion : OpConversionPattern<CompressOp> {
|
||||
using OpConversionPattern<CompressOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto inputs = op.getOperands();
|
||||
unsigned width = inputs[0].getType().getIntOrFloatBitWidth();
|
||||
|
||||
SmallVector<SmallVector<Value>> addends;
|
||||
for (auto input : inputs) {
|
||||
addends.push_back(
|
||||
extractBits(rewriter, input)); // Extract bits from each input
|
||||
}
|
||||
|
||||
// Wallace tree reduction
|
||||
// TODO - implement a more efficient compression algorithm to compete with
|
||||
// yosys's `alumacc` lowering - a coarse grained timing model would help to
|
||||
// sort the inputs according to arrival time.
|
||||
auto targetAddends = op.getNumResults();
|
||||
rewriter.replaceOp(op, comb::wallaceReduction(rewriter, loc, width,
|
||||
targetAddends, addends));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DatapathPartialProductOpConversion
|
||||
: OpConversionPattern<PartialProductOp> {
|
||||
using OpConversionPattern<PartialProductOp>::OpConversionPattern;
|
||||
|
||||
DatapathPartialProductOpConversion(MLIRContext *context, bool forceBooth)
|
||||
: OpConversionPattern<PartialProductOp>(context),
|
||||
forceBooth(forceBooth){};
|
||||
|
||||
const bool forceBooth;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(PartialProductOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Value a = op.getLhs();
|
||||
Value b = op.getRhs();
|
||||
unsigned width = a.getType().getIntOrFloatBitWidth();
|
||||
|
||||
// Skip a zero width value.
|
||||
if (width == 0) {
|
||||
rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(0), 0);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Use width as a heuristic to guide partial product implementation
|
||||
if (width > 16 || forceBooth)
|
||||
return lowerBoothArray(rewriter, a, b, op, width);
|
||||
else
|
||||
return lowerAndArray(rewriter, a, b, op, width);
|
||||
}
|
||||
|
||||
private:
|
||||
static LogicalResult lowerAndArray(ConversionPatternRewriter &rewriter,
|
||||
Value a, Value b, PartialProductOp op,
|
||||
unsigned width) {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
// Keep a as a bitvector - multiply by each digit of b
|
||||
SmallVector<Value> bBits = extractBits(rewriter, b);
|
||||
|
||||
SmallVector<Value> partialProducts;
|
||||
partialProducts.reserve(width);
|
||||
// AND Array Construction:
|
||||
// partialProducts[i] = ({b[i],..., b[i]} & a) << i
|
||||
assert(op.getNumResults() <= width &&
|
||||
"Cannot return more results than the operator width");
|
||||
|
||||
for (unsigned i = 0; i < op.getNumResults(); ++i) {
|
||||
auto repl = rewriter.create<comb::ReplicateOp>(loc, bBits[i], width);
|
||||
auto ppRow = rewriter.create<comb::AndOp>(loc, repl, a);
|
||||
auto shiftBy = rewriter.create<hw::ConstantOp>(loc, APInt(width, i));
|
||||
auto ppAlign = rewriter.create<comb::ShlOp>(loc, ppRow, shiftBy);
|
||||
partialProducts.push_back(ppAlign);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, partialProducts);
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult lowerBoothArray(ConversionPatternRewriter &rewriter,
|
||||
Value a, Value b, PartialProductOp op,
|
||||
unsigned width) {
|
||||
Location loc = op.getLoc();
|
||||
auto zeroFalse = rewriter.create<hw::ConstantOp>(loc, APInt(1, 0));
|
||||
auto zeroWidth = rewriter.create<hw::ConstantOp>(loc, APInt(width, 0));
|
||||
auto oneWidth = rewriter.create<hw::ConstantOp>(loc, APInt(width, 1));
|
||||
Value twoA = rewriter.create<comb::ShlOp>(loc, a, oneWidth);
|
||||
|
||||
SmallVector<Value> bBits = extractBits(rewriter, b);
|
||||
|
||||
SmallVector<Value> partialProducts;
|
||||
partialProducts.reserve(width);
|
||||
|
||||
// Booth encoding halves array height by grouping three bits at a time:
|
||||
// partialProducts[i] = a * (-2*b[2*i+1] + b[2*i] + b[2*i-1]) << 2*i
|
||||
// encNeg \approx (-2*b[2*i+1] + b[2*i] + b[2*i-1]) <= 0
|
||||
// encOne = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 1
|
||||
// encTwo = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 2
|
||||
Value encNegPrev;
|
||||
|
||||
// For even width - additional row contains the final sign correction
|
||||
for (unsigned i = 0; i <= width; i += 2) {
|
||||
// Get Booth bits: b[i+1], b[i], b[i-1] (b[-1] = 0)
|
||||
Value bim1 = (i == 0) ? zeroFalse : bBits[i - 1];
|
||||
Value bi = (i < width) ? bBits[i] : zeroFalse;
|
||||
Value bip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse;
|
||||
|
||||
// Is the encoding zero or negative (an approximation)
|
||||
Value encNeg = bip1;
|
||||
// Is the encoding one = b[i] xor b[i-1]
|
||||
Value encOne = rewriter.create<comb::XorOp>(loc, bi, bim1, true);
|
||||
// Is the encoding two = (bip1 & ~bi & ~bim1) | (~bip1 & bi & bim1)
|
||||
Value constOne = rewriter.create<hw::ConstantOp>(loc, APInt(1, 1));
|
||||
Value biInv = rewriter.create<comb::XorOp>(loc, bi, constOne, true);
|
||||
Value bip1Inv = rewriter.create<comb::XorOp>(loc, bip1, constOne, true);
|
||||
Value bim1Inv = rewriter.create<comb::XorOp>(loc, bim1, constOne, true);
|
||||
|
||||
Value andLeft = rewriter.create<comb::AndOp>(
|
||||
loc, ValueRange{bip1Inv, bi, bim1}, true);
|
||||
Value andRight = rewriter.create<comb::AndOp>(
|
||||
loc, ValueRange{bip1, biInv, bim1Inv}, true);
|
||||
Value encTwo = rewriter.create<comb::OrOp>(loc, andLeft, andRight, true);
|
||||
|
||||
Value encNegRepl = rewriter.create<comb::ReplicateOp>(loc, encNeg, width);
|
||||
Value encOneRepl = rewriter.create<comb::ReplicateOp>(loc, encOne, width);
|
||||
Value encTwoRepl = rewriter.create<comb::ReplicateOp>(loc, encTwo, width);
|
||||
|
||||
// Select between 2*a or 1*a or 0*a
|
||||
Value selTwoA = rewriter.create<comb::AndOp>(loc, encTwoRepl, twoA);
|
||||
Value selOneA = rewriter.create<comb::AndOp>(loc, encOneRepl, a);
|
||||
Value magA = rewriter.create<comb::OrOp>(loc, selTwoA, selOneA, true);
|
||||
|
||||
// Conditionally invert the row
|
||||
Value ppRow = rewriter.create<comb::XorOp>(loc, magA, encNegRepl, true);
|
||||
|
||||
// No sign-correction in the first row
|
||||
if (i == 0) {
|
||||
partialProducts.push_back(ppRow);
|
||||
encNegPrev = encNeg;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Insert a sign-correction from the previous row
|
||||
assert(i >= 2 && "Expected i to be at least 2 for sign correction");
|
||||
// {ppRow, 0, encNegPrev} << 2*(i-1)
|
||||
Value withSignCorrection = rewriter.create<comb::ConcatOp>(
|
||||
loc, ValueRange{ppRow, zeroFalse, encNegPrev});
|
||||
Value ppAlignPre =
|
||||
rewriter.create<comb::ExtractOp>(loc, withSignCorrection, 0, width);
|
||||
Value shiftBy = rewriter.create<hw::ConstantOp>(loc, APInt(width, i - 2));
|
||||
Value ppAlign = rewriter.create<comb::ShlOp>(loc, ppAlignPre, shiftBy);
|
||||
partialProducts.push_back(ppAlign);
|
||||
encNegPrev = encNeg;
|
||||
|
||||
if (partialProducts.size() == op.getNumResults())
|
||||
break;
|
||||
}
|
||||
|
||||
// Zero-pad to match the required output width
|
||||
while (partialProducts.size() < op.getNumResults())
|
||||
partialProducts.push_back(zeroWidth);
|
||||
|
||||
assert(partialProducts.size() == op.getNumResults() &&
|
||||
"Expected number of booth partial products to match results");
|
||||
|
||||
rewriter.replaceOp(op, partialProducts);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Convert Datapath to Comb pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct ConvertDatapathToCombPass
|
||||
: public impl::ConvertDatapathToCombBase<ConvertDatapathToCombPass> {
|
||||
void runOnOperation() override;
|
||||
using ConvertDatapathToCombBase<
|
||||
ConvertDatapathToCombPass>::ConvertDatapathToCombBase;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ConvertDatapathToCombPass::runOnOperation() {
|
||||
ConversionTarget target(getContext());
|
||||
|
||||
target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
|
||||
target.addIllegalDialect<DatapathDialect>();
|
||||
|
||||
RewritePatternSet patterns(&getContext());
|
||||
|
||||
patterns.add<DatapathPartialProductOpConversion>(patterns.getContext(),
|
||||
forceBooth);
|
||||
|
||||
if (lowerCompressToAdd)
|
||||
// Lower compressors to simple add operations for downstream optimisations
|
||||
patterns.add<DatapathCompressOpAddConversion>(patterns.getContext());
|
||||
else
|
||||
// Lower compressors to a complete gate-level implementation
|
||||
patterns.add<DatapathCompressOpConversion>(patterns.getContext());
|
||||
|
||||
if (failed(mlir::applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
|
@ -216,6 +216,86 @@ Value comb::createInject(OpBuilder &builder, Location loc, Value value,
|
|||
return builder.createOrFold<comb::ConcatOp>(loc, fragments);
|
||||
}
|
||||
|
||||
// Construct a full adder for three 1-bit inputs.
|
||||
std::pair<Value, Value> comb::fullAdder(OpBuilder &builder, Location loc,
|
||||
Value a, Value b, Value c) {
|
||||
auto aXorB = builder.createOrFold<comb::XorOp>(loc, a, b, true);
|
||||
Value sum = builder.createOrFold<comb::XorOp>(loc, aXorB, c, true);
|
||||
|
||||
auto carry = builder.createOrFold<comb::OrOp>(
|
||||
loc,
|
||||
ArrayRef<Value>{builder.createOrFold<comb::AndOp>(loc, a, b, true),
|
||||
builder.createOrFold<comb::AndOp>(loc, aXorB, c, true)},
|
||||
true);
|
||||
|
||||
return {sum, carry};
|
||||
}
|
||||
|
||||
// Perform Wallace tree reduction on partial products.
|
||||
// See https://en.wikipedia.org/wiki/Wallace_tree
|
||||
SmallVector<Value>
|
||||
comb::wallaceReduction(OpBuilder &builder, Location loc, size_t width,
|
||||
size_t targetAddends,
|
||||
SmallVector<SmallVector<Value>> &addends) {
|
||||
auto falseValue = builder.create<hw::ConstantOp>(loc, APInt(1, 0));
|
||||
SmallVector<SmallVector<Value>> newAddends;
|
||||
newAddends.reserve(addends.size());
|
||||
// Continue reduction until we have only two rows. The length of
|
||||
// `addends` is reduced by 1/3 in each iteration.
|
||||
while (addends.size() > targetAddends) {
|
||||
newAddends.clear();
|
||||
// Take three rows at a time and reduce to two rows(sum and carry).
|
||||
for (unsigned i = 0; i < addends.size(); i += 3) {
|
||||
if (i + 2 < addends.size()) {
|
||||
// We have three rows to reduce
|
||||
auto &row1 = addends[i];
|
||||
auto &row2 = addends[i + 1];
|
||||
auto &row3 = addends[i + 2];
|
||||
|
||||
assert(row1.size() == width && row2.size() == width &&
|
||||
row3.size() == width);
|
||||
|
||||
SmallVector<Value> sumRow, carryRow;
|
||||
sumRow.reserve(width);
|
||||
carryRow.reserve(width);
|
||||
carryRow.push_back(falseValue);
|
||||
|
||||
// Process each bit position
|
||||
for (unsigned j = 0; j < width; ++j) {
|
||||
// Full adder logic
|
||||
auto [sum, carry] =
|
||||
comb::fullAdder(builder, loc, row1[j], row2[j], row3[j]);
|
||||
sumRow.push_back(sum);
|
||||
if (j + 1 < width)
|
||||
carryRow.push_back(carry);
|
||||
}
|
||||
|
||||
newAddends.push_back(std::move(sumRow));
|
||||
newAddends.push_back(std::move(carryRow));
|
||||
} else {
|
||||
// Add remaining rows as is
|
||||
newAddends.append(addends.begin() + i, addends.end());
|
||||
}
|
||||
}
|
||||
std::swap(newAddends, addends);
|
||||
}
|
||||
|
||||
assert(addends.size() <= targetAddends);
|
||||
SmallVector<Value> carrySave;
|
||||
for (auto &addend : addends) {
|
||||
// Reverse the order of the bits
|
||||
std::reverse(addend.begin(), addend.end());
|
||||
carrySave.push_back(builder.create<comb::ConcatOp>(loc, addend));
|
||||
}
|
||||
|
||||
// Pad with zeros
|
||||
auto zero = builder.create<hw::ConstantOp>(loc, APInt(width, 0));
|
||||
while (carrySave.size() < targetAddends)
|
||||
carrySave.push_back(zero);
|
||||
|
||||
return carrySave;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ICmpOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -67,7 +67,6 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
|
|||
// ALLOW_ADD-NEXT: %[[RHS0:.+]] = comb.extract %rhs from 0 : (i3) -> i1
|
||||
// ALLOW_ADD-NEXT: %[[RHS1:.+]] = comb.extract %rhs from 1 : (i3) -> i1
|
||||
// ALLOW_ADD-NEXT: %[[RHS2:.+]] = comb.extract %rhs from 2 : (i3) -> i1
|
||||
// ALLOW_ADD-NEXT: %false = hw.constant false
|
||||
// Partial Products
|
||||
// ALLOW_ADD-NEXT: %[[P_0_0:.+]] = comb.and %[[LHS0]], %[[RHS0]] : i1
|
||||
// ALLOW_ADD-NEXT: %[[P_1_0:.+]] = comb.and %[[LHS1]], %[[RHS0]] : i1
|
||||
|
@ -76,6 +75,7 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
|
|||
// ALLOW_ADD-NEXT: %[[P_1_1:.+]] = comb.and %[[LHS1]], %[[RHS1]] : i1
|
||||
// ALLOW_ADD-NEXT: %[[P_2_1:.+]] = comb.and %[[LHS0]], %[[RHS2]] : i1
|
||||
// Wallace Tree Reduction
|
||||
// ALLOW_ADD-NEXT: %false = hw.constant false
|
||||
// ALLOW_ADD-NEXT: %[[XOR0:.+]] = comb.xor bin %[[P_1_0]], %[[P_0_1]] : i1
|
||||
// ALLOW_ADD-NEXT: %[[AND0:.+]] = comb.and bin %[[P_1_0]], %[[P_0_1]] : i1
|
||||
// ALLOW_ADD-NEXT: %[[XOR1:.+]] = comb.xor bin %[[P_2_0]], %[[P_1_1]] : i1
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
// RUN: circt-opt %s --convert-datapath-to-comb | FileCheck %s
|
||||
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-compress-to-add=true}))" | FileCheck %s --check-prefix=TO-ADD
|
||||
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-partial-product-to-booth=true}, canonicalize))" | FileCheck %s --check-prefix=FORCE-BOOTH
|
||||
|
||||
// CHECK-LABEL: @compressor
|
||||
hw.module @compressor(in %a : i2, in %b : i2, in %c : i2, out carry : i2, out save : i2) {
|
||||
//CHECK-NEXT: %[[A0:.+]] = comb.extract %a from 0 : (i2) -> i1
|
||||
//CHECK-NEXT: %[[A1:.+]] = comb.extract %a from 1 : (i2) -> i1
|
||||
//CHECK-NEXT: %[[B0:.+]] = comb.extract %b from 0 : (i2) -> i1
|
||||
//CHECK-NEXT: %[[B1:.+]] = comb.extract %b from 1 : (i2) -> i1
|
||||
//CHECK-NEXT: %[[C0:.+]] = comb.extract %c from 0 : (i2) -> i1
|
||||
//CHECK-NEXT: %[[C1:.+]] = comb.extract %c from 1 : (i2) -> i1
|
||||
//CHECK-NEXT: %false = hw.constant false
|
||||
//CHECK-NEXT: %[[AxB0:.+]] = comb.xor bin %[[A0]], %[[B0]] : i1
|
||||
//CHECK-NEXT: %[[AxBxC0:.+]] = comb.xor bin %[[AxB0]], %[[C0]] : i1
|
||||
//CHECK-NEXT: %[[AB0:.+]] = comb.and bin %[[A0]], %[[B0]] : i1
|
||||
//CHECK-NEXT: %[[AxBC0:.+]] = comb.and bin %[[AxB0]], %[[C0]] : i1
|
||||
//CHECK-NEXT: %[[AB0oAxBC0:.+]] = comb.or bin %[[AB0]], %[[AxBC0]] : i1
|
||||
//CHECK-NEXT: %[[AxB1:.+]] = comb.xor bin %[[A1]], %[[B1]] : i1
|
||||
//CHECK-NEXT: %[[AxBxC1:.+]] = comb.xor bin %[[AxB1]], %[[C1]] : i1
|
||||
//CHECK-NEXT: %[[AB1:.+]] = comb.and bin %[[A1]], %[[B1]] : i1
|
||||
//CHECK-NEXT: %[[AxBC1:.+]] = comb.and bin %[[AxB1]], %[[C1]] : i1
|
||||
//CHECK-NEXT: comb.or bin %[[AB1]], %[[AxBC1]] : i1
|
||||
//CHECK-NEXT: comb.concat %[[AxBxC1]], %[[AxBxC0]] : i1, i1
|
||||
//CHECK-NEXT: comb.concat %[[AB0oAxBC0]], %false : i1, i1
|
||||
%0:2 = datapath.compress %a, %b, %c : i2 [3 -> 2]
|
||||
hw.output %0#0, %0#1 : i2, i2
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @compressor_add
|
||||
// TO-ADD-LABEL: @compressor_add
|
||||
// TO-ADD-NEXT: %[[ADD:.+]] = comb.add bin %a, %b, %c : i2
|
||||
// TO-ADD-NEXT: %c0_i2 = hw.constant 0 : i2
|
||||
// TO-ADD-NEXT: hw.output %c0_i2, %[[ADD]] : i2, i2
|
||||
hw.module @compressor_add(in %a : i2, in %b : i2, in %c : i2, out carry : i2, out save : i2) {
|
||||
%0:2 = datapath.compress %a, %b, %c : i2 [3 -> 2]
|
||||
hw.output %0#0, %0#1 : i2, i2
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @partial_product
|
||||
hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) {
|
||||
// CHECK-NEXT: %[[B0:.+]] = comb.extract %b from 0 : (i3) -> i1
|
||||
// CHECK-NEXT: %[[B1:.+]] = comb.extract %b from 1 : (i3) -> i1
|
||||
// CHECK-NEXT: %[[B2:.+]] = comb.extract %b from 2 : (i3) -> i1
|
||||
// CHECK-NEXT: %[[B0R:.+]] = comb.replicate %[[B0]] : (i1) -> i3
|
||||
// CHECK-NEXT: %[[PP0:.+]] = comb.and %[[B0R]], %a : i3
|
||||
// CHECK-NEXT: %c0_i3 = hw.constant 0 : i3
|
||||
// CHECK-NEXT: comb.shl %[[PP0]], %c0_i3 : i3
|
||||
// CHECK-NEXT: %[[B1R:.+]] = comb.replicate %[[B1]] : (i1) -> i3
|
||||
// CHECK-NEXT: %[[PP1:.+]] = comb.and %[[B1R]], %a : i3
|
||||
// CHECK-NEXT: %c1_i3 = hw.constant 1 : i3
|
||||
// CHECK-NEXT: comb.shl %[[PP1]], %c1_i3 : i3
|
||||
// CHECK-NEXT: %[[B2R:.+]] = comb.replicate %[[B2]] : (i1) -> i3
|
||||
// CHECK-NEXT: %[[PP2:.+]] = comb.and %[[B2R]], %a : i3
|
||||
// CHECK-NEXT: %c2_i3 = hw.constant 2 : i3
|
||||
// CHECK-NEXT: comb.shl %[[PP2]], %c2_i3 : i3
|
||||
%0:3 = datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3)
|
||||
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @partial_product_booth
|
||||
// FORCE-BOOTH-LABEL: @partial_product_booth
|
||||
// Constants
|
||||
// FORCE-BOOTH-NEXT: %true = hw.constant true
|
||||
// FORCE-BOOTH-NEXT: %false = hw.constant false
|
||||
// FORCE-BOOTH-NEXT: %c0_i3 = hw.constant 0 : i3
|
||||
// 2*a
|
||||
// FORCE-BOOTH-NEXT: %0 = comb.extract %a from 0 : (i3) -> i2
|
||||
// FORCE-BOOTH-NEXT: %[[TWOA:.+]] = comb.concat %0, %false : i2, i1
|
||||
// FORCE-BOOTH-NEXT: %[[B0:.+]] = comb.extract %b from 0 : (i3) -> i1
|
||||
// FORCE-BOOTH-NEXT: %[[B1:.+]] = comb.extract %b from 1 : (i3) -> i1
|
||||
// FORCE-BOOTH-NEXT: %[[B2:.+]] = comb.extract %b from 2 : (i3) -> i1
|
||||
// PP0
|
||||
// FORCE-BOOTH-NEXT: %[[NB0:.+]] = comb.xor bin %[[B0]], %true : i1
|
||||
// FORCE-BOOTH-NEXT: %[[TWO0:.+]] = comb.and %[[B1]], %[[NB0]] : i1
|
||||
// FORCE-BOOTH-NEXT: %[[PPOSGN:.+]] = comb.replicate %[[B1]] : (i1) -> i3
|
||||
// FORCE-BOOTH-NEXT: %[[ONER:.+]] = comb.replicate %[[B0]] : (i1) -> i3
|
||||
// FORCE-BOOTH-NEXT: %[[TWO0R:.+]] = comb.replicate %[[TWO0]] : (i1) -> i3
|
||||
// FORCE-BOOTH-NEXT: %[[PP0TWOA:.+]] = comb.and %[[TWO0R]], %[[TWOA]] : i3
|
||||
// FORCE-BOOTH-NEXT: %[[PP0ONEA:.+]] = comb.and %[[ONER]], %a : i3
|
||||
// FORCE-BOOTH-NEXT: %[[PP0MAG:.+]] = comb.or bin %[[PP0TWOA]], %[[PP0ONEA]] : i3
|
||||
// FORCE-BOOTH-NEXT: %[[PP0:.+]] = comb.xor bin %[[PP0MAG]], %[[PPOSGN]] : i3
|
||||
// PP1
|
||||
// FORCE-BOOTH-NEXT: %[[B2XORB1:.+]] = comb.xor bin %4, %3 : i1
|
||||
// FORCE-BOOTH-NEXT: %[[A0:.+]] = comb.extract %a from 0 : (i3) -> i1
|
||||
// FORCE-BOOTH-NEXT: %[[PP1MSB:.+]] = comb.and %[[B2XORB1]], %[[A0]] : i1
|
||||
// FORCE-BOOTH-NEXT: %[[PP1:.+]] = comb.concat %[[PP1MSB]], %false, %[[B1]] : i1, i1, i1
|
||||
// FORCE-BOOTH-NEXT: hw.output %[[PP0]], %[[PP1]], %c0_i3 : i3, i3, i3
|
||||
hw.module @partial_product_booth(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) {
|
||||
%0:3 = datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3)
|
||||
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @partial_product_24
|
||||
hw.module @partial_product_24(in %a : i24, in %b : i24, out sum : i24) {
|
||||
%0:24 = datapath.partial_product %a, %b : (i24, i24) -> (i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24)
|
||||
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12, %0#13, %0#14, %0#15, %0#16, %0#17, %0#18, %0#19, %0#20, %0#21, %0#22, %0#23 : i24
|
||||
hw.output %1 : i24
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @partial_product_25
|
||||
hw.module @partial_product_25(in %a : i25, in %b : i25, out sum : i25) {
|
||||
%0:25 = datapath.partial_product %a, %b : (i25, i25) -> (i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25)
|
||||
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12, %0#13, %0#14, %0#15, %0#16, %0#17, %0#18, %0#19, %0#20, %0#21, %0#22, %0#23 : i25
|
||||
hw.output %1 : i25
|
||||
}
|
Loading…
Reference in New Issue