circt/lib/Dialect/LTL/LTLFolds.cpp

144 lines
4.5 KiB
C++

//===- LTLFolds.cpp -------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/LTL/LTLOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace circt;
using namespace ltl;
using namespace mlir;
/// Concatenate two value ranges into a larger range. Useful for declarative
/// rewrites.
static SmallVector<Value> concatValues(ValueRange a, ValueRange b) {
SmallVector<Value> v;
v.append(a.begin(), a.end());
v.append(b.begin(), b.end());
return v;
}
/// Inline all `ConcatOp`s in a range of values.
static SmallVector<Value> flattenConcats(ValueRange values) {
SmallVector<Value> flatInputs;
for (auto value : values) {
if (auto concatOp = value.getDefiningOp<ConcatOp>()) {
auto inputs = concatOp.getInputs();
flatInputs.append(inputs.begin(), inputs.end());
} else {
flatInputs.push_back(value);
}
}
return flatInputs;
}
//===----------------------------------------------------------------------===//
// Declarative Rewrites
//===----------------------------------------------------------------------===//
namespace patterns {
#include "circt/Dialect/LTL/LTLFolds.cpp.inc"
} // namespace patterns
//===----------------------------------------------------------------------===//
// AndOp / OrOp / IntersectOp
//===----------------------------------------------------------------------===//
LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
if (op.getType() == rewriter.getI1Type()) {
rewriter.replaceOpWithNewOp<comb::AndOp>(op, op.getInputs(), true);
return success();
}
return failure();
}
LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
if (op.getType() == rewriter.getI1Type()) {
rewriter.replaceOpWithNewOp<comb::OrOp>(op, op.getInputs(), true);
return success();
}
return failure();
}
LogicalResult IntersectOp::canonicalize(IntersectOp op,
PatternRewriter &rewriter) {
if (op.getType() == rewriter.getI1Type()) {
rewriter.replaceOpWithNewOp<comb::AndOp>(op, op.getInputs(), true);
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// DelayOp
//===----------------------------------------------------------------------===//
OpFoldResult DelayOp::fold(FoldAdaptor adaptor) {
// delay(s, 0, 0) -> s
if (adaptor.getDelay() == 0 && adaptor.getLength() == 0 &&
isa<SequenceType>(getInput().getType()))
return getInput();
return {};
}
void DelayOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<patterns::NestedDelays>(results.getContext());
results.add<patterns::MoveDelayIntoConcat>(results.getContext());
}
//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
// concat(s) -> s
if (getInputs().size() == 1)
return getInputs()[0];
return {};
}
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<patterns::FlattenConcats>(results.getContext());
}
//===----------------------------------------------------------------------===//
// RepeatLikeOps
//===----------------------------------------------------------------------===//
namespace {
struct RepeatLikeOp {
static OpFoldResult fold(uint64_t base, uint64_t more, Value input) {
// repeat(s, 1, 0) -> s
if (base == 1 && more == 0 && isa<SequenceType>(input.getType()))
return input;
return {};
}
};
} // namespace
OpFoldResult RepeatOp::fold(FoldAdaptor adaptor) {
auto more = adaptor.getMore();
if (more.has_value())
return RepeatLikeOp::fold(adaptor.getBase(), *more, getInput());
return {};
}
OpFoldResult GoToRepeatOp::fold(FoldAdaptor adaptor) {
return RepeatLikeOp::fold(adaptor.getBase(), adaptor.getMore(), getInput());
}
OpFoldResult NonConsecutiveRepeatOp::fold(FoldAdaptor adaptor) {
return RepeatLikeOp::fold(adaptor.getBase(), adaptor.getMore(), getInput());
}