circt/lib/Conversion/HWToSMT/HWToSMT.cpp

363 lines
14 KiB
C++

//===- HWToSMT.cpp --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Conversion/HWToSMT.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SMT/IR/SMTOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace circt {
#define GEN_PASS_DEF_CONVERTHWTOSMT
#include "circt/Conversion/Passes.h.inc"
} // namespace circt
using namespace circt;
using namespace hw;
//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//
namespace {
/// Lower a hw::ConstantOp operation to smt::BVConstantOp
struct HWConstantOpConversion : OpConversionPattern<ConstantOp> {
using OpConversionPattern<ConstantOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getValue().getBitWidth() < 1)
return rewriter.notifyMatchFailure(op.getLoc(),
"0-bit constants not supported");
rewriter.replaceOpWithNewOp<mlir::smt::BVConstantOp>(op,
adaptor.getValue());
return success();
}
};
/// Lower a hw::HWModuleOp operation to func::FuncOp.
struct HWModuleOpConversion : OpConversionPattern<HWModuleOp> {
using OpConversionPattern<HWModuleOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(HWModuleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto funcTy = op.getModuleType().getFuncType();
SmallVector<Type> inputTypes, resultTypes;
if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
return failure();
if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
return failure();
if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
return failure();
auto funcOp = mlir::func::FuncOp::create(
rewriter, op.getLoc(), adaptor.getSymNameAttr(),
rewriter.getFunctionType(inputTypes, resultTypes));
rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
rewriter.eraseOp(op);
return success();
}
};
/// Lower a hw::OutputOp operation to func::ReturnOp.
struct OutputOpConversion : OpConversionPattern<OutputOp> {
using OpConversionPattern<OutputOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(OutputOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
return success();
}
};
/// Lower a hw::InstanceOp operation to func::CallOp.
struct InstanceOpConversion : OpConversionPattern<InstanceOp> {
using OpConversionPattern<InstanceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
return failure();
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
return success();
}
};
/// Lower a hw::ArrayCreateOp operation to smt::DeclareFun and an
/// smt::ArrayStoreOp for each operand.
struct ArrayCreateOpConversion : OpConversionPattern<ArrayCreateOp> {
using OpConversionPattern<ArrayCreateOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArrayCreateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type arrTy = typeConverter->convertType(op.getType());
if (!arrTy)
return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
unsigned width = adaptor.getInputs().size();
Value arr = mlir::smt::DeclareFunOp::create(rewriter, loc, arrTy);
for (auto [i, el] : llvm::enumerate(adaptor.getInputs())) {
Value idx = mlir::smt::BVConstantOp::create(rewriter, loc, width - i - 1,
llvm::Log2_64_Ceil(width));
arr = mlir::smt::ArrayStoreOp::create(rewriter, loc, arr, idx, el);
}
rewriter.replaceOp(op, arr);
return success();
}
};
/// Lower a hw::ArrayGetOp operation to smt::ArraySelectOp
struct ArrayGetOpConversion : OpConversionPattern<ArrayGetOp> {
using OpConversionPattern<ArrayGetOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
unsigned numElements =
cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
Type type = typeConverter->convertType(op.getType());
if (!type)
return rewriter.notifyMatchFailure(op.getLoc(),
"unsupported array element type");
Value oobVal = mlir::smt::DeclareFunOp::create(rewriter, loc, type);
Value numElementsVal = mlir::smt::BVConstantOp::create(
rewriter, loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
Value inBounds = mlir::smt::BVCmpOp::create(
rewriter, loc, mlir::smt::BVCmpPredicate::ule, adaptor.getIndex(),
numElementsVal);
Value indexed = mlir::smt::ArraySelectOp::create(
rewriter, loc, adaptor.getInput(), adaptor.getIndex());
rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, indexed,
oobVal);
return success();
}
};
/// Lower a hw::ArrayInjectOp operation to smt::ArrayStoreOp.
struct ArrayInjectOpConversion : OpConversionPattern<ArrayInjectOp> {
using OpConversionPattern<ArrayInjectOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArrayInjectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
unsigned numElements =
cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
Type arrType = typeConverter->convertType(op.getType());
if (!arrType)
return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
// Check if the index is within bounds
Value numElementsVal = mlir::smt::BVConstantOp::create(
rewriter, loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
Value inBounds = mlir::smt::BVCmpOp::create(
rewriter, loc, mlir::smt::BVCmpPredicate::ule, adaptor.getIndex(),
numElementsVal);
// Store the element at the given index
Value stored = mlir::smt::ArrayStoreOp::create(
rewriter, loc, adaptor.getInput(), adaptor.getIndex(),
adaptor.getElement());
// Return the original array if out of bounds, otherwise return the new
// array
rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, stored,
adaptor.getInput());
return success();
}
};
/// Remove redundant (seq::FromClock and seq::ToClock) ops.
template <typename OpTy>
struct ReplaceWithInput : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;
LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert HW to SMT pass
//===----------------------------------------------------------------------===//
namespace {
struct ConvertHWToSMTPass
: public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
void runOnOperation() override;
};
} // namespace
void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
// The semantics of the builtin integer at the CIRCT core level is currently
// not very well defined. It is used for two-valued, four-valued, and possible
// other multi-valued logic. Here, we interpret it as two-valued for now.
// From a formal perspective, CIRCT would ideally define its own types for
// two-valued, four-valued, nine-valued (etc.) logic each. In MLIR upstream
// the integer type also carries poison information (which we don't have in
// CIRCT?).
converter.addConversion([](IntegerType type) -> std::optional<Type> {
if (type.getWidth() <= 0)
return std::nullopt;
return mlir::smt::BitVectorType::get(type.getContext(), type.getWidth());
});
converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
return mlir::smt::BitVectorType::get(type.getContext(), 1);
});
converter.addConversion([&](ArrayType type) -> std::optional<Type> {
auto rangeType = converter.convertType(type.getElementType());
if (!rangeType)
return {};
auto domainType = mlir::smt::BitVectorType::get(
type.getContext(), llvm::Log2_64_Ceil(type.getNumElements()));
return mlir::smt::ArrayType::get(type.getContext(), domainType, rangeType);
});
// Default target materialization to convert from illegal types to legal
// types, e.g., at the boundary of an inlined child block.
converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Value {
return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType,
inputs)
->getResult(0);
});
// Convert a 'smt.bool'-typed value to a 'smt.bv<N>'-typed value
converter.addTargetMaterialization(
[&](OpBuilder &builder, mlir::smt::BitVectorType resultType,
ValueRange inputs, Location loc) -> Value {
if (inputs.size() != 1)
return Value();
if (!isa<mlir::smt::BoolType>(inputs[0].getType()))
return Value();
unsigned width = resultType.getWidth();
Value constZero =
mlir::smt::BVConstantOp::create(builder, loc, 0, width);
Value constOne =
mlir::smt::BVConstantOp::create(builder, loc, 1, width);
return mlir::smt::IteOp::create(builder, loc, inputs[0], constOne,
constZero);
});
// Convert an unrealized conversion cast from 'smt.bool' to i1
// into a direct conversion from 'smt.bool' to 'smt.bv<1>'.
converter.addTargetMaterialization(
[&](OpBuilder &builder, mlir::smt::BitVectorType resultType,
ValueRange inputs, Location loc) -> Value {
if (inputs.size() != 1 || resultType.getWidth() != 1)
return Value();
auto intType = dyn_cast<IntegerType>(inputs[0].getType());
if (!intType || intType.getWidth() != 1)
return Value();
auto castOp =
inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
if (!castOp || castOp.getInputs().size() != 1)
return Value();
if (!isa<mlir::smt::BoolType>(castOp.getInputs()[0].getType()))
return Value();
Value constZero = mlir::smt::BVConstantOp::create(builder, loc, 0, 1);
Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, 1);
return mlir::smt::IteOp::create(builder, loc, castOp.getInputs()[0],
constOne, constZero);
});
// Convert a 'smt.bv<1>'-typed value to a 'smt.bool'-typed value
converter.addTargetMaterialization(
[&](OpBuilder &builder, mlir::smt::BoolType resultType, ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return Value();
auto bvType = dyn_cast<mlir::smt::BitVectorType>(inputs[0].getType());
if (!bvType || bvType.getWidth() != 1)
return Value();
Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, 1);
return mlir::smt::EqOp::create(builder, loc, inputs[0], constOne);
});
// Default source materialization to convert from illegal types to legal
// types, e.g., at the boundary of an inlined child block.
converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Value {
return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType,
inputs)
->getResult(0);
});
}
void circt::populateHWToSMTConversionPatterns(TypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
ReplaceWithInput<seq::FromClockOp>, ArrayCreateOpConversion,
ArrayGetOpConversion, ArrayInjectOpConversion>(
converter, patterns.getContext());
}
void ConvertHWToSMTPass::runOnOperation() {
ConversionTarget target(getContext());
target.addIllegalDialect<hw::HWDialect>();
target.addIllegalOp<seq::FromClockOp>();
target.addIllegalOp<seq::ToClockOp>();
target.addLegalDialect<mlir::smt::SMTDialect>();
target.addLegalDialect<mlir::func::FuncDialect>();
RewritePatternSet patterns(&getContext());
TypeConverter converter;
populateHWToSMTTypeConverter(converter);
populateHWToSMTConversionPatterns(converter, patterns);
if (failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
// Sort the functions topologically because 'hw.module' has a graph region
// while 'func.func' is a regular SSACFG region. Real combinational cycles or
// pseudo cycles through module instances are not supported yet.
for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
// Skip functions that are definitely not the result of lowering from
// 'hw.module'
if (func.getBody().getBlocks().size() != 1)
continue;
mlir::sortTopologically(&func.getBody().front());
}
}