[Datapath] Conversion Pass Comb to Datapath (#8664)

Add support for lowering variadic adders and two-input multipliers to datapath operations. Now automiatng 
```
%0 = comb.mul %a, %b : i4
%1 = comb.add %0, %c : i4
```

Resulting from comb-to-datapath and canonicalize:
```
%0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4)
%1:2 = datapath.compress %0#0, %0#1, %0#2, %0#3, %c : i4 [5 -> 2]
%2 = comb.add %1#0, %1#1 : i4
```
This commit is contained in:
Samuel Coward 2025-07-09 16:54:55 +01:00 committed by GitHub
parent f91e77c509
commit 713a91ddff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 203 additions and 0 deletions

View File

@ -0,0 +1,21 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_CONVERSION_COMBTODATAPATH_H
#define CIRCT_CONVERSION_COMBTODATAPATH_H
#include "circt/Support/LLVM.h"
namespace circt {
#define GEN_PASS_DECL_CONVERTCOMBTODATAPATH
#include "circt/Conversion/Passes.h.inc"
} // namespace circt
#endif // CIRCT_CONVERSION_COMBTODATAPATH_H

View File

@ -22,6 +22,7 @@
#include "circt/Conversion/CalyxToHW.h"
#include "circt/Conversion/CombToAIG.h"
#include "circt/Conversion/CombToArith.h"
#include "circt/Conversion/CombToDatapath.h"
#include "circt/Conversion/CombToSMT.h"
#include "circt/Conversion/ConvertToArcs.h"
#include "circt/Conversion/DCToHW.h"

View File

@ -849,5 +849,21 @@ def ConvertAIGToComb: Pass<"convert-aig-to-comb", "hw::HWModuleOp"> {
];
}
//===----------------------------------------------------------------------===//
// ConvertCombToDatapath
//===----------------------------------------------------------------------===//
def ConvertCombToDatapath: Pass<"convert-comb-to-datapath", "hw::HWModuleOp"> {
let summary = "Lower Comb ops to Datapath ops";
let description = [{
This pass converts arithmetic Comb operations into Datapath operations that
leverage redundant number representations (carry save). Primarily for use
in the circt-synth flow.
}];
let dependentDialects = [
"circt::comb::CombDialect", "circt::datapath::DatapathDialect",
"circt::hw::HWDialect"
];
}
#endif // CIRCT_CONVERSION_PASSES_TD

View File

@ -10,6 +10,7 @@ add_circt_public_c_api_library(CIRCTCAPIConversion
CIRCTCalyxNative
CIRCTCombToAIG
CIRCTCombToArith
CIRCTCombToDatapath
CIRCTCombToLLVM
CIRCTCombToSMT
CIRCTConvertToArcs

View File

@ -5,6 +5,7 @@ add_subdirectory(CalyxToFSM)
add_subdirectory(CalyxToHW)
add_subdirectory(CombToAIG)
add_subdirectory(CombToArith)
add_subdirectory(CombToDatapath)
add_subdirectory(CombToLLVM)
add_subdirectory(CombToSMT)
add_subdirectory(ConvertToArcs)

View File

@ -0,0 +1,18 @@
add_circt_conversion_library(CIRCTCombToDatapath
CombToDatapath.cpp
ADDITIONAL_HEADER_DIRS
${CIRCT_MAIN_INCLUDE_DIR}/circt/Conversion/CombToDatapath
DEPENDS
CIRCTConversionPassIncGen
LINK_LIBS PUBLIC
CIRCTHW
CIRCTComb
CIRCTDatapath
MLIRIR
MLIRPass
MLIRSupport
MLIRTransforms
)

View File

@ -0,0 +1,116 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the main Comb to Datapath Conversion Pass Implementation.
//
//===----------------------------------------------------------------------===//
#include "circt/Conversion/CombToDatapath.h"
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/Datapath/DatapathOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace circt {
#define GEN_PASS_DEF_CONVERTCOMBTODATAPATH
#include "circt/Conversion/Passes.h.inc"
} // namespace circt
using namespace circt;
using namespace comb;
//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//
namespace {
// add(a1, a2, ...) -> add(compress(a1, a2, ...))
struct CombAddOpConversion : OpConversionPattern<AddOp> {
using OpConversionPattern<AddOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto width = op.getType().getIntOrFloatBitWidth();
// Skip a zero width value.
if (width == 0) {
rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
return success();
}
// Reduce to two values (carry,save)
auto results =
rewriter.create<datapath::CompressOp>(op.getLoc(), op.getOperands(), 2);
// carry+saved
rewriter.replaceOpWithNewOp<AddOp>(op, results.getResults(), true);
return success();
}
};
// mul(a,b) -> add(pp(a,b))
// multi-input adder will be converted to a compressor by other pattern
struct CombMulOpConversion : OpConversionPattern<MulOp> {
using OpConversionPattern<MulOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(MulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: support for variadic multipliers
if (adaptor.getInputs().size() != 2)
return failure();
auto width = op.getType().getIntOrFloatBitWidth();
// Create partial product rows - number of rows == width
auto pp = rewriter.create<datapath::PartialProductOp>(
op.getLoc(), op.getInputs(), width);
// Sum partial products
rewriter.replaceOpWithNewOp<AddOp>(op, pp.getResults(), true);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert Comb to Datapath pass
//===----------------------------------------------------------------------===//
namespace {
struct ConvertCombToDatapathPass
: public impl::ConvertCombToDatapathBase<ConvertCombToDatapathPass> {
void runOnOperation() override;
using ConvertCombToDatapathBase<
ConvertCombToDatapathPass>::ConvertCombToDatapathBase;
};
} // namespace
static void
populateCombToDatapathConversionPatterns(RewritePatternSet &patterns) {
patterns.add<CombAddOpConversion, CombMulOpConversion>(patterns.getContext());
}
void ConvertCombToDatapathPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<datapath::DatapathDialect, comb::CombDialect,
hw::HWDialect>();
// Permit 2-input adders (carry-propagate adders)
target.addDynamicallyLegalOp<comb::AddOp>(
[](comb::AddOp op) { return op.getNumOperands() <= 2; });
// TODO: determine lowering of multi-input multipliers
target.addDynamicallyLegalOp<comb::MulOp>(
[](comb::MulOp op) { return op.getNumOperands() > 2; });
RewritePatternSet patterns(&getContext());
populateCombToDatapathConversionPatterns(patterns);
if (failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}

View File

@ -0,0 +1,29 @@
// RUN: circt-opt %s --convert-comb-to-datapath | FileCheck %s
// CHECK-LABEL: @test
hw.module @test(in %arg0: i4, in %arg1: i4, in %arg2: i4, in %arg3: i4) {
// CHECK-NEXT: %c42_i32 = hw.constant 42 : i32
%c42_i32 = hw.constant 42 : i32
// CHECK-NEXT: comb.add %arg0, %arg1 : i4
%0 = comb.add %arg0, %arg1 : i4
// CHECK-NEXT: %[[COMP1:.+]]:2 = datapath.compress %arg0, %arg1, %arg2, %arg3 : i4 [4 -> 2]
// CHECK-NEXT: comb.add bin %[[COMP1]]#0, %[[COMP1]]#1 : i4
%1 = comb.add %arg0, %arg1, %arg2, %arg3 : i4
// CHECK-NEXT: %[[PP:.+]]:4 = datapath.partial_product %arg0, %arg1 : (i4, i4) -> (i4, i4, i4, i4)
// CHECK-NEXT: %[[COMP2:.+]]:2 = datapath.compress %[[PP]]#0, %[[PP]]#1, %[[PP]]#2, %[[PP]]#3 : i4 [4 -> 2]
// CHECK-NEXT: comb.add bin %[[COMP2]]#0, %[[COMP2]]#1 : i4
%2 = comb.mul %arg0, %arg1 : i4
// CHECK-NEXT: comb.mul %arg0, %arg1, %arg2 : i4
%7 = comb.mul %arg0, %arg1, %arg2 : i4
}
// CHECK-LABEL: @zero_width
hw.module @zero_width(in %arg0: i0, in %arg1: i0, in %arg2: i0) {
// CHECK-NEXT: hw.constant 0 : i0
%0 = comb.add %arg0, %arg1, %arg2 : i0
}