mirror of https://github.com/llvm/circt.git
159 lines
5.6 KiB
C++
159 lines
5.6 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Conversion/DatapathToSMT.h"
|
|
#include "circt/Conversion/HWToSMT.h"
|
|
#include "circt/Dialect/Datapath/DatapathOps.h"
|
|
#include "mlir/Dialect/SMT/IR/SMTOps.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
namespace circt {
|
|
#define GEN_PASS_DEF_CONVERTDATAPATHTOSMT
|
|
#include "circt/Conversion/Passes.h.inc"
|
|
} // namespace circt
|
|
|
|
using namespace mlir;
|
|
using namespace circt;
|
|
using namespace datapath;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Conversion patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
// Lower to an SMT assertion that summing the results is equivalent to summing
|
|
// the compress inputs
|
|
// d:2 = compress(a, b, c) ->
|
|
// assert(d#0 + d#1 == a + b + c)
|
|
struct CompressOpConversion : OpConversionPattern<CompressOp> {
|
|
using OpConversionPattern<CompressOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
ValueRange operands = adaptor.getOperands();
|
|
ValueRange results = op.getResults();
|
|
|
|
// Sum operands
|
|
Value operandRunner = operands[0];
|
|
for (Value operand : operands.drop_front())
|
|
operandRunner =
|
|
smt::BVAddOp::create(rewriter, op.getLoc(), operandRunner, operand);
|
|
|
|
// Create free variables
|
|
SmallVector<Value, 2> newResults;
|
|
newResults.reserve(results.size());
|
|
for (Value result : results) {
|
|
auto declareFunOp = smt::DeclareFunOp::create(
|
|
rewriter, op.getLoc(), typeConverter->convertType(result.getType()));
|
|
newResults.push_back(declareFunOp.getResult());
|
|
}
|
|
|
|
// Sum the free variables
|
|
Value resultRunner = newResults.front();
|
|
for (auto freeVar : llvm::drop_begin(newResults, 1))
|
|
resultRunner =
|
|
smt::BVAddOp::create(rewriter, op.getLoc(), resultRunner, freeVar);
|
|
|
|
// Assert sum operands == sum results (free variables)
|
|
auto premise =
|
|
smt::EqOp::create(rewriter, op.getLoc(), operandRunner, resultRunner);
|
|
// Encode via an assertion (could be relaxed to an assumption).
|
|
smt::AssertOp::create(rewriter, op.getLoc(), premise);
|
|
|
|
if (newResults.size() != results.size())
|
|
return rewriter.notifyMatchFailure(op, "expected same number of results");
|
|
|
|
rewriter.replaceOp(op, newResults);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Lower to an SMT assertion that summing the results is equivalent to the
|
|
// product of the partial_product inputs
|
|
// c:<N> = partial_product(a, b) ->
|
|
// assert(c#0 + ... + c#<N-1> == a * b)
|
|
struct PartialProductOpConversion : OpConversionPattern<PartialProductOp> {
|
|
using OpConversionPattern<PartialProductOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(PartialProductOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
ValueRange operands = adaptor.getOperands();
|
|
ValueRange results = op.getResults();
|
|
|
|
// Multiply the operands
|
|
auto mulResult =
|
|
smt::BVMulOp::create(rewriter, op.getLoc(), operands[0], operands[1]);
|
|
|
|
// Create free variables
|
|
SmallVector<Value, 2> newResults;
|
|
newResults.reserve(results.size());
|
|
for (Value result : results) {
|
|
auto declareFunOp = smt::DeclareFunOp::create(
|
|
rewriter, op.getLoc(), typeConverter->convertType(result.getType()));
|
|
newResults.push_back(declareFunOp.getResult());
|
|
}
|
|
|
|
// Sum the free variables
|
|
Value resultRunner = newResults.front();
|
|
for (auto freeVar : llvm::drop_begin(newResults, 1))
|
|
resultRunner =
|
|
smt::BVAddOp::create(rewriter, op.getLoc(), resultRunner, freeVar);
|
|
|
|
// Assert product of operands == sum results (free variables)
|
|
auto premise =
|
|
smt::EqOp::create(rewriter, op.getLoc(), mulResult, resultRunner);
|
|
// Encode via an assertion (could be relaxed to an assumption).
|
|
smt::AssertOp::create(rewriter, op.getLoc(), premise);
|
|
|
|
if (newResults.size() != results.size())
|
|
return rewriter.notifyMatchFailure(op, "expected same number of results");
|
|
|
|
rewriter.replaceOp(op, newResults);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Convert Datapath to SMT pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct ConvertDatapathToSMTPass
|
|
: public circt::impl::ConvertDatapathToSMTBase<ConvertDatapathToSMTPass> {
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void circt::populateDatapathToSMTConversionPatterns(
|
|
TypeConverter &converter, RewritePatternSet &patterns) {
|
|
patterns.add<CompressOpConversion, PartialProductOpConversion>(
|
|
converter, patterns.getContext());
|
|
}
|
|
|
|
void ConvertDatapathToSMTPass::runOnOperation() {
|
|
ConversionTarget target(getContext());
|
|
target.addIllegalDialect<datapath::DatapathDialect>();
|
|
target.addLegalDialect<smt::SMTDialect>();
|
|
|
|
RewritePatternSet patterns(&getContext());
|
|
TypeConverter converter;
|
|
populateHWToSMTTypeConverter(converter);
|
|
populateDatapathToSMTConversionPatterns(converter, patterns);
|
|
|
|
if (failed(mlir::applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|