[Datapath] Operator definitions and canonicalization patterns (#8647)

Building on the boiler plate definition adding two operators
* `datapath.compress` - compressor tree circuit
* `datapath.partial_product` - partial product generation circuit

The key idea is to view datapath operators as generators of circuits that satisfy some contract, for example in the case of the `datapath.compress` summing it's results is equivalent to summing it's inputs. This allows us to defer implementing these critical circuits until later in the synthesis flow.

In a simple example, we can fold a*b+c using the datapath dialect to remove a carry-propagate adder:
```mlir
%0 = comb.mul %a, %b : i4
%1 = comb.add %0, %c : i4
```
Which is equivalent to:
```mlir
%0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4)
%1:2 = datapath.compress %0#0, %0#1, %0#2, %0#3, %c : i4 [5 -> 2]
%2 = comb.add %1#0, %1#1 : i4
```
This commit is contained in:
Samuel Coward 2025-07-07 19:50:56 +01:00 committed by GitHub
parent 74736548c8
commit f26f1ed5c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 532 additions and 2 deletions

View File

@ -27,7 +27,7 @@ In a simple example, we can fold a*b+c using the datapath dialect to remove a ca
``` ```
Which is equivalent to: Which is equivalent to:
```mlir ```mlir
%0:4 = datapath.pp %a, %b : 4 x i4 %0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4)
%1:2 = datapath.compress %0#0, %0#1, %0#2, %0#3, %c : 5 x i4 -> (i4, i4) %1:2 = datapath.compress %0#0, %0#1, %0#2, %0#3, %c : i4 [5 -> 2]
%2 = comb.add %1#0, %1#1 : i4 %2 = comb.add %1#0, %1#1 : i4
``` ```

View File

@ -23,5 +23,83 @@ include "circt/Dialect/HW/HWTypes.td"
class DatapathOp<string mnemonic, list<Trait> traits = []> : class DatapathOp<string mnemonic, list<Trait> traits = []> :
Op<DatapathDialect, mnemonic, traits>; Op<DatapathDialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
def CompressOp : DatapathOp<"compress",
[Pure, SameTypeOperands, SameOperandsAndResultType, Commutative]> {
let summary = "Reduce a set of bitvectors to a carry-save representation";
let description = [{
Reduce an array of bitvectors to a smaller set of bitvectors (at least 2).
A compressor tree sums multiple bitvectors (often partial products in
multipliers or adders). Instead of adding all bitvectors sequentially, a
compressor tree reduces the number of operands in parallel stages. The
result is stored in a redundant (carry-save) representation, deferring the
compressor tree implementation to a later stage.
Example:
```mlir
%0:2 = datapath.compress %a, %b, %c : i16 [3 -> 2]
```
}];
let arguments = (ins Variadic<HWIntegerType>:$inputs);
let results = (outs Variadic<HWIntegerType>:$results);
let assemblyFormat = [{
$inputs attr-dict `:` custom<CompressFormat>(type($inputs), type($results))
}];
let hasVerifier = 1;
let hasCanonicalizer = true;
let builders = [
OpBuilder<(ins "ValueRange":$lhs, "int32_t":$targetRows), [{
auto inputType = lhs.front().getType();
SmallVector<Type> resultTypes(targetRows, inputType);
return build($_builder, $_state, resultTypes, lhs);
}]>
];
}
def PartialProductOp : DatapathOp<"partial_product",
[Pure, SameTypeOperands, SameOperandsAndResultType, Commutative]> {
let summary = "Generate partial products from multiplying the operands";
let description = [{
The first step in a multiplication is to generate partial products, which
when summed, yield the product of the two operands. The partial
product operator does not specify an implementation, only that summing the
results will yield the product of the two operands. The number of results
corresponds to the rows of a partial product array, which by default is
equal to the width of the inputs.
Verilog Example 4-bit multiplication:
```verilog
partial_product[0][3:0] = {4{a[0]}} & b
...
partial_product[3][3:0] = {4{a[3]}} & b
ab[3:0] = partial_product[0] + ... + partial_product[3] // = a*b
```
Example using `datapath` dialect:
```mlir
%0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4)
```
}];
let arguments = (ins HWIntegerType:$lhs,
HWIntegerType:$rhs);
let results = (outs Variadic<HWIntegerType>:$results);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = true;
let builders = [
OpBuilder<(ins "ValueRange":$lhs, "int32_t":$targetRows), [{
auto inputType = lhs.front().getType();
SmallVector<Type> resultTypes(targetRows, inputType);
return build($_builder, $_state, resultTypes, lhs);
}]>
];
}
#endif // CIRCT_DIALECT_DATAPATH_OPS_TD #endif // CIRCT_DIALECT_DATAPATH_OPS_TD

View File

@ -1,5 +1,7 @@
add_circt_dialect_library(CIRCTDatapath add_circt_dialect_library(CIRCTDatapath
DatapathDialect.cpp DatapathDialect.cpp
DatapathFolds.cpp
DatapathOps.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/Datapath ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/Datapath

View File

@ -0,0 +1,259 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/Datapath/DatapathOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
#include <algorithm>
using namespace mlir;
using namespace circt;
using namespace datapath;
using namespace matchers;
//===----------------------------------------------------------------------===//
// Compress Operation
//===----------------------------------------------------------------------===//
// Check that all compressor results are included in this list of operands
// If not we must take care as manipulating compressor results independently
// could easily introduce a non-equivalent representation.
static bool areAllCompressorResultsSummed(ValueRange compressResults,
ValueRange operands) {
for (auto result : compressResults) {
if (!llvm::is_contained(operands, result))
return false;
}
return true;
}
struct FoldCompressIntoCompress
: public OpRewritePattern<datapath::CompressOp> {
using OpRewritePattern::OpRewritePattern;
// compress(compress(a,b,c), add(e,f)) -> compress(a,b,c,e,f)
LogicalResult matchAndRewrite(datapath::CompressOp compOp,
PatternRewriter &rewriter) const override {
auto operands = compOp.getOperands();
llvm::SmallSetVector<Value, 8> processedCompressorResults;
SmallVector<Value, 8> newCompressOperands;
for (Value operand : operands) {
// Skip if already processed this compressor
if (processedCompressorResults.contains(operand))
continue;
// If the operand has multiple uses, we do not fold it into a compress
// operation, so we treat it as a regular operand to maintain sharing.
if (!operand.hasOneUse()) {
newCompressOperands.push_back(operand);
continue;
}
// Found a compress op - add its operands to our new list
if (auto compressOp = operand.getDefiningOp<datapath::CompressOp>()) {
// Check that all results of the compressor are summed in this add
if (!areAllCompressorResultsSummed(compressOp.getResults(), operands))
return failure();
llvm::append_range(newCompressOperands, compressOp.getOperands());
// Only process each compressor once as multiple operands will point
// to the same defining operation
processedCompressorResults.insert(compressOp.getResults().begin(),
compressOp.getResults().end());
continue;
}
if (auto addOp = operand.getDefiningOp<comb::AddOp>()) {
llvm::append_range(newCompressOperands, addOp.getOperands());
continue;
}
// Regular operand - just add it to our list
newCompressOperands.push_back(operand);
}
// If unable to collect more operands then this pattern doesn't apply
if (newCompressOperands.size() <= compOp.getNumOperands())
return failure();
// Create a new CompressOp with all collected operands
rewriter.replaceOpWithNewOp<datapath::CompressOp>(
compOp, newCompressOperands, compOp.getNumResults());
return success();
}
};
struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {
using OpRewritePattern::OpRewritePattern;
// add(compress(a,b,c),d) -> add(compress(a,b,c,d))
LogicalResult matchAndRewrite(comb::AddOp addOp,
PatternRewriter &rewriter) const override {
// comb.add canonicalization patterns handle folding add operations
if (addOp.getNumOperands() <= 2)
return failure();
// Get operands of the AddOp
auto operands = addOp.getOperands();
llvm::SmallSetVector<Value, 8> processedCompressorResults;
SmallVector<Value, 8> newCompressOperands;
// Only construct compressor if can form a larger compressor than what
// is currently an input of this add
bool shouldFold = false;
for (Value operand : operands) {
// Skip if already processed this compressor
if (processedCompressorResults.contains(operand))
continue;
// If the operand has multiple uses, we do not fold it into a compress
// operation, so we treat it as a regular operand.
if (!operand.hasOneUse()) {
shouldFold |= !newCompressOperands.empty();
newCompressOperands.push_back(operand);
continue;
}
// Found a compress op - add its operands to our new list
if (auto compressOp = operand.getDefiningOp<datapath::CompressOp>()) {
// Check that all results of the compressor are summed in this add
if (!areAllCompressorResultsSummed(compressOp.getResults(), operands))
return failure();
// If we've already added one operand it should be folded
shouldFold |= !newCompressOperands.empty();
llvm::append_range(newCompressOperands, compressOp.getOperands());
// Only process each compressor once
processedCompressorResults.insert(compressOp.getResults().begin(),
compressOp.getResults().end());
continue;
}
if (auto addOp = operand.getDefiningOp<comb::AddOp>()) {
shouldFold |= !newCompressOperands.empty();
llvm::append_range(newCompressOperands, addOp.getOperands());
continue;
}
// Regular operand - just add it to our list
shouldFold |= !newCompressOperands.empty();
newCompressOperands.push_back(operand);
}
// Only fold if we have constructed a larger compressor than what was
// already there
if (!shouldFold)
return failure();
// Create a new CompressOp with all collected operands
auto newCompressOp = rewriter.create<datapath::CompressOp>(
addOp.getLoc(), newCompressOperands, 2);
// Replace the original AddOp with a new add(compress(inputs))
rewriter.replaceOpWithNewOp<comb::AddOp>(addOp, newCompressOp.getResults(),
true);
return success();
}
};
struct ConstantFoldCompress : public OpRewritePattern<CompressOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CompressOp op,
PatternRewriter &rewriter) const override {
auto inputs = op.getInputs();
auto size = inputs.size();
APInt value;
// compress(..., 0) -> compress(...) -- identity
if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
// If only reducing by one row and contains zero - pass through operands
if (size - 1 == op.getNumResults()) {
rewriter.replaceOp(op, inputs.drop_back());
return success();
}
// Default create a compressor with fewer arguments
rewriter.replaceOpWithNewOp<CompressOp>(op, inputs.drop_back(),
op.getNumResults());
return success();
}
return failure();
}
};
void CompressOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<FoldCompressIntoCompress, FoldAddIntoCompress, ConstantFoldCompress>(
context);
}
//===----------------------------------------------------------------------===//
// Partial Product Operation
//===----------------------------------------------------------------------===//
struct ReduceNumPartialProducts : public OpRewritePattern<PartialProductOp> {
using OpRewritePattern::OpRewritePattern;
// pp(concat(0,a), concat(0,b)) -> reduced number of results
LogicalResult matchAndRewrite(PartialProductOp op,
PatternRewriter &rewriter) const override {
auto operands = op.getOperands();
unsigned inputWidth = operands[0].getType().getIntOrFloatBitWidth();
// TODO: implement a constant multiplication for the PartialProductOp
size_t maxNonZeroBits = 0;
for (Value operand : operands) {
// If the extracted bits are all known, then return the result.
auto knownBits = comb::computeKnownBits(operand);
if (knownBits.isUnknown())
return failure(); // Skip if we don't know anything about the bits
size_t nonZeroBits = inputWidth - knownBits.Zero.countLeadingOnes();
// If all bits non-zero we will not reduce the number of results
if (nonZeroBits == op.getNumResults())
return failure();
maxNonZeroBits = std::max(maxNonZeroBits, nonZeroBits);
}
auto newPP = rewriter.create<datapath::PartialProductOp>(
op.getLoc(), op.getOperands(), maxNonZeroBits);
auto zero = rewriter.create<hw::ConstantOp>(op.getLoc(),
APInt::getZero(inputWidth));
// Collect newPP results and pad with zeros if needed
SmallVector<Value> newResults(newPP.getResults().begin(),
newPP.getResults().end());
newResults.append(op.getNumResults() - newResults.size(), zero);
rewriter.replaceOp(op, newResults);
return success();
}
};
void PartialProductOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ReduceNumPartialProducts>(context);
}

View File

@ -0,0 +1,68 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements datapath ops.
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Datapath/DatapathOps.h"
using namespace circt;
using namespace datapath;
LogicalResult CompressOp::verify() {
// The compressor must reduce the number of operands by at least 1 otherwise
// it fails to perform any reduction.
if (getNumOperands() < 3)
return emitOpError("requires 3 or more arguments - otherwise use add");
if (getNumResults() >= getNumOperands())
return emitOpError("must reduce the number of operands by at least 1");
if (getNumResults() < 2)
return emitOpError("must produce at least 2 results");
return success();
}
// Parser for the custom type format
// Parser for "<input-type> [<num-inputs> -> <num-outputs>]"
static ParseResult parseCompressFormat(OpAsmParser &parser,
SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &resultTypes) {
int64_t inputCount, resultCount;
Type inputElementType;
if (parser.parseType(inputElementType) || parser.parseLSquare() ||
parser.parseInteger(inputCount) || parser.parseArrow() ||
parser.parseInteger(resultCount) || parser.parseRSquare())
return failure();
// Inputs and results have same type
inputTypes.assign(inputCount, inputElementType);
resultTypes.assign(resultCount, inputElementType);
return success();
}
// Printer for "<input-type> [<num-inputs> -> <num-outputs>]"
static void printCompressFormat(OpAsmPrinter &printer, Operation *op,
TypeRange inputTypes, TypeRange resultTypes) {
printer << inputTypes[0] << " [" << inputTypes.size() << " -> "
<< resultTypes.size() << "]";
}
//===----------------------------------------------------------------------===//
// TableGen generated logic.
//===----------------------------------------------------------------------===//
// Provide the autogenerated implementation guts for the Op classes.
#define GET_OP_CLASSES
#include "circt/Dialect/Datapath/Datapath.cpp.inc"

View File

@ -0,0 +1,15 @@
// RUN: circt-opt %s -verify-roundtrip | FileCheck %s
// CHECK-LABEL: @compressor
hw.module @compressor(in %a : i4, in %b : i4, in %c : i4, out carry : i4, out save : i4) {
// CHECK-NEXT: datapath.compress %a, %b, %c : i4 [3 -> 2]
%0:2 = datapath.compress %a, %b, %c : i4 [3 -> 2]
hw.output %0#0, %0#1 : i4, i4
}
// CHECK-LABEL: @partial_product
hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) {
// CHECK-NEXT: datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3)
%0:3 = datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3)
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
}

View File

@ -0,0 +1,81 @@
// RUN: circt-opt %s --canonicalize | FileCheck %s
// CHECK-LABEL: @do_nothing
hw.module @do_nothing(in %a : i4, in %b : i4, out carry : i4, out save : i4) {
// CHECK-NEXT: %[[PP:.+]]:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4)
// CHECK-NEXT: datapath.compress %[[PP]]#0, %[[PP]]#1, %[[PP]]#2, %[[PP]]#3 : i4 [4 -> 2]
%0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4)
%1:2 = datapath.compress %0#0, %0#1, %0#2, %0#3 : i4 [4 -> 2]
hw.output %1#0, %1#1 : i4, i4
}
// CHECK-LABEL: @fold_compress
hw.module @fold_compress(in %a : i4, in %b : i4, in %c : i4, in %d : i4, out carry : i4, out save : i4) {
// CHECK-NEXT: datapath.compress %d, %a, %b, %c : i4 [4 -> 2]
%0:2 = datapath.compress %a, %b, %c : i4 [3 -> 2]
%1:2 = datapath.compress %d, %0#0, %0#1 : i4 [3 -> 2]
hw.output %1#0, %1#1 : i4, i4
}
// CHECK-LABEL: @fold_add
hw.module @fold_add(in %a : i4, in %b : i4, in %c : i4, in %d : i4, out sum : i4) {
// CHECK-NEXT: %[[COMP:.+]]:2 = datapath.compress %d, %a, %b, %c : i4 [4 -> 2]
// CHECK-NEXT: comb.add bin %[[COMP]]#0, %[[COMP]]#1 : i4
%0:2 = datapath.compress %a, %b, %c : i4 [3 -> 2]
%1 = comb.add %d, %0#0, %0#1 : i4
hw.output %1 : i4
}
// CHECK-LABEL: @constant_fold_compress
hw.module @constant_fold_compress(in %a : i4, in %b : i4, in %c : i4,
out sum0 : i4, out carry0 : i4, out sum1 : i4, out carry1 : i4) {
%c0_i4 = hw.constant 0 : i4
%0:2 = datapath.compress %a, %b, %c0_i4 : i4 [3 -> 2]
// CHECK-NEXT: %[[COMP:.+]]:2 = datapath.compress %a, %b, %c : i4 [3 -> 2]
%1:2 = datapath.compress %a, %b, %c0_i4, %c : i4 [4 -> 2]
// CHECK-NEXT: hw.output %a, %b, %[[COMP]]#0, %[[COMP]]#1 : i4, i4, i4, i4
hw.output %0#0, %0#1, %1#0, %1#1 : i4, i4, i4, i4
}
// CHECK-LABEL: @constant_fold_compress_passthrough
hw.module @constant_fold_compress_passthrough(in %a : i4, in %b : i4, in %c : i4,
out sum0 : i4, out sum1 : i4, out sum2 : i4) {
%c0_i4 = hw.constant 0 : i4
%0:3 = datapath.compress %a, %b, %c0_i4, %c : i4 [4 -> 3]
// CHECK-NEXT: hw.output %a, %b, %c : i4, i4, i4
hw.output %a, %b, %c : i4, i4, i4
}
// CHECK-LABEL: @constant_fold_partial_product
hw.module @constant_fold_partial_product(in %a : i3, in %b : i3, out sum : i4) {
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %[[CONCAT_A:.+]] = comb.concat %false, %a : i1, i3
// CHECK-NEXT: %[[CONCAT_B:.+]] = comb.concat %false, %b : i1, i3
// CHECK-NEXT: %[[PP:.+]]:3 = datapath.partial_product %[[CONCAT_A]], %[[CONCAT_B]] : (i4, i4) -> (i4, i4, i4)
// CHECK-NEXT: %[[COMP:.+]]:2 = datapath.compress %[[PP]]#0, %[[PP]]#1, %[[PP]]#2 : i4 [3 -> 2]
// CHECK-NEXT: comb.add bin %[[COMP]]#0, %[[COMP]]#1 : i4
%false = hw.constant false
%0 = comb.concat %false, %a : i1, i3
%1 = comb.concat %false, %b : i1, i3
%2:4 = datapath.partial_product %0, %1 : (i4, i4) -> (i4, i4, i4, i4)
%3 = comb.add %2#0, %2#1, %2#2, %2#3 : i4
hw.output %3 : i4
}
// CHECK-LABEL: @partial_product_do_nothing
hw.module @partial_product_do_nothing(in %a : i3, in %b : i4, out sum : i4) {
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %[[CONCAT_A:.+]] = comb.concat %false, %a : i1, i3
// CHECK-NEXT: %[[PP:.+]]:4 = datapath.partial_product %[[CONCAT_A]], %b : (i4, i4) -> (i4, i4, i4, i4)
// CHECK-NEXT: %[[COMP:.+]]:2 = datapath.compress %[[PP]]#0, %[[PP]]#1, %[[PP]]#2, %[[PP]]#3 : i4 [4 -> 2]
// CHECK-NEXT: comb.add bin %[[COMP]]#0, %[[COMP]]#1 : i4
%false = hw.constant false
%0 = comb.concat %false, %a : i1, i3
%1:4 = datapath.partial_product %0, %b : (i4, i4) -> (i4, i4, i4, i4)
%2:2 = datapath.compress %1#0, %1#1, %1#2, %1#3 : i4 [4 -> 2]
%3 = comb.add bin %2#0, %2#1 : i4
hw.output %3 : i4
}

View File

@ -0,0 +1,27 @@
// RUN: circt-opt %s -split-input-file -verify-diagnostics
hw.module @err(in %a: i8, in %b: i8) {
// expected-error @+1 {{'datapath.compress' op requires 3 or more arguments - otherwise use add}}
%0:2 = datapath.compress %a, %b : i8 [2 -> 2]
}
// -----
hw.module @err(in %a: i8, in %b: i8, in %c: i8) {
// expected-error @+1 {{'datapath.compress' op must produce at least 2 results}}
%0 = datapath.compress %a, %b, %c : i8 [3 -> 1]
}
// -----
hw.module @err(in %a: i8, in %b: i8, in %c: i8) {
// expected-error @+1 {{'datapath.compress' op must reduce the number of operands by at least 1}}
%0:3 = datapath.compress %a, %b, %c : i8 [3 -> 3]
}
// -----
hw.module @err(in %a: i8, in %b: i8, in %c: i8) {
// expected-error @+1 {{'datapath.compress' op must reduce the number of operands by at least 1}}
%0:4 = datapath.compress %a, %b, %c : i8 [3 -> 4]
}