[Datapath] Add Datapath to SMT conversion pass (#8682)

To verify correctness of transformations involving datapath operations,
include a lowering from datapath to SMT. Will later be integrated into
circt-lec to enable verification of these operators. Each operator
satisfied a contract, rather than providing a precise semantics for
every operator. This is because the datapath operators return values in
redundant number representations, meaning there are many valid
implementations.

For example:
```mlir
%0:2 = datapath.compress %a, %b, %c : i8 [3 -> 2]
```
Will be verified by introducing free variables for each return value 
`(%0#0, %0#1)` then asserting that the sum of the associated free 
variables is equal to the sum of the inputs:
`assert(%0#0 + %0#1 == %a + %b + %c)`.

Whilst this is encoded as an assert it really represents an assumption
that must be satisfied by a valid implementation of datapath.compress.
This commit is contained in:
Samuel Coward 2025-07-11 16:13:11 +01:00 committed by GitHub
parent c43ec9809f
commit 097604cfa1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 253 additions and 0 deletions

View File

@ -0,0 +1,25 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_CONVERSION_DATAPATHTOSMT_H
#define CIRCT_CONVERSION_DATAPATHTOSMT_H
#include "circt/Support/LLVM.h"
namespace circt {
/// Get the Datapath to SMT conversion patterns.
void populateDatapathToSMTConversionPatterns(TypeConverter &converter,
RewritePatternSet &patterns);
#define GEN_PASS_DECL_CONVERTDATAPATHTOSMT
#include "circt/Conversion/Passes.h.inc"
} // namespace circt
#endif // CIRCT_CONVERSION_DATAPATHTOSMT_H

View File

@ -26,6 +26,7 @@
#include "circt/Conversion/CombToSMT.h"
#include "circt/Conversion/ConvertToArcs.h"
#include "circt/Conversion/DCToHW.h"
#include "circt/Conversion/DatapathToSMT.h"
#include "circt/Conversion/ExportChiselInterface.h"
#include "circt/Conversion/ExportVerilog.h"
#include "circt/Conversion/FIRRTLToHW.h"

View File

@ -866,4 +866,15 @@ def ConvertCombToDatapath: Pass<"convert-comb-to-datapath", "hw::HWModuleOp"> {
];
}
//===----------------------------------------------------------------------===//
// ConvertDatapathToSMT
//===----------------------------------------------------------------------===//
def ConvertDatapathToSMT : Pass<"convert-datapath-to-smt"> {
let summary = "Convert datapath ops to SMT ops";
let dependentDialects = [
"mlir::smt::SMTDialect"
];
}
#endif // CIRCT_CONVERSION_PASSES_TD

View File

@ -14,6 +14,7 @@ add_circt_public_c_api_library(CIRCTCAPIConversion
CIRCTCombToLLVM
CIRCTCombToSMT
CIRCTConvertToArcs
CIRCTDatapathToSMT
CIRCTDCToHW
CIRCTExportChiselInterface
CIRCTExportVerilog

View File

@ -9,6 +9,7 @@ add_subdirectory(CombToDatapath)
add_subdirectory(CombToLLVM)
add_subdirectory(CombToSMT)
add_subdirectory(ConvertToArcs)
add_subdirectory(DatapathToSMT)
add_subdirectory(DCToHW)
add_subdirectory(ExportAIGER)
add_subdirectory(ExportChiselInterface)

View File

@ -0,0 +1,16 @@
add_circt_conversion_library(CIRCTDatapathToSMT
DatapathToSMT.cpp
DEPENDS
CIRCTConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
CIRCTDatapath
CIRCTComb
CIRCTHWToSMT
MLIRSMT
MLIRTransforms
)

View File

@ -0,0 +1,158 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Conversion/DatapathToSMT.h"
#include "circt/Conversion/HWToSMT.h"
#include "circt/Dialect/Datapath/DatapathOps.h"
#include "mlir/Dialect/SMT/IR/SMTOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace circt {
#define GEN_PASS_DEF_CONVERTDATAPATHTOSMT
#include "circt/Conversion/Passes.h.inc"
} // namespace circt
using namespace mlir;
using namespace circt;
using namespace datapath;
//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//
namespace {
// Lower to an SMT assertion that summing the results is equivalent to summing
// the compress inputs
// d:2 = compress(a, b, c) ->
// assert(d#0 + d#1 == a + b + c)
struct CompressOpConversion : OpConversionPattern<CompressOp> {
using OpConversionPattern<CompressOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ValueRange operands = adaptor.getOperands();
ValueRange results = op.getResults();
// Sum operands
Value operandRunner = operands[0];
for (Value operand : operands.drop_front())
operandRunner =
rewriter.create<smt::BVAddOp>(op.getLoc(), operandRunner, operand);
// Create free variables
SmallVector<Value, 2> newResults;
newResults.reserve(results.size());
for (Value result : results) {
auto declareFunOp = rewriter.create<smt::DeclareFunOp>(
op.getLoc(), typeConverter->convertType(result.getType()));
newResults.push_back(declareFunOp.getResult());
}
// Sum the free variables
Value resultRunner = newResults.front();
for (auto freeVar : llvm::drop_begin(newResults, 1))
resultRunner =
rewriter.create<smt::BVAddOp>(op.getLoc(), resultRunner, freeVar);
// Assert sum operands == sum results (free variables)
auto premise =
rewriter.create<smt::EqOp>(op.getLoc(), operandRunner, resultRunner);
// Encode via an assertion (could be relaxed to an assumption).
rewriter.create<smt::AssertOp>(op.getLoc(), premise);
if (newResults.size() != results.size())
return rewriter.notifyMatchFailure(op, "expected same number of results");
rewriter.replaceOp(op, newResults);
return success();
}
};
// Lower to an SMT assertion that summing the results is equivalent to the
// product of the partial_product inputs
// c:<N> = partial_product(a, b) ->
// assert(c#0 + ... + c#<N-1> == a * b)
struct PartialProductOpConversion : OpConversionPattern<PartialProductOp> {
using OpConversionPattern<PartialProductOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PartialProductOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ValueRange operands = adaptor.getOperands();
ValueRange results = op.getResults();
// Multiply the operands
auto mulResult =
rewriter.create<smt::BVMulOp>(op.getLoc(), operands[0], operands[1]);
// Create free variables
SmallVector<Value, 2> newResults;
newResults.reserve(results.size());
for (Value result : results) {
auto declareFunOp = rewriter.create<smt::DeclareFunOp>(
op.getLoc(), typeConverter->convertType(result.getType()));
newResults.push_back(declareFunOp.getResult());
}
// Sum the free variables
Value resultRunner = newResults.front();
for (auto freeVar : llvm::drop_begin(newResults, 1))
resultRunner =
rewriter.create<smt::BVAddOp>(op.getLoc(), resultRunner, freeVar);
// Assert product of operands == sum results (free variables)
auto premise =
rewriter.create<smt::EqOp>(op.getLoc(), mulResult, resultRunner);
// Encode via an assertion (could be relaxed to an assumption).
rewriter.create<smt::AssertOp>(op.getLoc(), premise);
if (newResults.size() != results.size())
return rewriter.notifyMatchFailure(op, "expected same number of results");
rewriter.replaceOp(op, newResults);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert Datapath to SMT pass
//===----------------------------------------------------------------------===//
namespace {
struct ConvertDatapathToSMTPass
: public circt::impl::ConvertDatapathToSMTBase<ConvertDatapathToSMTPass> {
void runOnOperation() override;
};
} // namespace
void circt::populateDatapathToSMTConversionPatterns(
TypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<CompressOpConversion, PartialProductOpConversion>(
converter, patterns.getContext());
}
void ConvertDatapathToSMTPass::runOnOperation() {
ConversionTarget target(getContext());
target.addIllegalDialect<datapath::DatapathDialect>();
target.addLegalDialect<smt::SMTDialect>();
RewritePatternSet patterns(&getContext());
TypeConverter converter;
populateHWToSMTTypeConverter(converter);
populateDatapathToSMTConversionPatterns(converter, patterns);
if (failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}

View File

@ -0,0 +1,40 @@
// RUN: circt-opt %s --convert-datapath-to-smt | FileCheck %s
// CHECK-LABEL: @compressor
hw.module @compressor(in %a : i4, in %b : i4, in %c : i4, out carry : i4, out save : i4) {
// CHECK-NEXT: %[[C:.+]] = builtin.unrealized_conversion_cast %c : i4 to !smt.bv<4>
// CHECK-NEXT: %[[B:.+]] = builtin.unrealized_conversion_cast %b : i4 to !smt.bv<4>
// CHECK-NEXT: %[[A:.+]] = builtin.unrealized_conversion_cast %a : i4 to !smt.bv<4>
// CHECK-NEXT: %[[AB:.+]] = smt.bv.add %[[A]], %[[B]] : !smt.bv<4>
// CHECK-NEXT: %[[INS:.+]] = smt.bv.add %[[AB]], %[[C]] : !smt.bv<4>
// CHECK-NEXT: %[[COMP0:.+]] = smt.declare_fun : !smt.bv<4>
// CHECK-NEXT: %[[COMP0_BV:.+]] = builtin.unrealized_conversion_cast %[[COMP0]] : !smt.bv<4> to i4
// CHECK-NEXT: %[[COMP1:.+]] = smt.declare_fun : !smt.bv<4>
// CHECK-NEXT: %[[COMP1_BV:.+]] = builtin.unrealized_conversion_cast %7 : !smt.bv<4> to i4
// CHECK-NEXT: %[[OUT:.+]] = smt.bv.add %[[COMP0]], %[[COMP1]] : !smt.bv<4>
// CHECK-NEXT: %[[P:.+]] = smt.eq %[[INS]], %[[OUT]] : !smt.bv<4>
// CHECK-NEXT: smt.assert %[[P]]
// CHECK-NEXT: hw.output %[[COMP0_BV]], %[[COMP1_BV]] : i4, i4
%0:2 = datapath.compress %a, %b, %c : i4 [3 -> 2]
hw.output %0#0, %0#1 : i4, i4
}
// CHECK-LABEL: @partial_product
hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) {
//CHECK-NEXT: %[[B:.+]] = builtin.unrealized_conversion_cast %b : i3 to !smt.bv<3>
//CHECK-NEXT: %[[A:.+]] = builtin.unrealized_conversion_cast %a : i3 to !smt.bv<3>
//CHECK-NEXT: %[[MUL:.+]] = smt.bv.mul %[[A]], %[[B]] : !smt.bv<3>
//CHECK-NEXT: %[[PP0:.+]] = smt.declare_fun : !smt.bv<3>
//CHECK-NEXT: %[[PP0_BV:.+]] = builtin.unrealized_conversion_cast %[[PP0]] : !smt.bv<3> to i3
//CHECK-NEXT: %[[PP1:.+]] = smt.declare_fun : !smt.bv<3>
//CHECK-NEXT: %[[PP1_BV:.+]] = builtin.unrealized_conversion_cast %[[PP1]] : !smt.bv<3> to i3
//CHECK-NEXT: %[[PP2:.+]] = smt.declare_fun : !smt.bv<3>
//CHECK-NEXT: %[[PP2_BV:.+]] = builtin.unrealized_conversion_cast %[[PP2]] : !smt.bv<3> to i3
//CHECK-NEXT: %[[ADD01:.+]] = smt.bv.add %[[PP0]], %[[PP1]] : !smt.bv<3>
//CHECK-NEXT: %[[ADD012:.+]] = smt.bv.add %[[ADD01]], %[[PP2]] : !smt.bv<3>
//CHECK-NEXT: %[[P:.+]] = smt.eq %[[MUL]], %[[ADD012]] : !smt.bv<3>
//CHECK-NEXT: smt.assert %[[P]]
//CHECK-NEXT: hw.output %[[PP0_BV]], %[[PP1_BV]], %[[PP2_BV]] : i3, i3, i3
%0:3 = datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3)
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
}