mirror of https://github.com/llvm/circt.git
[Datapath] Add Datapath to SMT conversion pass (#8682)
To verify correctness of transformations involving datapath operations, include a lowering from datapath to SMT. Will later be integrated into circt-lec to enable verification of these operators. Each operator satisfied a contract, rather than providing a precise semantics for every operator. This is because the datapath operators return values in redundant number representations, meaning there are many valid implementations. For example: ```mlir %0:2 = datapath.compress %a, %b, %c : i8 [3 -> 2] ``` Will be verified by introducing free variables for each return value `(%0#0, %0#1)` then asserting that the sum of the associated free variables is equal to the sum of the inputs: `assert(%0#0 + %0#1 == %a + %b + %c)`. Whilst this is encoded as an assert it really represents an assumption that must be satisfied by a valid implementation of datapath.compress.
This commit is contained in:
parent
c43ec9809f
commit
097604cfa1
|
@ -0,0 +1,25 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CIRCT_CONVERSION_DATAPATHTOSMT_H
|
||||
#define CIRCT_CONVERSION_DATAPATHTOSMT_H
|
||||
|
||||
#include "circt/Support/LLVM.h"
|
||||
|
||||
namespace circt {
|
||||
|
||||
/// Get the Datapath to SMT conversion patterns.
|
||||
void populateDatapathToSMTConversionPatterns(TypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
#define GEN_PASS_DECL_CONVERTDATAPATHTOSMT
|
||||
#include "circt/Conversion/Passes.h.inc"
|
||||
|
||||
} // namespace circt
|
||||
|
||||
#endif // CIRCT_CONVERSION_DATAPATHTOSMT_H
|
|
@ -26,6 +26,7 @@
|
|||
#include "circt/Conversion/CombToSMT.h"
|
||||
#include "circt/Conversion/ConvertToArcs.h"
|
||||
#include "circt/Conversion/DCToHW.h"
|
||||
#include "circt/Conversion/DatapathToSMT.h"
|
||||
#include "circt/Conversion/ExportChiselInterface.h"
|
||||
#include "circt/Conversion/ExportVerilog.h"
|
||||
#include "circt/Conversion/FIRRTLToHW.h"
|
||||
|
|
|
@ -866,4 +866,15 @@ def ConvertCombToDatapath: Pass<"convert-comb-to-datapath", "hw::HWModuleOp"> {
|
|||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConvertDatapathToSMT
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertDatapathToSMT : Pass<"convert-datapath-to-smt"> {
|
||||
let summary = "Convert datapath ops to SMT ops";
|
||||
let dependentDialects = [
|
||||
"mlir::smt::SMTDialect"
|
||||
];
|
||||
}
|
||||
|
||||
#endif // CIRCT_CONVERSION_PASSES_TD
|
||||
|
|
|
@ -14,6 +14,7 @@ add_circt_public_c_api_library(CIRCTCAPIConversion
|
|||
CIRCTCombToLLVM
|
||||
CIRCTCombToSMT
|
||||
CIRCTConvertToArcs
|
||||
CIRCTDatapathToSMT
|
||||
CIRCTDCToHW
|
||||
CIRCTExportChiselInterface
|
||||
CIRCTExportVerilog
|
||||
|
|
|
@ -9,6 +9,7 @@ add_subdirectory(CombToDatapath)
|
|||
add_subdirectory(CombToLLVM)
|
||||
add_subdirectory(CombToSMT)
|
||||
add_subdirectory(ConvertToArcs)
|
||||
add_subdirectory(DatapathToSMT)
|
||||
add_subdirectory(DCToHW)
|
||||
add_subdirectory(ExportAIGER)
|
||||
add_subdirectory(ExportChiselInterface)
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
add_circt_conversion_library(CIRCTDatapathToSMT
|
||||
DatapathToSMT.cpp
|
||||
|
||||
DEPENDS
|
||||
CIRCTConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
CIRCTDatapath
|
||||
CIRCTComb
|
||||
CIRCTHWToSMT
|
||||
MLIRSMT
|
||||
MLIRTransforms
|
||||
)
|
|
@ -0,0 +1,158 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "circt/Conversion/DatapathToSMT.h"
|
||||
#include "circt/Conversion/HWToSMT.h"
|
||||
#include "circt/Dialect/Datapath/DatapathOps.h"
|
||||
#include "mlir/Dialect/SMT/IR/SMTOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace circt {
|
||||
#define GEN_PASS_DEF_CONVERTDATAPATHTOSMT
|
||||
#include "circt/Conversion/Passes.h.inc"
|
||||
} // namespace circt
|
||||
|
||||
using namespace mlir;
|
||||
using namespace circt;
|
||||
using namespace datapath;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
// Lower to an SMT assertion that summing the results is equivalent to summing
|
||||
// the compress inputs
|
||||
// d:2 = compress(a, b, c) ->
|
||||
// assert(d#0 + d#1 == a + b + c)
|
||||
struct CompressOpConversion : OpConversionPattern<CompressOp> {
|
||||
using OpConversionPattern<CompressOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
ValueRange operands = adaptor.getOperands();
|
||||
ValueRange results = op.getResults();
|
||||
|
||||
// Sum operands
|
||||
Value operandRunner = operands[0];
|
||||
for (Value operand : operands.drop_front())
|
||||
operandRunner =
|
||||
rewriter.create<smt::BVAddOp>(op.getLoc(), operandRunner, operand);
|
||||
|
||||
// Create free variables
|
||||
SmallVector<Value, 2> newResults;
|
||||
newResults.reserve(results.size());
|
||||
for (Value result : results) {
|
||||
auto declareFunOp = rewriter.create<smt::DeclareFunOp>(
|
||||
op.getLoc(), typeConverter->convertType(result.getType()));
|
||||
newResults.push_back(declareFunOp.getResult());
|
||||
}
|
||||
|
||||
// Sum the free variables
|
||||
Value resultRunner = newResults.front();
|
||||
for (auto freeVar : llvm::drop_begin(newResults, 1))
|
||||
resultRunner =
|
||||
rewriter.create<smt::BVAddOp>(op.getLoc(), resultRunner, freeVar);
|
||||
|
||||
// Assert sum operands == sum results (free variables)
|
||||
auto premise =
|
||||
rewriter.create<smt::EqOp>(op.getLoc(), operandRunner, resultRunner);
|
||||
// Encode via an assertion (could be relaxed to an assumption).
|
||||
rewriter.create<smt::AssertOp>(op.getLoc(), premise);
|
||||
|
||||
if (newResults.size() != results.size())
|
||||
return rewriter.notifyMatchFailure(op, "expected same number of results");
|
||||
|
||||
rewriter.replaceOp(op, newResults);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Lower to an SMT assertion that summing the results is equivalent to the
|
||||
// product of the partial_product inputs
|
||||
// c:<N> = partial_product(a, b) ->
|
||||
// assert(c#0 + ... + c#<N-1> == a * b)
|
||||
struct PartialProductOpConversion : OpConversionPattern<PartialProductOp> {
|
||||
using OpConversionPattern<PartialProductOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(PartialProductOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
ValueRange operands = adaptor.getOperands();
|
||||
ValueRange results = op.getResults();
|
||||
|
||||
// Multiply the operands
|
||||
auto mulResult =
|
||||
rewriter.create<smt::BVMulOp>(op.getLoc(), operands[0], operands[1]);
|
||||
|
||||
// Create free variables
|
||||
SmallVector<Value, 2> newResults;
|
||||
newResults.reserve(results.size());
|
||||
for (Value result : results) {
|
||||
auto declareFunOp = rewriter.create<smt::DeclareFunOp>(
|
||||
op.getLoc(), typeConverter->convertType(result.getType()));
|
||||
newResults.push_back(declareFunOp.getResult());
|
||||
}
|
||||
|
||||
// Sum the free variables
|
||||
Value resultRunner = newResults.front();
|
||||
for (auto freeVar : llvm::drop_begin(newResults, 1))
|
||||
resultRunner =
|
||||
rewriter.create<smt::BVAddOp>(op.getLoc(), resultRunner, freeVar);
|
||||
|
||||
// Assert product of operands == sum results (free variables)
|
||||
auto premise =
|
||||
rewriter.create<smt::EqOp>(op.getLoc(), mulResult, resultRunner);
|
||||
// Encode via an assertion (could be relaxed to an assumption).
|
||||
rewriter.create<smt::AssertOp>(op.getLoc(), premise);
|
||||
|
||||
if (newResults.size() != results.size())
|
||||
return rewriter.notifyMatchFailure(op, "expected same number of results");
|
||||
|
||||
rewriter.replaceOp(op, newResults);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Convert Datapath to SMT pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct ConvertDatapathToSMTPass
|
||||
: public circt::impl::ConvertDatapathToSMTBase<ConvertDatapathToSMTPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void circt::populateDatapathToSMTConversionPatterns(
|
||||
TypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.add<CompressOpConversion, PartialProductOpConversion>(
|
||||
converter, patterns.getContext());
|
||||
}
|
||||
|
||||
void ConvertDatapathToSMTPass::runOnOperation() {
|
||||
ConversionTarget target(getContext());
|
||||
target.addIllegalDialect<datapath::DatapathDialect>();
|
||||
target.addLegalDialect<smt::SMTDialect>();
|
||||
|
||||
RewritePatternSet patterns(&getContext());
|
||||
TypeConverter converter;
|
||||
populateHWToSMTTypeConverter(converter);
|
||||
populateDatapathToSMTConversionPatterns(converter, patterns);
|
||||
|
||||
if (failed(mlir::applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
// RUN: circt-opt %s --convert-datapath-to-smt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @compressor
|
||||
hw.module @compressor(in %a : i4, in %b : i4, in %c : i4, out carry : i4, out save : i4) {
|
||||
// CHECK-NEXT: %[[C:.+]] = builtin.unrealized_conversion_cast %c : i4 to !smt.bv<4>
|
||||
// CHECK-NEXT: %[[B:.+]] = builtin.unrealized_conversion_cast %b : i4 to !smt.bv<4>
|
||||
// CHECK-NEXT: %[[A:.+]] = builtin.unrealized_conversion_cast %a : i4 to !smt.bv<4>
|
||||
// CHECK-NEXT: %[[AB:.+]] = smt.bv.add %[[A]], %[[B]] : !smt.bv<4>
|
||||
// CHECK-NEXT: %[[INS:.+]] = smt.bv.add %[[AB]], %[[C]] : !smt.bv<4>
|
||||
// CHECK-NEXT: %[[COMP0:.+]] = smt.declare_fun : !smt.bv<4>
|
||||
// CHECK-NEXT: %[[COMP0_BV:.+]] = builtin.unrealized_conversion_cast %[[COMP0]] : !smt.bv<4> to i4
|
||||
// CHECK-NEXT: %[[COMP1:.+]] = smt.declare_fun : !smt.bv<4>
|
||||
// CHECK-NEXT: %[[COMP1_BV:.+]] = builtin.unrealized_conversion_cast %7 : !smt.bv<4> to i4
|
||||
// CHECK-NEXT: %[[OUT:.+]] = smt.bv.add %[[COMP0]], %[[COMP1]] : !smt.bv<4>
|
||||
// CHECK-NEXT: %[[P:.+]] = smt.eq %[[INS]], %[[OUT]] : !smt.bv<4>
|
||||
// CHECK-NEXT: smt.assert %[[P]]
|
||||
// CHECK-NEXT: hw.output %[[COMP0_BV]], %[[COMP1_BV]] : i4, i4
|
||||
%0:2 = datapath.compress %a, %b, %c : i4 [3 -> 2]
|
||||
hw.output %0#0, %0#1 : i4, i4
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @partial_product
|
||||
hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) {
|
||||
//CHECK-NEXT: %[[B:.+]] = builtin.unrealized_conversion_cast %b : i3 to !smt.bv<3>
|
||||
//CHECK-NEXT: %[[A:.+]] = builtin.unrealized_conversion_cast %a : i3 to !smt.bv<3>
|
||||
//CHECK-NEXT: %[[MUL:.+]] = smt.bv.mul %[[A]], %[[B]] : !smt.bv<3>
|
||||
//CHECK-NEXT: %[[PP0:.+]] = smt.declare_fun : !smt.bv<3>
|
||||
//CHECK-NEXT: %[[PP0_BV:.+]] = builtin.unrealized_conversion_cast %[[PP0]] : !smt.bv<3> to i3
|
||||
//CHECK-NEXT: %[[PP1:.+]] = smt.declare_fun : !smt.bv<3>
|
||||
//CHECK-NEXT: %[[PP1_BV:.+]] = builtin.unrealized_conversion_cast %[[PP1]] : !smt.bv<3> to i3
|
||||
//CHECK-NEXT: %[[PP2:.+]] = smt.declare_fun : !smt.bv<3>
|
||||
//CHECK-NEXT: %[[PP2_BV:.+]] = builtin.unrealized_conversion_cast %[[PP2]] : !smt.bv<3> to i3
|
||||
//CHECK-NEXT: %[[ADD01:.+]] = smt.bv.add %[[PP0]], %[[PP1]] : !smt.bv<3>
|
||||
//CHECK-NEXT: %[[ADD012:.+]] = smt.bv.add %[[ADD01]], %[[PP2]] : !smt.bv<3>
|
||||
//CHECK-NEXT: %[[P:.+]] = smt.eq %[[MUL]], %[[ADD012]] : !smt.bv<3>
|
||||
//CHECK-NEXT: smt.assert %[[P]]
|
||||
//CHECK-NEXT: hw.output %[[PP0_BV]], %[[PP1_BV]], %[[PP2_BV]] : i3, i3, i3
|
||||
%0:3 = datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3)
|
||||
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
|
||||
}
|
Loading…
Reference in New Issue