mirror of https://github.com/llvm/circt.git
1590 lines
66 KiB
C++
1590 lines
66 KiB
C++
//===- LowerSMTToZ3LLVM.cpp -----------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Conversion/SMTToZ3LLVM.h"
|
|
#include "circt/Support/Namespace.h"
|
|
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
|
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
|
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SMT/IR/SMTOps.h"
|
|
#include "mlir/IR/BuiltinDialect.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "lower-smt-to-z3-llvm"
|
|
|
|
namespace circt {
|
|
#define GEN_PASS_DEF_LOWERSMTTOZ3LLVM
|
|
#include "circt/Conversion/Passes.h.inc"
|
|
} // namespace circt
|
|
|
|
using namespace mlir;
|
|
using namespace circt;
|
|
using namespace smt;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SMTGlobalHandler implementation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SMTGlobalsHandler SMTGlobalsHandler::create(OpBuilder &builder,
|
|
ModuleOp module) {
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
builder.setInsertionPointToStart(module.getBody());
|
|
|
|
SymbolCache symCache;
|
|
symCache.addDefinitions(module);
|
|
Namespace names;
|
|
names.add(symCache);
|
|
|
|
Location loc = module.getLoc();
|
|
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
|
|
|
|
auto createGlobal = [&](StringRef namePrefix) {
|
|
auto global = LLVM::GlobalOp::create(
|
|
builder, loc, ptrTy, false, LLVM::Linkage::Internal,
|
|
names.newName(namePrefix), Attribute{}, /*alignment=*/8);
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.createBlock(&global.getInitializer());
|
|
Value res = LLVM::ZeroOp::create(builder, loc, ptrTy);
|
|
LLVM::ReturnOp::create(builder, loc, res);
|
|
return global;
|
|
};
|
|
|
|
auto ctxGlobal = createGlobal("ctx");
|
|
auto solverGlobal = createGlobal("solver");
|
|
|
|
return SMTGlobalsHandler(std::move(names), solverGlobal, ctxGlobal);
|
|
}
|
|
|
|
SMTGlobalsHandler::SMTGlobalsHandler(Namespace &&names,
|
|
mlir::LLVM::GlobalOp solver,
|
|
mlir::LLVM::GlobalOp ctx)
|
|
: solver(solver), ctx(ctx), names(names) {}
|
|
|
|
SMTGlobalsHandler::SMTGlobalsHandler(ModuleOp module,
|
|
mlir::LLVM::GlobalOp solver,
|
|
mlir::LLVM::GlobalOp ctx)
|
|
: solver(solver), ctx(ctx) {
|
|
SymbolCache symCache;
|
|
symCache.addDefinitions(module);
|
|
names.add(symCache);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Lowering Pattern Base
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
template <typename OpTy>
|
|
class SMTLoweringPattern : public OpConversionPattern<OpTy> {
|
|
public:
|
|
SMTLoweringPattern(const TypeConverter &typeConverter, MLIRContext *context,
|
|
SMTGlobalsHandler &globals,
|
|
const LowerSMTToZ3LLVMOptions &options)
|
|
: OpConversionPattern<OpTy>(typeConverter, context), globals(globals),
|
|
options(options) {}
|
|
|
|
private:
|
|
Value buildGlobalPtrToGlobal(OpBuilder &builder, Location loc,
|
|
LLVM::GlobalOp global,
|
|
DenseMap<Block *, Value> &cache) const {
|
|
Block *block = builder.getBlock();
|
|
if (auto iter = cache.find(block); iter != cache.end())
|
|
return iter->getSecond();
|
|
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToStart(block);
|
|
Value globalAddr = LLVM::AddressOfOp::create(builder, loc, global);
|
|
return cache[block] = LLVM::LoadOp::create(
|
|
builder, loc, LLVM::LLVMPointerType::get(builder.getContext()),
|
|
globalAddr);
|
|
}
|
|
|
|
protected:
|
|
/// A convenience function to get the pointer to the context from the 'global'
|
|
/// operation. The result is cached for each basic block, i.e., it is assumed
|
|
/// that this function is never called in the same basic block again at a
|
|
/// location (insertion point of the 'builder') not dominating all previous
|
|
/// locations this function was called at.
|
|
Value buildContextPtr(OpBuilder &builder, Location loc) const {
|
|
return buildGlobalPtrToGlobal(builder, loc, globals.ctx, globals.ctxCache);
|
|
}
|
|
|
|
/// A convenience function to get the pointer to the solver from the 'global'
|
|
/// operation. The result is cached for each basic block, i.e., it is assumed
|
|
/// that this function is never called in the same basic block again at a
|
|
/// location (insertion point of the 'builder') not dominating all previous
|
|
/// locations this function was called at.
|
|
Value buildSolverPtr(OpBuilder &builder, Location loc) const {
|
|
return buildGlobalPtrToGlobal(builder, loc, globals.solver,
|
|
globals.solverCache);
|
|
}
|
|
|
|
/// Create a `llvm.call` operation to a function with the given 'name' and
|
|
/// 'type'. If there does not already exist a (external) function with that
|
|
/// name create a matching external function declaration.
|
|
LLVM::CallOp buildCall(OpBuilder &builder, Location loc, StringRef name,
|
|
LLVM::LLVMFunctionType funcType,
|
|
ValueRange args) const {
|
|
auto &funcOp = globals.funcMap[builder.getStringAttr(name)];
|
|
if (!funcOp) {
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
auto module =
|
|
builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
|
builder.setInsertionPointToEnd(module.getBody());
|
|
auto funcOpResult = LLVM::lookupOrCreateFn(
|
|
builder, module, name, funcType.getParams(), funcType.getReturnType(),
|
|
funcType.getVarArg());
|
|
assert(succeeded(funcOpResult) && "expected to lookup or create printf");
|
|
funcOp = funcOpResult.value();
|
|
}
|
|
return LLVM::CallOp::create(builder, loc, funcOp, args);
|
|
}
|
|
|
|
/// Build a global constant for the given string and construct an 'addressof'
|
|
/// operation at the current 'builder' insertion point to get a pointer to it.
|
|
/// Multiple calls with the same string will reuse the same global. It is
|
|
/// guaranteed that the symbol of the global will be unique.
|
|
Value buildString(OpBuilder &builder, Location loc, StringRef str) const {
|
|
auto &global = globals.stringCache[builder.getStringAttr(str)];
|
|
if (!global) {
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
auto module =
|
|
builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
|
builder.setInsertionPointToEnd(module.getBody());
|
|
auto arrayTy =
|
|
LLVM::LLVMArrayType::get(builder.getI8Type(), str.size() + 1);
|
|
auto strAttr = builder.getStringAttr(str.str() + '\00');
|
|
global = LLVM::GlobalOp::create(
|
|
builder, loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
|
globals.names.newName("str"), strAttr);
|
|
}
|
|
return LLVM::AddressOfOp::create(builder, loc, global);
|
|
}
|
|
/// Most API functions require a pointer to the the Z3 context object as the
|
|
/// first argument. This helper function prepends this pointer value to the
|
|
/// call for convenience.
|
|
LLVM::CallOp buildAPICallWithContext(OpBuilder &builder, Location loc,
|
|
StringRef name, Type returnType,
|
|
ValueRange args = {}) const {
|
|
auto ctx = buildContextPtr(builder, loc);
|
|
SmallVector<Value> arguments;
|
|
arguments.emplace_back(ctx);
|
|
arguments.append(SmallVector<Value>(args));
|
|
return buildCall(
|
|
builder, loc, name,
|
|
LLVM::LLVMFunctionType::get(
|
|
returnType, SmallVector<Type>(ValueRange(arguments).getTypes())),
|
|
arguments);
|
|
}
|
|
|
|
/// Most API functions we need to call return a 'Z3_AST' object which is a
|
|
/// pointer in LLVM. This helper function simplifies calling those API
|
|
/// functions.
|
|
Value buildPtrAPICall(OpBuilder &builder, Location loc, StringRef name,
|
|
ValueRange args = {}) const {
|
|
return buildAPICallWithContext(
|
|
builder, loc, name,
|
|
LLVM::LLVMPointerType::get(builder.getContext()), args)
|
|
->getResult(0);
|
|
}
|
|
|
|
/// Build a value representing the SMT sort given with 'type'.
|
|
Value buildSort(OpBuilder &builder, Location loc, Type type) const {
|
|
// NOTE: if a type not handled by this switch is passed, an assertion will
|
|
// be triggered.
|
|
return TypeSwitch<Type, Value>(type)
|
|
.Case([&](smt::IntType ty) {
|
|
return buildPtrAPICall(builder, loc, "Z3_mk_int_sort");
|
|
})
|
|
.Case([&](smt::BitVectorType ty) {
|
|
Value bitwidth = LLVM::ConstantOp::create(
|
|
builder, loc, builder.getI32Type(), ty.getWidth());
|
|
return buildPtrAPICall(builder, loc, "Z3_mk_bv_sort", {bitwidth});
|
|
})
|
|
.Case([&](smt::BoolType ty) {
|
|
return buildPtrAPICall(builder, loc, "Z3_mk_bool_sort");
|
|
})
|
|
.Case([&](smt::SortType ty) {
|
|
Value str = buildString(builder, loc, ty.getIdentifier());
|
|
Value sym =
|
|
buildPtrAPICall(builder, loc, "Z3_mk_string_symbol", {str});
|
|
return buildPtrAPICall(builder, loc, "Z3_mk_uninterpreted_sort",
|
|
{sym});
|
|
})
|
|
.Case([&](smt::ArrayType ty) {
|
|
return buildPtrAPICall(builder, loc, "Z3_mk_array_sort",
|
|
{buildSort(builder, loc, ty.getDomainType()),
|
|
buildSort(builder, loc, ty.getRangeType())});
|
|
});
|
|
}
|
|
|
|
SMTGlobalsHandler &globals;
|
|
const LowerSMTToZ3LLVMOptions &options;
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Lowering Patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// The 'smt.declare_fun' operation is used to declare both constants and
|
|
/// functions. The Z3 API, however, uses two different functions. Therefore,
|
|
/// depending on the result type of this operation, one of the following two
|
|
/// API functions is used to create the symbolic value:
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_fresh_const(Z3_context c, Z3_string prefix, Z3_sort ty);
|
|
/// Z3_func_decl Z3_API Z3_mk_fresh_func_decl(
|
|
/// Z3_context c, Z3_string prefix, unsigned domain_size,
|
|
/// Z3_sort const domain[], Z3_sort range);
|
|
/// ```
|
|
struct DeclareFunOpLowering : public SMTLoweringPattern<DeclareFunOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(DeclareFunOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
|
|
// Create the name prefix.
|
|
Value prefix;
|
|
if (adaptor.getNamePrefix())
|
|
prefix = buildString(rewriter, loc, *adaptor.getNamePrefix());
|
|
else
|
|
prefix = LLVM::ZeroOp::create(rewriter, loc,
|
|
LLVM::LLVMPointerType::get(getContext()));
|
|
|
|
// Handle the constant value case.
|
|
if (!isa<SMTFuncType>(op.getType())) {
|
|
Value sort = buildSort(rewriter, loc, op.getType());
|
|
Value constDecl =
|
|
buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_const", {prefix, sort});
|
|
rewriter.replaceOp(op, constDecl);
|
|
return success();
|
|
}
|
|
|
|
// Otherwise, we declare a function.
|
|
Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
|
|
auto funcType = cast<SMTFuncType>(op.getResult().getType());
|
|
Value rangeSort = buildSort(rewriter, loc, funcType.getRangeType());
|
|
|
|
Type arrTy =
|
|
LLVM::LLVMArrayType::get(llvmPtrTy, funcType.getDomainTypes().size());
|
|
|
|
Value domain = LLVM::UndefOp::create(rewriter, loc, arrTy);
|
|
for (auto [i, ty] : llvm::enumerate(funcType.getDomainTypes())) {
|
|
Value sort = buildSort(rewriter, loc, ty);
|
|
domain = LLVM::InsertValueOp::create(rewriter, loc, domain, sort, i);
|
|
}
|
|
|
|
Value one =
|
|
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
|
|
Value domainStorage =
|
|
LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, arrTy, one);
|
|
LLVM::StoreOp::create(rewriter, loc, domain, domainStorage);
|
|
|
|
Value domainSize = LLVM::ConstantOp::create(
|
|
rewriter, loc, rewriter.getI32Type(), funcType.getDomainTypes().size());
|
|
Value decl =
|
|
buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_func_decl",
|
|
{prefix, domainSize, domainStorage, rangeSort});
|
|
|
|
rewriter.replaceOp(op, decl);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower the 'smt.apply_func' operation to Z3 API calls of the form:
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_app(Z3_context c, Z3_func_decl d,
|
|
/// unsigned num_args, Z3_ast const args[]);
|
|
/// ```
|
|
struct ApplyFuncOpLowering : public SMTLoweringPattern<ApplyFuncOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ApplyFuncOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
|
|
Type arrTy = LLVM::LLVMArrayType::get(llvmPtrTy, adaptor.getArgs().size());
|
|
|
|
// Create an array of the function arguments.
|
|
Value domain = LLVM::UndefOp::create(rewriter, loc, arrTy);
|
|
for (auto [i, arg] : llvm::enumerate(adaptor.getArgs()))
|
|
domain = LLVM::InsertValueOp::create(rewriter, loc, domain, arg, i);
|
|
|
|
// Store the array on the stack.
|
|
Value one =
|
|
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
|
|
Value domainStorage =
|
|
LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, arrTy, one);
|
|
LLVM::StoreOp::create(rewriter, loc, domain, domainStorage);
|
|
|
|
// Call the API function with a pointer to the function, the number of
|
|
// arguments, and the pointer to the arguments stored on the stack.
|
|
Value domainSize = LLVM::ConstantOp::create(
|
|
rewriter, loc, rewriter.getI32Type(), adaptor.getArgs().size());
|
|
Value returnVal =
|
|
buildPtrAPICall(rewriter, loc, "Z3_mk_app",
|
|
{adaptor.getFunc(), domainSize, domainStorage});
|
|
rewriter.replaceOp(op, returnVal);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower the `smt.bv.constant` operation to either
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_unsigned_int64(Z3_context c, uint64_t v, Z3_sort ty);
|
|
/// ```
|
|
/// if the bit-vector fits into a 64-bit integer or convert it to a string and
|
|
/// use the sligtly slower but arbitrary precision API function:
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
|
|
/// ```
|
|
/// Note that there is also an API function taking an array of booleans, and
|
|
/// while those are typically compiled to 'i8' in LLVM they don't necessarily
|
|
/// have to (I think).
|
|
struct BVConstantOpLowering : public SMTLoweringPattern<smt::BVConstantOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(smt::BVConstantOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
unsigned width = op.getType().getWidth();
|
|
auto bvSort = buildSort(rewriter, loc, op.getResult().getType());
|
|
APInt val = adaptor.getValue().getValue();
|
|
|
|
if (width <= 64) {
|
|
Value bvConst = LLVM::ConstantOp::create(
|
|
rewriter, loc, rewriter.getI64Type(), val.getZExtValue());
|
|
Value res = buildPtrAPICall(rewriter, loc, "Z3_mk_unsigned_int64",
|
|
{bvConst, bvSort});
|
|
rewriter.replaceOp(op, res);
|
|
return success();
|
|
}
|
|
|
|
std::string str;
|
|
llvm::raw_string_ostream stream(str);
|
|
stream << val;
|
|
Value bvString = buildString(rewriter, loc, str);
|
|
Value bvNumeral =
|
|
buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {bvString, bvSort});
|
|
|
|
rewriter.replaceOp(op, bvNumeral);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Some of the Z3 API supports a variadic number of operands for some
|
|
/// operations (in particular if the expansion would lead to a super-linear
|
|
/// increase in operations such as with the ':pairwise' attribute). Those API
|
|
/// calls take an 'unsigned' argument indicating the size of an array of
|
|
/// pointers to the operands.
|
|
template <typename SourceTy>
|
|
struct VariadicSMTPattern : public SMTLoweringPattern<SourceTy> {
|
|
using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
|
|
|
|
VariadicSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
|
|
SMTGlobalsHandler &globals,
|
|
const LowerSMTToZ3LLVMOptions &options,
|
|
StringRef apiFuncName, unsigned minNumArgs)
|
|
: SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
|
|
apiFuncName(apiFuncName), minNumArgs(minNumArgs) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SourceTy op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
if (adaptor.getOperands().size() < minNumArgs)
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
Value numOperands = LLVM::ConstantOp::create(
|
|
rewriter, loc, rewriter.getI32Type(), op->getNumOperands());
|
|
Value constOne =
|
|
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
|
|
Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
|
|
Type arrTy = LLVM::LLVMArrayType::get(ptrTy, op->getNumOperands());
|
|
Value storage =
|
|
LLVM::AllocaOp::create(rewriter, loc, ptrTy, arrTy, constOne);
|
|
Value array = LLVM::UndefOp::create(rewriter, loc, arrTy);
|
|
|
|
for (auto [i, operand] : llvm::enumerate(adaptor.getOperands()))
|
|
array = LLVM::InsertValueOp::create(rewriter, loc, array, operand,
|
|
ArrayRef<int64_t>{(int64_t)i});
|
|
|
|
LLVM::StoreOp::create(rewriter, loc, array, storage);
|
|
|
|
rewriter.replaceOp(op,
|
|
SMTLoweringPattern<SourceTy>::buildPtrAPICall(
|
|
rewriter, loc, apiFuncName, {numOperands, storage}));
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
StringRef apiFuncName;
|
|
unsigned minNumArgs;
|
|
};
|
|
|
|
/// Lower an SMT operation to a function call with the name 'apiFuncName' with
|
|
/// arguments matching the operands one-to-one.
|
|
template <typename SourceTy>
|
|
struct OneToOneSMTPattern : public SMTLoweringPattern<SourceTy> {
|
|
using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
|
|
|
|
OneToOneSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
|
|
SMTGlobalsHandler &globals,
|
|
const LowerSMTToZ3LLVMOptions &options,
|
|
StringRef apiFuncName, unsigned numOperands)
|
|
: SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
|
|
apiFuncName(apiFuncName), numOperands(numOperands) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SourceTy op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
if (adaptor.getOperands().size() != numOperands)
|
|
return failure();
|
|
|
|
rewriter.replaceOp(
|
|
op, SMTLoweringPattern<SourceTy>::buildPtrAPICall(
|
|
rewriter, op.getLoc(), apiFuncName, adaptor.getOperands()));
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
StringRef apiFuncName;
|
|
unsigned numOperands;
|
|
};
|
|
|
|
/// A pattern to lower SMT operations with a variadic number of operands
|
|
/// modelling the ':chainable' attribute in SMT to binary operations.
|
|
template <typename SourceTy>
|
|
class LowerChainableSMTPattern : public SMTLoweringPattern<SourceTy> {
|
|
using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
|
|
using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SourceTy op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
if (adaptor.getOperands().size() <= 2)
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
SmallVector<Value> elements;
|
|
for (int i = 1, e = adaptor.getOperands().size(); i < e; ++i) {
|
|
Value val = SourceTy::create(
|
|
rewriter, loc, op->getResultTypes(),
|
|
ValueRange{adaptor.getOperands()[i - 1], adaptor.getOperands()[i]});
|
|
elements.push_back(val);
|
|
}
|
|
rewriter.replaceOpWithNewOp<smt::AndOp>(op, elements);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// A pattern to lower SMT operations with a variadic number of operands
|
|
/// modelling the `:left-assoc` attribute to a sequence of binary operators.
|
|
template <typename SourceTy>
|
|
class LowerLeftAssocSMTPattern : public SMTLoweringPattern<SourceTy> {
|
|
using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
|
|
using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SourceTy op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
if (adaptor.getOperands().size() <= 2)
|
|
return rewriter.notifyMatchFailure(op, "must have at least two operands");
|
|
|
|
Value runner = adaptor.getOperands()[0];
|
|
for (Value val : adaptor.getOperands().drop_front())
|
|
runner = SourceTy::create(rewriter, op.getLoc(), op->getResultTypes(),
|
|
ValueRange{runner, val});
|
|
|
|
rewriter.replaceOp(op, runner);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// The 'smt.solver' operation has a region that corresponds to the lifetime of
|
|
/// the Z3 context and one solver instance created within this context.
|
|
/// To create a context, a Z3 configuration has to be built first and various
|
|
/// configuration parameters can be set before creating a context from it. Once
|
|
/// we have a context, we can create a solver and store a pointer to the context
|
|
/// and the solver in an LLVM global such that operations in the child region
|
|
/// have access to them. While the context created with `Z3_mk_context` takes
|
|
/// care of the reference counting of `Z3_AST` objects, it still requires manual
|
|
/// reference counting of `Z3_solver` objects, therefore, we need to increase
|
|
/// the ref. counter of the solver we get from `Z3_mk_solver` and must decrease
|
|
/// it again once we don't need it anymore. Finally, the configuration object
|
|
/// can be deleted.
|
|
/// ```
|
|
/// Z3_config Z3_API Z3_mk_config(void);
|
|
/// void Z3_API Z3_set_param_value(Z3_config c, Z3_string param_id,
|
|
/// Z3_string param_value);
|
|
/// Z3_context Z3_API Z3_mk_context(Z3_config c);
|
|
/// Z3_solver Z3_API Z3_mk_solver(Z3_context c);
|
|
/// void Z3_API Z3_solver_inc_ref(Z3_context c, Z3_solver s);
|
|
/// void Z3_API Z3_del_config(Z3_config c);
|
|
/// ```
|
|
/// At the end of the solver lifetime, we have to tell the context that we
|
|
/// don't need the solver anymore and delete the context itself.
|
|
/// ```
|
|
/// void Z3_API Z3_solver_dec_ref(Z3_context c, Z3_solver s);
|
|
/// void Z3_API Z3_del_context(Z3_context c);
|
|
/// ```
|
|
/// Note that the solver created here is a combined solver. There might be some
|
|
/// potential for optimization by creating more specialized solvers supported by
|
|
/// the Z3 API according the the kind of operations present in the body region.
|
|
struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SolverOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
auto ptrTy = LLVM::LLVMPointerType::get(getContext());
|
|
auto voidTy = LLVM::LLVMVoidType::get(getContext());
|
|
auto ptrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, ptrTy);
|
|
auto ptrPtrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy});
|
|
auto ptrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, ptrTy);
|
|
auto ptrPtrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy});
|
|
|
|
// Create the configuration.
|
|
Value config = buildCall(rewriter, loc, "Z3_mk_config",
|
|
LLVM::LLVMFunctionType::get(ptrTy, {}), {})
|
|
.getResult();
|
|
|
|
// In debug-mode, we enable proofs such that we can fetch one in the 'unsat'
|
|
// region of each 'smt.check' operation.
|
|
if (options.debug) {
|
|
Value paramKey = buildString(rewriter, loc, "proof");
|
|
Value paramValue = buildString(rewriter, loc, "true");
|
|
buildCall(rewriter, loc, "Z3_set_param_value",
|
|
LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, ptrTy}),
|
|
{config, paramKey, paramValue});
|
|
}
|
|
|
|
// Check if the logic is set anywhere within the solver
|
|
std::optional<StringRef> logic = std::nullopt;
|
|
auto setLogicOps = op.getBodyRegion().getOps<smt::SetLogicOp>();
|
|
if (!setLogicOps.empty()) {
|
|
// We know from before patterns were applied that there is only one
|
|
// set_logic op
|
|
auto setLogicOp = *setLogicOps.begin();
|
|
logic = setLogicOp.getLogic();
|
|
rewriter.eraseOp(setLogicOp);
|
|
}
|
|
|
|
// Create the context and store a pointer to it in the global variable.
|
|
Value ctx = buildCall(rewriter, loc, "Z3_mk_context", ptrToPtrFunc, config)
|
|
.getResult();
|
|
Value ctxAddr =
|
|
LLVM::AddressOfOp::create(rewriter, loc, globals.ctx).getResult();
|
|
LLVM::StoreOp::create(rewriter, loc, ctx, ctxAddr);
|
|
|
|
// Delete the configuration again.
|
|
buildCall(rewriter, loc, "Z3_del_config", ptrToVoidFunc, {config});
|
|
|
|
// Create a solver instance, increase its reference counter, and store a
|
|
// pointer to it in the global variable.
|
|
Value solver;
|
|
if (logic) {
|
|
auto logicStr = buildString(rewriter, loc, logic.value());
|
|
solver = buildCall(rewriter, loc, "Z3_mk_solver_for_logic",
|
|
ptrPtrToPtrFunc, {ctx, logicStr})
|
|
->getResult(0);
|
|
} else {
|
|
solver = buildCall(rewriter, loc, "Z3_mk_solver", ptrToPtrFunc, ctx)
|
|
->getResult(0);
|
|
}
|
|
buildCall(rewriter, loc, "Z3_solver_inc_ref", ptrPtrToVoidFunc,
|
|
{ctx, solver});
|
|
Value solverAddr =
|
|
LLVM::AddressOfOp::create(rewriter, loc, globals.solver).getResult();
|
|
LLVM::StoreOp::create(rewriter, loc, solver, solverAddr);
|
|
|
|
// This assumes that no constant hoisting of the like happens inbetween
|
|
// the patterns defined in this pass because once the solver initialization
|
|
// and deallocation calls are inserted and the body region is inlined,
|
|
// canonicalizations and folders applied inbetween lowering patterns might
|
|
// hoist the SMT constants which means they would access uninitialized
|
|
// global variables once they are lowered.
|
|
SmallVector<Type> convertedTypes;
|
|
if (failed(
|
|
typeConverter->convertTypes(op->getResultTypes(), convertedTypes)))
|
|
return failure();
|
|
|
|
func::FuncOp funcOp;
|
|
{
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
auto module = op->getParentOfType<ModuleOp>();
|
|
rewriter.setInsertionPointToEnd(module.getBody());
|
|
|
|
funcOp = func::FuncOp::create(
|
|
rewriter, loc, globals.names.newName("solver"),
|
|
rewriter.getFunctionType(adaptor.getInputs().getTypes(),
|
|
convertedTypes));
|
|
rewriter.inlineRegionBefore(op.getBodyRegion(), funcOp.getBody(),
|
|
funcOp.end());
|
|
}
|
|
|
|
ValueRange results =
|
|
func::CallOp::create(rewriter, loc, funcOp, adaptor.getInputs())
|
|
->getResults();
|
|
|
|
// At the end of the region, decrease the solver's reference counter and
|
|
// delete the context.
|
|
// NOTE: we cannot use the convenience helper here because we don't want to
|
|
// load the context from the global but use the result from the 'mk_context'
|
|
// call directly for two reasons:
|
|
// * avoid an unnecessary load
|
|
// * the caching mechanism of the context does not work here because it
|
|
// would reuse the loaded context from a earlier solver
|
|
buildCall(rewriter, loc, "Z3_solver_dec_ref", ptrPtrToVoidFunc,
|
|
{ctx, solver});
|
|
buildCall(rewriter, loc, "Z3_del_context", ptrToVoidFunc, ctx);
|
|
|
|
rewriter.replaceOp(op, results);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.assert` operations to Z3 API calls of the form:
|
|
/// ```
|
|
/// void Z3_API Z3_solver_assert(Z3_context c, Z3_solver s, Z3_ast a);
|
|
/// ```
|
|
struct AssertOpLowering : public SMTLoweringPattern<AssertOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
buildAPICallWithContext(
|
|
rewriter, loc, "Z3_solver_assert",
|
|
LLVM::LLVMVoidType::get(getContext()),
|
|
{buildSolverPtr(rewriter, loc), adaptor.getInput()});
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.reset` operations to Z3 API calls of the form:
|
|
/// ```
|
|
/// void Z3_API Z3_solver_reset(Z3_context c, Z3_solver s);
|
|
/// ```
|
|
struct ResetOpLowering : public SMTLoweringPattern<ResetOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ResetOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
buildAPICallWithContext(rewriter, loc, "Z3_solver_reset",
|
|
LLVM::LLVMVoidType::get(getContext()),
|
|
{buildSolverPtr(rewriter, loc)});
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.push` operations to (repeated) Z3 API calls of the form:
|
|
/// ```
|
|
/// void Z3_API Z3_solver_push(Z3_context c, Z3_solver s);
|
|
/// ```
|
|
struct PushOpLowering : public SMTLoweringPattern<PushOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
LogicalResult
|
|
matchAndRewrite(PushOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
// SMTLIB allows multiple levels to be pushed with one push command, but the
|
|
// Z3 C API doesn't let you provide a number of levels for push calls so
|
|
// multiple calls have to be created.
|
|
for (uint32_t i = 0; i < op.getCount(); i++)
|
|
buildAPICallWithContext(rewriter, loc, "Z3_solver_push",
|
|
LLVM::LLVMVoidType::get(getContext()),
|
|
{buildSolverPtr(rewriter, loc)});
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.pop` operations to Z3 API calls of the form:
|
|
/// ```
|
|
/// void Z3_API Z3_solver_pop(Z3_context c, Z3_solver s, unsigned n);
|
|
/// ```
|
|
struct PopOpLowering : public SMTLoweringPattern<PopOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
LogicalResult
|
|
matchAndRewrite(PopOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
Value constVal = LLVM::ConstantOp::create(
|
|
rewriter, loc, rewriter.getI32Type(), op.getCount());
|
|
buildAPICallWithContext(rewriter, loc, "Z3_solver_pop",
|
|
LLVM::LLVMVoidType::get(getContext()),
|
|
{buildSolverPtr(rewriter, loc), constVal});
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.yield` operations to `scf.yield` operations. This not necessary
|
|
/// for the yield in `smt.solver` or in quantifiers since they are deleted
|
|
/// directly by the parent operation, but makes the lowering of the `smt.check`
|
|
/// operation simpler and more convenient since the regions get translated
|
|
/// directly to regions of `scf.if` operations.
|
|
struct YieldOpLowering : public SMTLoweringPattern<YieldOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(YieldOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
if (op->getParentOfType<func::FuncOp>()) {
|
|
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getValues());
|
|
return success();
|
|
}
|
|
if (op->getParentOfType<LLVM::LLVMFuncOp>()) {
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getValues());
|
|
return success();
|
|
}
|
|
if (isa<scf::SCFDialect>(op->getParentOp()->getDialect())) {
|
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getValues());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.check` operations to Z3 API calls and control-flow operations.
|
|
/// ```
|
|
/// Z3_lbool Z3_API Z3_solver_check(Z3_context c, Z3_solver s);
|
|
///
|
|
/// typedef enum
|
|
/// {
|
|
/// Z3_L_FALSE = -1, // means unsatisfiable here
|
|
/// Z3_L_UNDEF, // means unknown here
|
|
/// Z3_L_TRUE // means satisfiable here
|
|
/// } Z3_lbool;
|
|
/// ```
|
|
struct CheckOpLowering : public SMTLoweringPattern<CheckOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CheckOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
|
|
auto printfType = LLVM::LLVMFunctionType::get(
|
|
LLVM::LLVMVoidType::get(rewriter.getContext()), {ptrTy}, true);
|
|
|
|
auto getHeaderString = [](const std::string &title) {
|
|
unsigned titleSize = title.size() + 2; // Add a space left and right
|
|
return std::string((80 - titleSize) / 2, '-') + " " + title + " " +
|
|
std::string((80 - titleSize + 1) / 2, '-') + "\n%s\n" +
|
|
std::string(80, '-') + "\n";
|
|
};
|
|
|
|
// Get the pointer to the solver instance.
|
|
Value solver = buildSolverPtr(rewriter, loc);
|
|
|
|
// In debug-mode, print the state of the solver before calling 'check-sat'
|
|
// on it. This prints the asserted SMT expressions.
|
|
if (options.debug) {
|
|
auto solverStringPtr =
|
|
buildPtrAPICall(rewriter, loc, "Z3_solver_to_string", {solver});
|
|
auto solverFormatString =
|
|
buildString(rewriter, loc, getHeaderString("Solver"));
|
|
buildCall(rewriter, op.getLoc(), "printf", printfType,
|
|
{solverFormatString, solverStringPtr});
|
|
}
|
|
|
|
// Convert the result types of the `smt.check` operation.
|
|
SmallVector<Type> resultTypes;
|
|
if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
|
|
return failure();
|
|
|
|
// Call 'check-sat' and check if the assertions are satisfiable.
|
|
Value checkResult =
|
|
buildAPICallWithContext(rewriter, loc, "Z3_solver_check",
|
|
rewriter.getI32Type(), {solver})
|
|
->getResult(0);
|
|
Value constOne =
|
|
LLVM::ConstantOp::create(rewriter, loc, checkResult.getType(), 1);
|
|
Value isSat = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq,
|
|
checkResult, constOne);
|
|
|
|
// Simply inline the 'sat' region into the 'then' region of the 'scf.if'
|
|
auto satIfOp = scf::IfOp::create(rewriter, loc, resultTypes, isSat);
|
|
rewriter.inlineRegionBefore(op.getSatRegion(), satIfOp.getThenRegion(),
|
|
satIfOp.getThenRegion().end());
|
|
|
|
// Otherwise, the 'else' block checks if the assertions are unsatisfiable or
|
|
// unknown. The corresponding regions can also be simply inlined into the
|
|
// two branches of this nested if-statement as well.
|
|
rewriter.createBlock(&satIfOp.getElseRegion());
|
|
Value constNegOne =
|
|
LLVM::ConstantOp::create(rewriter, loc, checkResult.getType(), -1);
|
|
Value isUnsat = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq,
|
|
checkResult, constNegOne);
|
|
auto unsatIfOp = scf::IfOp::create(rewriter, loc, resultTypes, isUnsat);
|
|
scf::YieldOp::create(rewriter, loc, unsatIfOp->getResults());
|
|
|
|
rewriter.inlineRegionBefore(op.getUnsatRegion(), unsatIfOp.getThenRegion(),
|
|
unsatIfOp.getThenRegion().end());
|
|
rewriter.inlineRegionBefore(op.getUnknownRegion(),
|
|
unsatIfOp.getElseRegion(),
|
|
unsatIfOp.getElseRegion().end());
|
|
|
|
rewriter.replaceOp(op, satIfOp->getResults());
|
|
|
|
if (options.debug) {
|
|
// In debug-mode, if the assertions are unsatisfiable we can print the
|
|
// proof.
|
|
rewriter.setInsertionPointToStart(unsatIfOp.thenBlock());
|
|
auto proof = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_proof",
|
|
{solver});
|
|
auto stringPtr =
|
|
buildPtrAPICall(rewriter, op.getLoc(), "Z3_ast_to_string", {proof});
|
|
auto formatString =
|
|
buildString(rewriter, op.getLoc(), getHeaderString("Proof"));
|
|
buildCall(rewriter, op.getLoc(), "printf", printfType,
|
|
{formatString, stringPtr});
|
|
|
|
// In debug mode, if the assertions are satisfiable we can print the model
|
|
// (effectively a counter-example).
|
|
rewriter.setInsertionPointToStart(satIfOp.thenBlock());
|
|
auto model = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_model",
|
|
{solver});
|
|
auto modelStringPtr =
|
|
buildPtrAPICall(rewriter, op.getLoc(), "Z3_model_to_string", {model});
|
|
auto modelFormatString =
|
|
buildString(rewriter, op.getLoc(), getHeaderString("Model"));
|
|
buildCall(rewriter, op.getLoc(), "printf", printfType,
|
|
{modelFormatString, modelStringPtr});
|
|
}
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.forall` and `smt.exists` operations to the following Z3 API call.
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_{forall|exists}_const(
|
|
/// Z3_context c,
|
|
/// unsigned weight,
|
|
/// unsigned num_bound,
|
|
/// Z3_app const bound[],
|
|
/// unsigned num_patterns,
|
|
/// Z3_pattern const patterns[],
|
|
/// Z3_ast body
|
|
/// );
|
|
/// ```
|
|
/// All nested regions are inlined into the parent region and the block
|
|
/// arguments are replaced with new `smt.declare_fun` constants that are also
|
|
/// passed to the `bound` argument of above API function. Patterns are created
|
|
/// with the following API function.
|
|
/// ```
|
|
/// Z3_pattern Z3_API Z3_mk_pattern(Z3_context c, unsigned num_patterns,
|
|
/// Z3_ast const terms[]);
|
|
/// ```
|
|
/// Where each operand of the `smt.yield` in a pattern region is a 'term'.
|
|
template <typename QuantifierOp>
|
|
struct QuantifierLowering : public SMTLoweringPattern<QuantifierOp> {
|
|
using SMTLoweringPattern<QuantifierOp>::SMTLoweringPattern;
|
|
using SMTLoweringPattern<QuantifierOp>::typeConverter;
|
|
using SMTLoweringPattern<QuantifierOp>::buildPtrAPICall;
|
|
using OpAdaptor = typename QuantifierOp::Adaptor;
|
|
|
|
Value createStorageForValueList(ValueRange values, Location loc,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
|
|
Type arrTy = LLVM::LLVMArrayType::get(ptrTy, values.size());
|
|
Value constOne =
|
|
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
|
|
Value storage =
|
|
LLVM::AllocaOp::create(rewriter, loc, ptrTy, arrTy, constOne);
|
|
Value array = LLVM::UndefOp::create(rewriter, loc, arrTy);
|
|
|
|
for (auto [i, val] : llvm::enumerate(values))
|
|
array = LLVM::InsertValueOp::create(rewriter, loc, array, val,
|
|
ArrayRef<int64_t>(i));
|
|
|
|
LLVM::StoreOp::create(rewriter, loc, array, storage);
|
|
|
|
return storage;
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(QuantifierOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
|
|
|
|
// no-pattern attribute not supported yet because the Z3 CAPI allows more
|
|
// fine-grained control where a list of patterns to be banned can be given.
|
|
// This means, the no-pattern attribute is equivalent to providing a list of
|
|
// all possible sub-expressions in the quantifier body to the CAPI.
|
|
if (adaptor.getNoPattern())
|
|
return rewriter.notifyMatchFailure(
|
|
op, "no-pattern attribute not yet supported!");
|
|
|
|
rewriter.setInsertionPoint(op);
|
|
|
|
// Weight attribute
|
|
Value weight = LLVM::ConstantOp::create(
|
|
rewriter, loc, rewriter.getI32Type(), adaptor.getWeight());
|
|
|
|
// Bound variables
|
|
unsigned numDecls = op.getBody().getNumArguments();
|
|
Value numDeclsVal = LLVM::ConstantOp::create(
|
|
rewriter, loc, rewriter.getI32Type(), numDecls);
|
|
|
|
// We replace the block arguments with constant symbolic values and inform
|
|
// the quantifier API call which constants it should treat as bound
|
|
// variables. We also need to make sure that we use the exact same SSA
|
|
// values in the pattern regions since we lower constant declaration
|
|
// operation to always produce fresh constants.
|
|
SmallVector<Value> repl;
|
|
for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) {
|
|
Value newArg;
|
|
if (adaptor.getBoundVarNames().has_value())
|
|
newArg = smt::DeclareFunOp::create(
|
|
rewriter, loc, arg.getType(),
|
|
cast<StringAttr>((*adaptor.getBoundVarNames())[i]));
|
|
else
|
|
newArg = smt::DeclareFunOp::create(rewriter, loc, arg.getType());
|
|
repl.push_back(typeConverter->materializeTargetConversion(
|
|
rewriter, loc, typeConverter->convertType(arg.getType()), newArg));
|
|
}
|
|
|
|
Value boundStorage = createStorageForValueList(repl, loc, rewriter);
|
|
|
|
// Body Expression
|
|
auto yieldOp = cast<smt::YieldOp>(op.getBody().front().getTerminator());
|
|
Value bodyExp = yieldOp.getValues()[0];
|
|
rewriter.setInsertionPointAfterValue(bodyExp);
|
|
bodyExp = typeConverter->materializeTargetConversion(
|
|
rewriter, loc, typeConverter->convertType(bodyExp.getType()), bodyExp);
|
|
rewriter.eraseOp(yieldOp);
|
|
|
|
rewriter.inlineBlockBefore(&op.getBody().front(), op, repl);
|
|
rewriter.setInsertionPoint(op);
|
|
|
|
// Patterns
|
|
unsigned numPatterns = adaptor.getPatterns().size();
|
|
Value numPatternsVal = LLVM::ConstantOp::create(
|
|
rewriter, loc, rewriter.getI32Type(), numPatterns);
|
|
|
|
Value patternStorage;
|
|
if (numPatterns > 0) {
|
|
SmallVector<Value> patterns;
|
|
for (Region *patternRegion : adaptor.getPatterns()) {
|
|
auto yieldOp =
|
|
cast<smt::YieldOp>(patternRegion->front().getTerminator());
|
|
auto patternTerms = yieldOp.getOperands();
|
|
|
|
rewriter.setInsertionPoint(yieldOp);
|
|
SmallVector<Value> patternList;
|
|
for (auto val : patternTerms)
|
|
patternList.push_back(typeConverter->materializeTargetConversion(
|
|
rewriter, loc, typeConverter->convertType(val.getType()), val));
|
|
|
|
rewriter.eraseOp(yieldOp);
|
|
rewriter.inlineBlockBefore(&patternRegion->front(), op, repl);
|
|
|
|
rewriter.setInsertionPoint(op);
|
|
Value numTerms = LLVM::ConstantOp::create(
|
|
rewriter, loc, rewriter.getI32Type(), patternTerms.size());
|
|
Value patternTermStorage =
|
|
createStorageForValueList(patternList, loc, rewriter);
|
|
Value pattern = buildPtrAPICall(rewriter, loc, "Z3_mk_pattern",
|
|
{numTerms, patternTermStorage});
|
|
|
|
patterns.emplace_back(pattern);
|
|
}
|
|
patternStorage = createStorageForValueList(patterns, loc, rewriter);
|
|
} else {
|
|
// If we set the num_patterns parameter to 0, we can just pass a nullptr
|
|
// as storage.
|
|
patternStorage = LLVM::ZeroOp::create(rewriter, loc, ptrTy);
|
|
}
|
|
|
|
StringRef apiCallName = "Z3_mk_forall_const";
|
|
if (std::is_same_v<QuantifierOp, ExistsOp>)
|
|
apiCallName = "Z3_mk_exists_const";
|
|
Value quantifierExp =
|
|
buildPtrAPICall(rewriter, loc, apiCallName,
|
|
{weight, numDeclsVal, boundStorage, numPatternsVal,
|
|
patternStorage, bodyExp});
|
|
|
|
rewriter.replaceOp(op, quantifierExp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.bv.repeat` operations to Z3 API function calls of the form
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_repeat(Z3_context c, unsigned i, Z3_ast t1);
|
|
/// ```
|
|
struct RepeatOpLowering : public SMTLoweringPattern<RepeatOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(RepeatOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Value count = LLVM::ConstantOp::create(
|
|
rewriter, op.getLoc(), rewriter.getI32Type(), op.getCount());
|
|
rewriter.replaceOp(op,
|
|
buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_repeat",
|
|
{count, adaptor.getInput()}));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.bv.extract` operations to Z3 API function calls of the following
|
|
/// form, where the output bit-vector has size `n = high - low + 1`. This means,
|
|
/// both the 'high' and 'low' indices are inclusive.
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_extract(Z3_context c, unsigned high, unsigned low,
|
|
/// Z3_ast t1);
|
|
/// ```
|
|
struct ExtractOpLowering : public SMTLoweringPattern<ExtractOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
Value low = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
|
|
adaptor.getLowBit());
|
|
Value high = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
|
|
adaptor.getLowBit() +
|
|
op.getType().getWidth() - 1);
|
|
rewriter.replaceOp(op, buildPtrAPICall(rewriter, loc, "Z3_mk_extract",
|
|
{high, low, adaptor.getInput()}));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.array.broadcast` operations to Z3 API function calls of the form
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_const_array(Z3_context c, Z3_sort domain, Z3_ast v);
|
|
/// ```
|
|
struct ArrayBroadcastOpLowering
|
|
: public SMTLoweringPattern<smt::ArrayBroadcastOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(smt::ArrayBroadcastOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
auto domainSort = buildSort(
|
|
rewriter, op.getLoc(),
|
|
cast<smt::ArrayType>(op.getResult().getType()).getDomainType());
|
|
|
|
rewriter.replaceOp(op, buildPtrAPICall(rewriter, op.getLoc(),
|
|
"Z3_mk_const_array",
|
|
{domainSort, adaptor.getValue()}));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower the `smt.constant` operation to one of the following Z3 API function
|
|
/// calls depending on the value of the boolean attribute.
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_true(Z3_context c);
|
|
/// Z3_ast Z3_API Z3_mk_false(Z3_context c);
|
|
/// ```
|
|
struct BoolConstantOpLowering : public SMTLoweringPattern<smt::BoolConstantOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(smt::BoolConstantOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
rewriter.replaceOp(
|
|
op, buildPtrAPICall(rewriter, op.getLoc(),
|
|
adaptor.getValue() ? "Z3_mk_true" : "Z3_mk_false"));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.int.constant` operations to one of the following two Z3 API
|
|
/// function calls depending on whether the storage APInt has a bit-width that
|
|
/// fits in a `uint64_t`.
|
|
/// ```
|
|
/// Z3_sort Z3_API Z3_mk_int_sort(Z3_context c);
|
|
///
|
|
/// Z3_ast Z3_API Z3_mk_int64(Z3_context c, int64_t v, Z3_sort ty);
|
|
///
|
|
/// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
|
|
/// Z3_ast Z3_API Z3_mk_unary_minus(Z3_context c, Z3_ast arg);
|
|
/// ```
|
|
struct IntConstantOpLowering : public SMTLoweringPattern<smt::IntConstantOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(smt::IntConstantOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
Value type = buildPtrAPICall(rewriter, loc, "Z3_mk_int_sort");
|
|
if (adaptor.getValue().getBitWidth() <= 64) {
|
|
Value val = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
|
|
adaptor.getValue().getSExtValue());
|
|
rewriter.replaceOp(
|
|
op, buildPtrAPICall(rewriter, loc, "Z3_mk_int64", {val, type}));
|
|
return success();
|
|
}
|
|
|
|
std::string numeralStr;
|
|
llvm::raw_string_ostream stream(numeralStr);
|
|
stream << adaptor.getValue().abs();
|
|
|
|
Value numeral = buildString(rewriter, loc, numeralStr);
|
|
Value intNumeral =
|
|
buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {numeral, type});
|
|
|
|
if (adaptor.getValue().isNegative())
|
|
intNumeral =
|
|
buildPtrAPICall(rewriter, loc, "Z3_mk_unary_minus", intNumeral);
|
|
|
|
rewriter.replaceOp(op, intNumeral);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.int.cmp` operations to one of the following Z3 API function calls
|
|
/// depending on the predicate.
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_{{pred}}(Z3_context c, Z3_ast t1, Z3_ast t2);
|
|
/// ```
|
|
struct IntCmpOpLowering : public SMTLoweringPattern<IntCmpOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(IntCmpOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
rewriter.replaceOp(
|
|
op,
|
|
buildPtrAPICall(rewriter, op.getLoc(),
|
|
"Z3_mk_" + stringifyIntPredicate(op.getPred()).str(),
|
|
{adaptor.getLhs(), adaptor.getRhs()}));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.int2bv` operations to the following Z3 API function calls.
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_int2bv(Z3_context c, unsigned n, Z3_ast t1);
|
|
/// ```
|
|
struct Int2BVOpLowering : public SMTLoweringPattern<Int2BVOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Int2BVOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Value widthConst =
|
|
LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI32Type(),
|
|
op.getResult().getType().getWidth());
|
|
rewriter.replaceOp(op,
|
|
buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_int2bv",
|
|
{widthConst, adaptor.getInput()}));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.bv2int` operations to the following Z3 API function call.
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_bv2int(Z3_context c, Z3_ast t1, bool is_signed)
|
|
/// ```
|
|
struct BV2IntOpLowering : public SMTLoweringPattern<BV2IntOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(BV2IntOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
// FIXME: ideally we don't want to use i1 here, since bools can sometimes be
|
|
// compiled to wider widths in LLVM
|
|
Value isSignedConst = LLVM::ConstantOp::create(
|
|
rewriter, op->getLoc(), rewriter.getI1Type(), op.getIsSigned());
|
|
rewriter.replaceOp(op,
|
|
buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_bv2int",
|
|
{adaptor.getInput(), isSignedConst}));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower `smt.bv.cmp` operations to one of the following Z3 API function calls,
|
|
/// performing two's complement comparison, depending on the predicate
|
|
/// attribute.
|
|
/// ```
|
|
/// Z3_ast Z3_API Z3_mk_bv{{pred}}(Z3_context c, Z3_ast t1, Z3_ast t2);
|
|
/// ```
|
|
struct BVCmpOpLowering : public SMTLoweringPattern<BVCmpOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(BVCmpOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
rewriter.replaceOp(
|
|
op, buildPtrAPICall(rewriter, op.getLoc(),
|
|
"Z3_mk_bv" +
|
|
stringifyBVCmpPredicate(op.getPred()).str(),
|
|
{adaptor.getLhs(), adaptor.getRhs()}));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Expand the `smt.int.abs` operation to a `smt.ite` operation.
|
|
struct IntAbsOpLowering : public SMTLoweringPattern<IntAbsOp> {
|
|
using SMTLoweringPattern::SMTLoweringPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(IntAbsOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
Value zero = IntConstantOp::create(
|
|
rewriter, loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
|
|
Value cmp = IntCmpOp::create(rewriter, loc, IntPredicate::lt,
|
|
adaptor.getInput(), zero);
|
|
Value neg = IntSubOp::create(rewriter, loc, zero, adaptor.getInput());
|
|
rewriter.replaceOpWithNewOp<IteOp>(op, cmp, neg, adaptor.getInput());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pass Implementation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct LowerSMTToZ3LLVMPass
|
|
: public circt::impl::LowerSMTToZ3LLVMBase<LowerSMTToZ3LLVMPass> {
|
|
using Base::Base;
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void circt::populateSMTToZ3LLVMTypeConverter(TypeConverter &converter) {
|
|
converter.addConversion([](smt::BoolType type) {
|
|
return LLVM::LLVMPointerType::get(type.getContext());
|
|
});
|
|
converter.addConversion([](smt::BitVectorType type) {
|
|
return LLVM::LLVMPointerType::get(type.getContext());
|
|
});
|
|
converter.addConversion([](smt::ArrayType type) {
|
|
return LLVM::LLVMPointerType::get(type.getContext());
|
|
});
|
|
converter.addConversion([](smt::IntType type) {
|
|
return LLVM::LLVMPointerType::get(type.getContext());
|
|
});
|
|
converter.addConversion([](smt::SMTFuncType type) {
|
|
return LLVM::LLVMPointerType::get(type.getContext());
|
|
});
|
|
converter.addConversion([](smt::SortType type) {
|
|
return LLVM::LLVMPointerType::get(type.getContext());
|
|
});
|
|
}
|
|
|
|
void circt::populateSMTToZ3LLVMConversionPatterns(
|
|
RewritePatternSet &patterns, TypeConverter &converter,
|
|
SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options) {
|
|
#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS) \
|
|
patterns.add<VariadicSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
|
|
converter, patterns.getContext(), \
|
|
globals, options, APINAME, \
|
|
MIN_NUM_ARGS);
|
|
|
|
#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS) \
|
|
patterns.add<OneToOneSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
|
|
converter, patterns.getContext(), \
|
|
globals, options, APINAME, NUM_ARGS);
|
|
|
|
// Lower `smt.distinct` operations which allows a variadic number of operands
|
|
// according to the `:pairwise` attribute. The Z3 API function supports a
|
|
// variadic number of operands as well, i.e., a direct lowering is possible:
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_distinct(Z3_context c, unsigned num_args, Z3_ast const
|
|
// args[])
|
|
// ```
|
|
// The API function requires num_args > 1 which is guaranteed to be satisfied
|
|
// because `smt.distinct` is verified to have > 1 operands.
|
|
ADD_VARIADIC_PATTERN(DistinctOp, "Z3_mk_distinct", 2);
|
|
|
|
// Lower `smt.and` operations which allows a variadic number of operands
|
|
// according to the `:left-assoc` attribute. The Z3 API function supports a
|
|
// variadic number of operands as well, i.e., a direct lowering is possible:
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_and(Z3_context c, unsigned num_args, Z3_ast const
|
|
// args[])
|
|
// ```
|
|
// The API function requires num_args > 1. This is not guaranteed by the
|
|
// `smt.and` operation and thus the pattern will not apply when no operand is
|
|
// present. The constant folder of the operation is assumed to fold this to
|
|
// a constant 'true' (neutral element of AND).
|
|
ADD_VARIADIC_PATTERN(AndOp, "Z3_mk_and", 2);
|
|
|
|
// Lower `smt.or` operations which allows a variadic number of operands
|
|
// according to the `:left-assoc` attribute. The Z3 API function supports a
|
|
// variadic number of operands as well, i.e., a direct lowering is possible:
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_or(Z3_context c, unsigned num_args, Z3_ast const
|
|
// args[])
|
|
// ```
|
|
// The API function requires num_args > 1. This is not guaranteed by the
|
|
// `smt.or` operation and thus the pattern will not apply when no operand is
|
|
// present. The constant folder of the operation is assumed to fold this to
|
|
// a constant 'false' (neutral element of OR).
|
|
ADD_VARIADIC_PATTERN(OrOp, "Z3_mk_or", 2);
|
|
|
|
// Lower `smt.not` operations to the following Z3 API function:
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_not(Z3_context c, Z3_ast a);
|
|
// ```
|
|
ADD_ONE_TO_ONE_PATTERN(NotOp, "Z3_mk_not", 1);
|
|
|
|
// Lower `smt.xor` operations which allows a variadic number of operands
|
|
// according to the `:left-assoc` attribute. The Z3 API function, however,
|
|
// only takes two operands.
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_xor(Z3_context c, Z3_ast t1, Z3_ast t2);
|
|
// ```
|
|
// Therefore, we need to decompose the operation first to a sequence of XOR
|
|
// operations matching the left associative behavior.
|
|
patterns.add<LowerLeftAssocSMTPattern<XOrOp>>(
|
|
converter, patterns.getContext(), globals, options);
|
|
ADD_ONE_TO_ONE_PATTERN(XOrOp, "Z3_mk_xor", 2);
|
|
|
|
// Lower `smt.implies` operations to the following Z3 API function:
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_implies(Z3_context c, Z3_ast t1, Z3_ast t2);
|
|
// ```
|
|
ADD_ONE_TO_ONE_PATTERN(ImpliesOp, "Z3_mk_implies", 2);
|
|
|
|
// All the bit-vector arithmetic and bitwise operations conveniently lower to
|
|
// Z3 API function calls with essentially matching names and a one-to-one
|
|
// correspondence of operands to call arguments.
|
|
ADD_ONE_TO_ONE_PATTERN(BVNegOp, "Z3_mk_bvneg", 1);
|
|
ADD_ONE_TO_ONE_PATTERN(BVAddOp, "Z3_mk_bvadd", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVMulOp, "Z3_mk_bvmul", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVURemOp, "Z3_mk_bvurem", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVSRemOp, "Z3_mk_bvsrem", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVSModOp, "Z3_mk_bvsmod", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVUDivOp, "Z3_mk_bvudiv", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVSDivOp, "Z3_mk_bvsdiv", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVShlOp, "Z3_mk_bvshl", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVLShrOp, "Z3_mk_bvlshr", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVAShrOp, "Z3_mk_bvashr", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVNotOp, "Z3_mk_bvnot", 1);
|
|
ADD_ONE_TO_ONE_PATTERN(BVAndOp, "Z3_mk_bvand", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVOrOp, "Z3_mk_bvor", 2);
|
|
ADD_ONE_TO_ONE_PATTERN(BVXOrOp, "Z3_mk_bvxor", 2);
|
|
|
|
// The `smt.bv.concat` operation only supports two operands, just like the
|
|
// Z3 API function.
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_concat(Z3_context c, Z3_ast t1, Z3_ast t2);
|
|
// ```
|
|
ADD_ONE_TO_ONE_PATTERN(ConcatOp, "Z3_mk_concat", 2);
|
|
|
|
// Lower the `smt.ite` operation to the following Z3 API function call, where
|
|
// `t1` must have boolean sort.
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_ite(Z3_context c, Z3_ast t1, Z3_ast t2, Z3_ast t3);
|
|
// ```
|
|
ADD_ONE_TO_ONE_PATTERN(IteOp, "Z3_mk_ite", 3);
|
|
|
|
// Lower the `smt.array.select` operation to the following Z3 function call.
|
|
// The operand declaration of the operation matches the order of arguments of
|
|
// the API function.
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_select(Z3_context c, Z3_ast a, Z3_ast i);
|
|
// ```
|
|
// Where `a` is the array expression and `i` is the index expression.
|
|
ADD_ONE_TO_ONE_PATTERN(ArraySelectOp, "Z3_mk_select", 2);
|
|
|
|
// Lower the `smt.array.store` operation to the following Z3 function call.
|
|
// The operand declaration of the operation matches the order of arguments of
|
|
// the API function.
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_store(Z3_context c, Z3_ast a, Z3_ast i, Z3_ast v);
|
|
// ```
|
|
// Where `a` is the array expression, `i` is the index expression, and `v` is
|
|
// the value expression to be stored.
|
|
ADD_ONE_TO_ONE_PATTERN(ArrayStoreOp, "Z3_mk_store", 3);
|
|
|
|
// Lower the `smt.int.add` operation to the following Z3 API function call.
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_add(Z3_context c, unsigned num_args, Z3_ast const
|
|
// args[]);
|
|
// ```
|
|
// The number of arguments must be greater than zero. Therefore, the pattern
|
|
// will fail if applied to an operation with less than two operands.
|
|
ADD_VARIADIC_PATTERN(IntAddOp, "Z3_mk_add", 2);
|
|
|
|
// Lower the `smt.int.mul` operation to the following Z3 API function call.
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_mul(Z3_context c, unsigned num_args, Z3_ast const
|
|
// args[]);
|
|
// ```
|
|
// The number of arguments must be greater than zero. Therefore, the pattern
|
|
// will fail if applied to an operation with less than two operands.
|
|
ADD_VARIADIC_PATTERN(IntMulOp, "Z3_mk_mul", 2);
|
|
|
|
// Lower the `smt.int.sub` operation to the following Z3 API function call.
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_sub(Z3_context c, unsigned num_args, Z3_ast const
|
|
// args[]);
|
|
// ```
|
|
// The number of arguments must be greater than zero. Since the `smt.int.sub`
|
|
// operation always has exactly two operands, this trivially holds.
|
|
ADD_VARIADIC_PATTERN(IntSubOp, "Z3_mk_sub", 2);
|
|
|
|
// Lower the `smt.int.div` operation to the following Z3 API function call.
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_div(Z3_context c, Z3_ast arg1, Z3_ast arg2);
|
|
// ```
|
|
ADD_ONE_TO_ONE_PATTERN(IntDivOp, "Z3_mk_div", 2);
|
|
|
|
// Lower the `smt.int.mod` operation to the following Z3 API function call.
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_mod(Z3_context c, Z3_ast arg1, Z3_ast arg2);
|
|
// ```
|
|
ADD_ONE_TO_ONE_PATTERN(IntModOp, "Z3_mk_mod", 2);
|
|
|
|
#undef ADD_VARIADIC_PATTERN
|
|
#undef ADD_ONE_TO_ONE_PATTERN
|
|
|
|
// Lower `smt.eq` operations which allows a variadic number of operands
|
|
// according to the `:chainable` attribute. The Z3 API function does not
|
|
// support a variadic number of operands, but exactly two:
|
|
// ```
|
|
// Z3_ast Z3_API Z3_mk_eq(Z3_context c, Z3_ast l, Z3_ast r)
|
|
// ```
|
|
// As a result, we first apply a rewrite pattern that unfolds chainable
|
|
// operators and then lower it one-to-one to the API function. In this case,
|
|
// this means:
|
|
// ```
|
|
// eq(a,b,c,d) ->
|
|
// and(eq(a,b), eq(b,c), eq(c,d)) ->
|
|
// and(Z3_mk_eq(ctx, a, b), Z3_mk_eq(ctx, b, c), Z3_mk_eq(ctx, c, d))
|
|
// ```
|
|
// The patterns for `smt.and` will then do the remaining work.
|
|
patterns.add<LowerChainableSMTPattern<EqOp>>(converter, patterns.getContext(),
|
|
globals, options);
|
|
patterns.add<OneToOneSMTPattern<EqOp>>(converter, patterns.getContext(),
|
|
globals, options, "Z3_mk_eq", 2);
|
|
|
|
// Other lowering patterns. Refer to their implementation directly for more
|
|
// information.
|
|
patterns.add<BVConstantOpLowering, DeclareFunOpLowering, AssertOpLowering,
|
|
ResetOpLowering, PushOpLowering, PopOpLowering, CheckOpLowering,
|
|
SolverOpLowering, ApplyFuncOpLowering, YieldOpLowering,
|
|
RepeatOpLowering, ExtractOpLowering, BoolConstantOpLowering,
|
|
IntConstantOpLowering, ArrayBroadcastOpLowering, BVCmpOpLowering,
|
|
IntCmpOpLowering, IntAbsOpLowering, Int2BVOpLowering,
|
|
BV2IntOpLowering, QuantifierLowering<ForallOp>,
|
|
QuantifierLowering<ExistsOp>>(converter, patterns.getContext(),
|
|
globals, options);
|
|
}
|
|
|
|
void LowerSMTToZ3LLVMPass::runOnOperation() {
|
|
LowerSMTToZ3LLVMOptions options;
|
|
options.debug = debug;
|
|
|
|
// Check that the lowering is possible
|
|
// Specifically, check that the use of set-logic ops is valid for z3
|
|
auto setLogicCheck = getOperation().walk([&](SolverOp solverOp)
|
|
-> WalkResult {
|
|
// Check that solver ops only contain one set-logic op and that they're at
|
|
// the start of the body
|
|
auto setLogicOps = solverOp.getBodyRegion().getOps<smt::SetLogicOp>();
|
|
auto numSetLogicOps = std::distance(setLogicOps.begin(), setLogicOps.end());
|
|
if (numSetLogicOps > 1) {
|
|
return solverOp.emitError(
|
|
"multiple set-logic operations found in one solver operation - Z3 "
|
|
"only supports setting the logic once");
|
|
}
|
|
if (numSetLogicOps == 1)
|
|
// Check the only ops before the set-logic op are ConstantLike
|
|
for (auto &blockOp : solverOp.getBodyRegion().getOps()) {
|
|
if (isa<smt::SetLogicOp>(blockOp))
|
|
break;
|
|
if (!blockOp.hasTrait<OpTrait::ConstantLike>()) {
|
|
return solverOp.emitError("set-logic operation must be the first "
|
|
"non-constant operation in a solver "
|
|
"operation");
|
|
}
|
|
}
|
|
return WalkResult::advance();
|
|
});
|
|
if (setLogicCheck.wasInterrupted())
|
|
return signalPassFailure();
|
|
|
|
// Set up the type converter
|
|
LLVMTypeConverter converter(&getContext());
|
|
populateSMTToZ3LLVMTypeConverter(converter);
|
|
|
|
RewritePatternSet patterns(&getContext());
|
|
|
|
// Populate the func to LLVM conversion patterns for two reasons:
|
|
// * Typically functions are represented using `func.func` and including the
|
|
// patterns to lower them here is more convenient for most lowering
|
|
// pipelines (avoids running another pass).
|
|
// * Already having `llvm.func` in the input or lowering `func.func` before
|
|
// the SMT in the body leads to issues because the SCF conversion patterns
|
|
// don't take the type converter into consideration and thus create blocks
|
|
// with the old types for block arguments. However, the conversion happens
|
|
// top-down and thus are assumed to be converted by the parent function op
|
|
// which at that point would have already been lowered (and the blocks are
|
|
// also not there when doing everything in one pass, i.e.,
|
|
// `populateAnyFunctionOpInterfaceTypeConversionPattern` does not have any
|
|
// effect as well). Are the SCF lowering patterns actually broken and should
|
|
// take a type-converter?
|
|
populateFuncToLLVMConversionPatterns(converter, patterns);
|
|
arith::populateArithToLLVMConversionPatterns(converter, patterns);
|
|
|
|
// Populate SCF to CF and CF to LLVM lowering patterns because we create
|
|
// `scf.if` operations in the lowering patterns for convenience (given the
|
|
// above issue we might want to lower to LLVM directly; or fix upstream?)
|
|
populateSCFToControlFlowConversionPatterns(patterns);
|
|
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
|
|
|
|
// Create the globals to store the context and solver and populate the SMT
|
|
// lowering patterns.
|
|
OpBuilder builder(&getContext());
|
|
auto globals = SMTGlobalsHandler::create(builder, getOperation());
|
|
populateSMTToZ3LLVMConversionPatterns(patterns, converter, globals, options);
|
|
|
|
// Do a full conversion. This assumes that all other dialects have been
|
|
// lowered before this pass already.
|
|
LLVMConversionTarget target(getContext());
|
|
target.addLegalOp<mlir::ModuleOp>();
|
|
target.addLegalOp<scf::YieldOp>();
|
|
|
|
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|