Comb Interval Range Analysis and Comb Opt Narrowing pass (#8425)

Building on the existing MLIR integer interval range analysis framework, build the interface for the comb dialect. Use the interval range analysis to develop a comb opt narrowing pass, that reduces a comb opt based on the interval range of the operation. Currently this is only supported for addition, subtraction and multiplication, but this could be extended in the future. Future work will leverage the interval analysis to validate the combination of consecutive addition operators into a multi-operand addition.
This commit is contained in:
Samuel Coward 2025-04-30 14:42:18 +01:00 committed by GitHub
parent 9a326368bb
commit 65ae143651
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 901 additions and 11 deletions

View File

@ -15,6 +15,7 @@
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/IR/OpAsmInterface.td" include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/SymbolInterfaces.td"

View File

@ -20,6 +20,7 @@
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -18,6 +18,10 @@
#include <memory> #include <memory>
#include <optional> #include <optional>
namespace mlir {
class DataFlowSolver;
}
namespace circt { namespace circt {
namespace comb { namespace comb {
@ -26,6 +30,10 @@ namespace comb {
#define GEN_PASS_REGISTRATION #define GEN_PASS_REGISTRATION
#include "circt/Dialect/Comb/Passes.h.inc" #include "circt/Dialect/Comb/Passes.h.inc"
/// Add patterns for int range based narrowing.
void populateCombNarrowingPatterns(mlir::RewritePatternSet &patterns,
mlir::DataFlowSolver &solver);
} // namespace comb } // namespace comb
} // namespace circt } // namespace circt

View File

@ -15,6 +15,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/IR/EnumAttr.td" include "mlir/IR/EnumAttr.td"
// Base class for binary operators. // Base class for binary operators.
@ -30,7 +31,9 @@ class BinOp<string mnemonic, list<Trait> traits = []> :
// Binary operator with uniform input/result types. // Binary operator with uniform input/result types.
class UTBinOp<string mnemonic, list<Trait> traits = []> : class UTBinOp<string mnemonic, list<Trait> traits = []> :
BinOp<mnemonic, BinOp<mnemonic,
traits # [SameTypeOperands, SameOperandsAndResultType]> { traits # [SameTypeOperands, SameOperandsAndResultType,
DeclareOpInterfaceMethods<InferIntRangeInterface,
["inferResultRanges"]>]> {
let assemblyFormat = "(`bin` $twoState^)? $lhs `,` $rhs attr-dict `:` qualified(type($result))"; let assemblyFormat = "(`bin` $twoState^)? $lhs `,` $rhs attr-dict `:` qualified(type($result))";
} }
@ -42,8 +45,10 @@ class VariadicOp<string mnemonic, list<Trait> traits = []> :
} }
class UTVariadicOp<string mnemonic, list<Trait> traits = []> : class UTVariadicOp<string mnemonic, list<Trait> traits = []> :
VariadicOp<mnemonic, VariadicOp<mnemonic,
traits # [SameTypeOperands, SameOperandsAndResultType]> { traits # [SameTypeOperands, SameOperandsAndResultType,
DeclareOpInterfaceMethods<InferIntRangeInterface,
["inferResultRanges"]>]> {
let hasCanonicalizeMethod = true; let hasCanonicalizeMethod = true;
let hasFolder = true; let hasFolder = true;
@ -76,7 +81,7 @@ let hasFolder = true in {
} }
def AndOp : UTVariadicOp<"and", [Commutative]>; def AndOp : UTVariadicOp<"and", [Commutative]>;
def OrOp : UTVariadicOp<"or", [Commutative]>; def OrOp : UTVariadicOp<"or", [Commutative]>;
def XorOp : UTVariadicOp<"xor", [Commutative]> { def XorOp : UTVariadicOp<"xor", [Commutative]> {
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// Return true if this is a two operand xor with an all ones constant as /// Return true if this is a two operand xor with an all ones constant as
@ -114,7 +119,10 @@ def ICmpPredicate : I64EnumAttr<
ICmpPredicateUGT, ICmpPredicateUGE, ICmpPredicateCEQ, ICmpPredicateCNE, ICmpPredicateUGT, ICmpPredicateUGE, ICmpPredicateCEQ, ICmpPredicateCNE,
ICmpPredicateWEQ, ICmpPredicateWNE]>; ICmpPredicateWEQ, ICmpPredicateWNE]>;
def ICmpOp : CombOp<"icmp", [Pure, SameTypeOperands]> { def ICmpOp : CombOp<"icmp",
[Pure,
SameTypeOperands,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "Compare two integer values"; let summary = "Compare two integer values";
let description = [{ let description = [{
This operation compares two integers using a predicate. If the predicate is This operation compares two integers using a predicate. If the predicate is
@ -178,7 +186,9 @@ def ParityOp : UnaryI1ReductionOp<"parity">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Extract a range of bits from the specified input. // Extract a range of bits from the specified input.
def ExtractOp : CombOp<"extract", [Pure]> { def ExtractOp : CombOp<"extract",
[Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "Extract a range of bits into a smaller value, lowBit " let summary = "Extract a range of bits into a smaller value, lowBit "
"specifies the lowest bit included."; "specifies the lowest bit included.";
@ -203,7 +213,9 @@ def ExtractOp : CombOp<"extract", [Pure]> {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Other Operations // Other Operations
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def ConcatOp : CombOp<"concat", [InferTypeOpInterface, Pure]> { def ConcatOp : CombOp<"concat",
[InferTypeOpInterface, Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "Concatenate a variadic list of operands together."; let summary = "Concatenate a variadic list of operands together.";
let description = [{ let description = [{
See the comb rationale document for details on operand ordering. See the comb rationale document for details on operand ordering.
@ -236,7 +248,9 @@ def ConcatOp : CombOp<"concat", [InferTypeOpInterface, Pure]> {
}]; }];
} }
def ReplicateOp : CombOp<"replicate", [Pure]> { def ReplicateOp : CombOp<"replicate",
[Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "Concatenate the operand a constant number of times"; let summary = "Concatenate the operand a constant number of times";
let arguments = (ins HWIntegerType:$input); let arguments = (ins HWIntegerType:$input);
@ -266,8 +280,10 @@ def ReplicateOp : CombOp<"replicate", [Pure]> {
} }
// Select one of two values based on a condition. // Select one of two values based on a condition.
def MuxOp : CombOp<"mux", def MuxOp : CombOp<"mux",
[Pure, AllTypesMatch<["trueValue", "falseValue", "result"]>]> { [Pure, AllTypesMatch<["trueValue", "falseValue", "result"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface,
["inferResultRangesFromOptional"]>]> {
let summary = "Return one or the other operand depending on a selector bit"; let summary = "Return one or the other operand depending on a selector bit";
let description = [{ let description = [{
``` ```

View File

@ -22,4 +22,16 @@ def LowerComb : Pass<"lower-comb"> {
}]; }];
} }
def CombIntRangeNarrowing : Pass<"comb-int-range-narrowing"> {
let summary = "Reduce comb op bitwidth based on integer range analysis.";
let description = [{
Compute a basic value range analysis, by propagating integer intervals
through the domain. The analysis is limited by a lack of sign-extension
operator in the comb dialect, leading to an over-approximation.
Particularly for signed arithmetic, a single interval is often an
over-approximation, a more precise analysis would require a union of
intervals.
}];
}
#endif // CIRCT_DIALECT_COMB_PASSES_TD #endif // CIRCT_DIALECT_COMB_PASSES_TD

View File

@ -19,11 +19,14 @@ include "circt/Dialect/HW/HWOpInterfaces.td"
include "circt/Dialect/HW/HWTypes.td" include "circt/Dialect/HW/HWTypes.td"
include "mlir/IR/OpAsmInterface.td" include "mlir/IR/OpAsmInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
def ConstantOp def ConstantOp
: HWOp<"constant", [Pure, ConstantLike, FirstAttrDerivedResultType, : HWOp<"constant", [Pure, ConstantLike, FirstAttrDerivedResultType,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> { DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface,
["inferResultRanges"]>]> {
let summary = "Produce a constant value"; let summary = "Produce a constant value";
let description = [{ let description = [{
The constant operation produces a constant value of standard integer type The constant operation produces a constant value of standard integer type

View File

@ -26,6 +26,7 @@
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"

View File

@ -18,6 +18,8 @@
#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" #include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h"
#include "circt/Dialect/HW/HWInstanceGraph.h" #include "circt/Dialect/HW/HWInstanceGraph.h"
#include "circt/Scheduling/Problems.h" #include "circt/Scheduling/Problems.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
@ -29,6 +31,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::affine; using namespace mlir::affine;
using namespace mlir::dataflow;
using namespace circt; using namespace circt;
using namespace circt::analysis; using namespace circt::analysis;
using namespace circt::scheduling; using namespace circt::scheduling;
@ -263,6 +266,69 @@ void FIRRTLInstanceInfoPass::runOnOperation() {
printModuleInfo(op, iInfo); printModuleInfo(op, iInfo);
} }
//===----------------------------------------------------------------------===//
// Comb IntRange Analysis
//===----------------------------------------------------------------------===//
namespace {
struct TestCombIntegerRangeAnalysisPass
: public PassWrapper<TestCombIntegerRangeAnalysisPass,
OperationPass<mlir::ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCombIntegerRangeAnalysisPass)
void runOnOperation() override;
StringRef getArgument() const override {
return "test-comb-int-range-analysis";
}
StringRef getDescription() const override {
return "Perform integer range analysis on comb dialect and set results as "
"attributes.";
}
};
} // namespace
void TestCombIntegerRangeAnalysisPass::runOnOperation() {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
// Append the integer range analysis as an operation attribute.
op->walk([&](Operation *op) {
for (auto value : op->getResults()) {
if (auto *range = solver.lookupState<IntegerValueRangeLattice>(value)) {
// All analyzed comb operations should return a single result.
assert(op->getResults().size() == 1 &&
"Expected a single result for the operation analysis");
assert(!range->getValue().isUninitialized() &&
"Expected a valid range for the value");
auto interval = range->getValue().getValue();
auto smax = interval.smax();
auto smaxAttr =
IntegerAttr::get(IntegerType::get(ctx, smax.getBitWidth()), smax);
op->setAttr("smax", smaxAttr);
auto smin = interval.smin();
auto sminAttr =
IntegerAttr::get(IntegerType::get(ctx, smin.getBitWidth()), smin);
op->setAttr("smin", sminAttr);
auto umax = interval.umax();
auto umaxAttr = IntegerAttr::get(
IntegerType::get(ctx, umax.getBitWidth(), IntegerType::Unsigned),
umax);
op->setAttr("umax", umaxAttr);
auto umin = interval.umin();
auto uminAttr = IntegerAttr::get(
IntegerType::get(ctx, umin.getBitWidth(), IntegerType::Unsigned),
umin);
op->setAttr("umin", uminAttr);
}
}
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pass registration // Pass registration
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -285,6 +351,9 @@ void registerAnalysisTestPasses() {
registerPass([]() -> std::unique_ptr<Pass> { registerPass([]() -> std::unique_ptr<Pass> {
return std::make_unique<FIRRTLInstanceInfoPass>(); return std::make_unique<FIRRTLInstanceInfoPass>();
}); });
registerPass([]() -> std::unique_ptr<Pass> {
return std::make_unique<TestCombIntegerRangeAnalysisPass>();
});
} }
} // namespace test } // namespace test
} // namespace circt } // namespace circt

View File

@ -3,6 +3,7 @@ add_circt_dialect_library(CIRCTComb
CombOps.cpp CombOps.cpp
CombAnalysis.cpp CombAnalysis.cpp
CombDialect.cpp CombDialect.cpp
InferIntRangeInterfaceImpls.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/Comb ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/Comb
@ -19,6 +20,7 @@ add_circt_dialect_library(CIRCTComb
CIRCTHW CIRCTHW
MLIRIR MLIRIR
MLIRInferTypeOpInterface MLIRInferTypeOpInterface
MLIRInferIntRangeInterface
) )
add_dependencies(circt-headers MLIRCombIncGen MLIRCombEnumsIncGen) add_dependencies(circt-headers MLIRCombIncGen MLIRCombEnumsIncGen)

View File

@ -0,0 +1,309 @@
//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for comb -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implementation of the interval range analysis interface.
// The overflow flags are not set for the comb operations since they is
// no meaningful concept of overflow detection in comb.
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Comb/CombOps.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
using namespace mlir;
using namespace mlir::intrange;
using namespace circt;
using namespace circt::comb;
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
void comb::AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
auto resultRange = argRanges[0];
for (auto argRange : argRanges.drop_front())
resultRange =
inferAdd({resultRange, argRange}, intrange::OverflowFlags::None);
setResultRange(getResult(), resultRange);
};
//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
void comb::SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferSub(argRanges, intrange::OverflowFlags::None));
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
void comb::MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
auto resultRange = argRanges[0];
for (auto argRange : argRanges.drop_front())
resultRange =
inferMul({resultRange, argRange}, intrange::OverflowFlags::None);
setResultRange(getResult(), resultRange);
}
//===----------------------------------------------------------------------===//
// DivUIOp
//===----------------------------------------------------------------------===//
void comb::DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferDivU(argRanges));
}
//===----------------------------------------------------------------------===//
// DivSIOp
//===----------------------------------------------------------------------===//
void comb::DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferDivS(argRanges));
}
//===----------------------------------------------------------------------===//
// ModUOp
//===----------------------------------------------------------------------===//
void comb::ModUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferRemU(argRanges));
}
//===----------------------------------------------------------------------===//
// ModSOp
//===----------------------------------------------------------------------===//
void comb::ModSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferRemS(argRanges));
}
//===----------------------------------------------------------------------===//
// AndOp
//===----------------------------------------------------------------------===//
void comb::AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
auto resultRange = argRanges[0];
for (auto argRange : argRanges.drop_front())
resultRange = inferAnd({resultRange, argRange});
setResultRange(getResult(), resultRange);
}
//===----------------------------------------------------------------------===//
// OrOp
//===----------------------------------------------------------------------===//
void comb::OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
auto resultRange = argRanges[0];
for (auto argRange : argRanges.drop_front())
resultRange = inferOr({resultRange, argRange});
setResultRange(getResult(), resultRange);
}
//===----------------------------------------------------------------------===//
// XorOp
//===----------------------------------------------------------------------===//
void comb::XorOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
auto resultRange = argRanges[0];
for (auto argRange : argRanges.drop_front())
resultRange = inferXor({resultRange, argRange});
setResultRange(getResult(), resultRange);
}
//===----------------------------------------------------------------------===//
// ShlOp
//===----------------------------------------------------------------------===//
void comb::ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferShl(argRanges, intrange::OverflowFlags::None));
}
//===----------------------------------------------------------------------===//
// ShRUIOp
//===----------------------------------------------------------------------===//
void comb::ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferShrU(argRanges));
}
//===----------------------------------------------------------------------===//
// ShRSIOp
//===----------------------------------------------------------------------===//
void comb::ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferShrS(argRanges));
}
//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//
void comb::ConcatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
// Compute concat as an unsigned integer of bits
const auto resWidth = getResult().getType().getIntOrFloatBitWidth();
auto totalWidth = resWidth;
APInt umin = APInt::getZero(resWidth);
APInt umax = APInt::getZero(resWidth);
for (auto [operand, arg] : llvm::zip(getOperands(), argRanges)) {
assert(totalWidth >= operand.getType().getIntOrFloatBitWidth() &&
"ConcatOp: total width in interval range calculation is negative");
totalWidth -= operand.getType().getIntOrFloatBitWidth();
auto uminUpd = arg.umin().zext(resWidth).ushl_sat(totalWidth);
auto umaxUpd = arg.umax().zext(resWidth).ushl_sat(totalWidth);
umin = umin.uadd_sat(uminUpd);
umax = umax.uadd_sat(umaxUpd);
}
auto urange = ConstantIntRanges::fromUnsigned(umin, umax);
setResultRange(getResult(), urange);
};
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
void comb::ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
// Right-shift and truncate (trunaction implicitly handled)
const auto resWidth = getResult().getType().getIntOrFloatBitWidth();
const auto lowBit = getLowBit();
auto umin = argRanges[0].umin().lshr(lowBit).trunc(resWidth);
auto umax = argRanges[0].umax().lshr(lowBit).trunc(resWidth);
auto urange = ConstantIntRanges::fromUnsigned(umin, umax);
setResultRange(getResult(), urange);
};
//===----------------------------------------------------------------------===//
// ReplicateOp
//===----------------------------------------------------------------------===//
void comb::ReplicateOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
// Compute replicate as an unsigned integer of bits
const auto operandWidth = getOperand().getType().getIntOrFloatBitWidth();
const auto resWidth = getResult().getType().getIntOrFloatBitWidth();
APInt umin = APInt::getZero(resWidth);
APInt umax = APInt::getZero(resWidth);
auto uminIn = argRanges[0].umin().zext(resWidth);
auto umaxIn = argRanges[0].umax().zext(resWidth);
for (unsigned int totalWidth = 0; totalWidth < resWidth;
totalWidth += operandWidth) {
auto uminUpd = uminIn.ushl_sat(totalWidth);
auto umaxUpd = umaxIn.ushl_sat(totalWidth);
umin = umin.uadd_sat(uminUpd);
umax = umax.uadd_sat(umaxUpd);
}
auto urange = ConstantIntRanges::fromUnsigned(umin, umax);
setResultRange(getResult(), urange);
};
//===----------------------------------------------------------------------===//
// MuxOp
//===----------------------------------------------------------------------===//
void comb::MuxOp::inferResultRangesFromOptional(
ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
std::optional<APInt> mbCondVal =
argRanges[0].isUninitialized()
? std::nullopt
: argRanges[0].getValue().getConstantValue();
const IntegerValueRange &trueCase = argRanges[1];
const IntegerValueRange &falseCase = argRanges[2];
if (mbCondVal) {
if (mbCondVal->isZero())
setResultRange(getResult(), falseCase);
else
setResultRange(getResult(), trueCase);
return;
}
setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
}
//===----------------------------------------------------------------------===//
// ICmpOp
//===----------------------------------------------------------------------===//
void comb::ICmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
comb::ICmpPredicate combPred = getPredicate();
APInt min = APInt::getZero(1);
APInt max = APInt::getAllOnes(1);
intrange::CmpPredicate pred;
switch (combPred) {
case comb::ICmpPredicate::eq:
pred = intrange::CmpPredicate::eq;
break;
case comb::ICmpPredicate::ne:
pred = intrange::CmpPredicate::ne;
break;
case comb::ICmpPredicate::slt:
pred = intrange::CmpPredicate::slt;
break;
case comb::ICmpPredicate::sle:
pred = intrange::CmpPredicate::sle;
break;
case comb::ICmpPredicate::sgt:
pred = intrange::CmpPredicate::sgt;
break;
case comb::ICmpPredicate::sge:
pred = intrange::CmpPredicate::sge;
break;
case comb::ICmpPredicate::ult:
pred = intrange::CmpPredicate::ult;
break;
case comb::ICmpPredicate::ule:
pred = intrange::CmpPredicate::ule;
break;
case comb::ICmpPredicate::ugt:
pred = intrange::CmpPredicate::ugt;
break;
case comb::ICmpPredicate::uge:
pred = intrange::CmpPredicate::uge;
break;
default:
// These predicates are not supported for integer range analysis
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
return;
}
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
if (truthValue.has_value() && *truthValue)
min = max;
else if (truthValue.has_value() && !(*truthValue))
max = min;
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
}

View File

@ -1,5 +1,6 @@
add_circt_dialect_library(CIRCTCombTransforms add_circt_dialect_library(CIRCTCombTransforms
LowerComb.cpp LowerComb.cpp
IntRangeOptimizations.cpp
DEPENDS DEPENDS
CIRCTCombTransformsIncGen CIRCTCombTransformsIncGen

View File

@ -0,0 +1,140 @@
//===- IntRangeOptimizations.cpp - Narrow ops in comb ------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/Comb/CombPasses.h"
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace circt;
using namespace circt::comb;
using namespace mlir;
using namespace mlir::dataflow;
namespace circt {
namespace comb {
#define GEN_PASS_DEF_COMBINTRANGENARROWING
#include "circt/Dialect/Comb/Passes.h.inc"
} // namespace comb
} // namespace circt
/// Gather ranges for all the values in `values`. Appends to the existing
/// vector.
static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
SmallVectorImpl<ConstantIntRanges> &ranges) {
for (Value val : values) {
auto *maybeInferredRange =
solver.lookupState<IntegerValueRangeLattice>(val);
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
return failure();
const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();
ranges.push_back(inferredRange);
}
return success();
}
namespace {
template <typename CombOpTy>
struct CombOpNarrow : public OpRewritePattern<CombOpTy> {
CombOpNarrow(MLIRContext *context, DataFlowSolver &s)
: OpRewritePattern<CombOpTy>(context), solver(s) {}
LogicalResult matchAndRewrite(CombOpTy op,
PatternRewriter &rewriter) const override {
auto opWidth = op.getType().getIntOrFloatBitWidth();
if (op->getNumOperands() != 2 || op->getNumResults() != 1)
return rewriter.notifyMatchFailure(
op, "Only support binary operations with one result");
SmallVector<ConstantIntRanges> ranges;
if (failed(collectRanges(solver, op->getOperands(), ranges)))
return rewriter.notifyMatchFailure(op, "input without specified range");
if (failed(collectRanges(solver, op->getResults(), ranges)))
return rewriter.notifyMatchFailure(op, "output without specified range");
auto removeWidth = ranges[0].umax().countLeadingZeros();
for (const ConstantIntRanges &range : ranges) {
auto rangeCanRemove = range.umax().countLeadingZeros();
removeWidth = std::min(removeWidth, rangeCanRemove);
}
if (removeWidth == 0)
return rewriter.notifyMatchFailure(op, "no bits to remove");
if (removeWidth == opWidth)
return rewriter.notifyMatchFailure(
op, "all bits to remove - replace by zero");
// Replace operator by narrower version of itself
Value lhs = op.getOperand(0);
Value rhs = op.getOperand(1);
Location loc = op.getLoc();
auto newWidth = opWidth - removeWidth;
// Create a replacement type for the extracted bits
auto replaceType = rewriter.getIntegerType(newWidth);
// Extract the lsbs from each operand
auto extractLhsOp =
rewriter.create<comb::ExtractOp>(loc, replaceType, lhs, 0);
auto extractRhsOp =
rewriter.create<comb::ExtractOp>(loc, replaceType, rhs, 0);
auto narrowOp = rewriter.create<CombOpTy>(loc, extractLhsOp, extractRhsOp);
// Concatenate zeros to match the original operator width
auto zero =
rewriter.create<hw::ConstantOp>(loc, APInt::getZero(removeWidth));
auto replaceOp = rewriter.create<comb::ConcatOp>(
loc, op.getType(), ValueRange{zero, narrowOp});
rewriter.replaceOp(op, replaceOp);
return success();
}
private:
DataFlowSolver &solver;
};
struct CombIntRangeNarrowingPass
: comb::impl::CombIntRangeNarrowingBase<CombIntRangeNarrowingPass> {
using CombIntRangeNarrowingBase::CombIntRangeNarrowingBase;
void runOnOperation() override;
};
} // namespace
void CombIntRangeNarrowingPass::runOnOperation() {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
RewritePatternSet patterns(ctx);
populateCombNarrowingPatterns(patterns, solver);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
signalPassFailure();
}
void comb::populateCombNarrowingPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver) {
patterns.add<CombOpNarrow<comb::AddOp>, CombOpNarrow<comb::MulOp>,
CombOpNarrow<comb::SubOp>>(patterns.getContext(), solver);
}

View File

@ -15,6 +15,7 @@ set(CIRCT_HW_Sources
ModuleImplementation.cpp ModuleImplementation.cpp
InnerSymbolTable.cpp InnerSymbolTable.cpp
PortConverter.cpp PortConverter.cpp
InferIntRangeInterfaceImpls.cpp
) )
set(LLVM_OPTIONAL_SOURCES set(LLVM_OPTIONAL_SOURCES
@ -41,6 +42,8 @@ add_circt_dialect_library(CIRCTHW
MLIRIR MLIRIR
MLIRInferTypeOpInterface MLIRInferTypeOpInterface
MLIRMemorySlotInterfaces MLIRMemorySlotInterfaces
MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
) )
add_circt_library(CIRCTHWReductions add_circt_library(CIRCTHWReductions

View File

@ -0,0 +1,25 @@
//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for HW -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
using namespace mlir;
using namespace mlir::intrange;
using namespace circt;
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
void hw::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), ConstantIntRanges::constant(getValue()));
}

View File

@ -0,0 +1,174 @@
// RUN: circt-opt %s --test-comb-int-range-analysis | FileCheck %s
// CHECK-LABEL: @basic_csa
hw.module @basic_csa(in %a : i1, in %b : i1, in %c : i1, out add_abc : i3) {
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2}
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %false, %a {smax = 1 : i2, smin = 0 : i2, umax = 1 : ui2, umin = 0 : ui2} : i1, i1
// CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %false, %b {smax = 1 : i2, smin = 0 : i2, umax = 1 : ui2, umin = 0 : ui2} : i1, i1
// CHECK-NEXT: %[[ADD:.+]] = comb.add %[[A_EXT]], %[[B_EXT]] {smax = 1 : i2, smin = -2 : i2, umax = 2 : ui2, umin = 0 : ui2} : i2
// CHECK-NEXT: %[[ADD_EXT:.+]] = comb.concat %false, %[[ADD]] {smax = 2 : i3, smin = 0 : i3, umax = 2 : ui3, umin = 0 : ui3} : i1, i2
// CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i2, %c {smax = 1 : i3, smin = 0 : i3, umax = 1 : ui3, umin = 0 : ui3} : i2, i1
// CHECK-NEXT: %[[ADD1:.+]] = comb.add %[[ADD_EXT]], %[[C_EXT]] {smax = 3 : i3, smin = 0 : i3, umax = 3 : ui3, umin = 0 : ui3} : i3
%c0_i2 = hw.constant 0 : i2
%false = hw.constant false
%0 = comb.concat %false, %a : i1, i1
%1 = comb.concat %false, %b : i1, i1
%2 = comb.add %0, %1 : i2
%3 = comb.concat %false, %2 : i1, i2
%4 = comb.concat %c0_i2, %c : i2, i1
%5 = comb.add %3, %4 : i3
hw.output %5 : i3
}
// CHECK-LABEL: @basic_mux
hw.module @basic_mux(in %a : i3, in %b : i3, in %sel : i1, out y : i4) {
// CHECK-NEXT: %false = hw.constant false {smax = false, smin = false, umax = 0 : ui1, umin = 0 : ui1}
// CHECK-NEXT: %true = hw.constant true {smax = true, smin = true, umax = 1 : ui1, umin = 1 : ui1}
// CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %true, %a {smax = -1 : i4, smin = -8 : i4, umax = 15 : ui4, umin = 8 : ui4} : i1, i3
// CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %false, %b {smax = 7 : i4, smin = 0 : i4, umax = 7 : ui4, umin = 0 : ui4} : i1, i3
// CHECK-NEXT: %[[MUX:.+]] = comb.mux %sel, %[[A_EXT]], %[[B_EXT]] {smax = 7 : i4, smin = -8 : i4, umax = 15 : ui4, umin = 0 : ui4} : i4
%false = hw.constant false
%true = hw.constant true
%0 = comb.concat %true, %a : i1, i3
%1 = comb.concat %false, %b : i1, i3
%2 = comb.mux %sel, %0, %1 : i4
hw.output %2 : i4
}
// CHECK-LABEL: @basic_fma
hw.module @basic_fma(in %a : i4, in %b : i4, in %c : i4, out d : i9) {
// CHECK-NEXT: %c0_i5 = hw.constant 0 : i5 {smax = 0 : i5, smin = 0 : i5, umax = 0 : ui5, umin = 0 : ui5}
// CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i5, %a {smax = 15 : i9, smin = 0 : i9, umax = 15 : ui9, umin = 0 : ui9} : i5, i4
// CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i5, %b {smax = 15 : i9, smin = 0 : i9, umax = 15 : ui9, umin = 0 : ui9} : i5, i4
// CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A_EXT]], %[[B_EXT]] {smax = 225 : i9, smin = 0 : i9, umax = 225 : ui9, umin = 0 : ui9} : i9
// CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i5, %c {smax = 15 : i9, smin = 0 : i9, umax = 15 : ui9, umin = 0 : ui9} : i5, i4
// CHECK-NEXT: %[[ADD:.+]] = comb.add %[[MUL]], %[[C_EXT]] {smax = 240 : i9, smin = 0 : i9, umax = 240 : ui9, umin = 0 : ui9} : i9
%c0_i5 = hw.constant 0 : i5
%0 = comb.concat %c0_i5, %a : i5, i4
%1 = comb.concat %c0_i5, %b : i5, i4
%2 = comb.mul %0, %1 : i9
%3 = comb.concat %c0_i5, %c : i5, i4
%4 = comb.add %2, %3 : i9
hw.output %4 : i9
}
// CHECK-LABEL: @const_sub
hw.module @const_sub(in %a : i8, out sub_res : i10) {
// CHECK-NEXT: %c256_i10 = hw.constant 256 : i10 {smax = 256 : i10, smin = 256 : i10, umax = 256 : ui10, umin = 256 : ui10}
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2}
// CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i2, %a {smax = 255 : i10, smin = 0 : i10, umax = 255 : ui10, umin = 0 : ui10} : i2, i8
// CHECK-NEXT: %[[SUB:.+]] = comb.sub %c256_i10, %[[A_EXT]] {smax = 256 : i10, smin = 1 : i10, umax = 256 : ui10, umin = 1 : ui10} : i10
%c256_i10 = hw.constant 256 : i10
%c0_i2 = hw.constant 0 : i2
%0 = comb.concat %c0_i2, %a : i2, i8
%1 = comb.sub %c256_i10, %0 : i10
hw.output %1 : i10
}
// CHECK-LABEL: @logical_ops
hw.module @logical_ops(in %a : i8, in %b : i9, in %c : i10, in %d : i16, out res : i18) {
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2}
// CHECK-NEXT: %false = hw.constant false {smax = false, smin = false, umax = 0 : ui1, umin = 0 : ui1}
// CHECK-NEXT: %c0_i9 = hw.constant 0 : i9 {smax = 0 : i9, smin = 0 : i9, umax = 0 : ui9, umin = 0 : ui9}
// CHECK-NEXT: %c0_i8 = hw.constant 0 : i8 {smax = 0 : i8, smin = 0 : i8, umax = 0 : ui8, umin = 0 : ui8}
// CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i9, %a {smax = 255 : i17, smin = 0 : i17, umax = 255 : ui17, umin = 0 : ui17} : i9, i8
// CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i8, %b {smax = 511 : i17, smin = 0 : i17, umax = 511 : ui17, umin = 0 : ui17} : i8, i9
// CHECK-NEXT: %[[AND:.+]] = comb.and %[[A_EXT]], %[[B_EXT]] {smax = 255 : i17, smin = 0 : i17, umax = 255 : ui17, umin = 0 : ui17} : i17
// CHECK-NEXT: %[[AND_EXT:.+]] = comb.concat %false, %[[AND]] {smax = 255 : i18, smin = 0 : i18, umax = 255 : ui18, umin = 0 : ui18} : i1, i17
// CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i8, %c {smax = 1023 : i18, smin = 0 : i18, umax = 1023 : ui18, umin = 0 : ui18} : i8, i10
// CHECK-NEXT: %[[OR:.+]] = comb.or %[[AND_EXT]], %[[C_EXT]] {smax = 1023 : i18, smin = 0 : i18, umax = 1023 : ui18, umin = 0 : ui18} : i18
// CHECK-NEXT: %[[D_EXT:.+]] = comb.concat %c0_i2, %d {smax = 65535 : i18, smin = 0 : i18, umax = 65535 : ui18, umin = 0 : ui18} : i2, i16
// CHECK-NEXT: %[[ADD:.+]] = comb.add %[[OR]], %[[D_EXT]] {smax = 66558 : i18, smin = 0 : i18, umax = 66558 : ui18, umin = 0 : ui18} : i18
%c0_i2 = hw.constant 0 : i2
%false = hw.constant false
%c0_i9 = hw.constant 0 : i9
%c0_i8 = hw.constant 0 : i8
%0 = comb.concat %c0_i9, %a : i9, i8
%1 = comb.concat %c0_i8, %b : i8, i9
%2 = comb.and %0, %1 : i17
%3 = comb.concat %false, %2 : i1, i17
%4 = comb.concat %c0_i8, %c : i8, i10
%5 = comb.or %3, %4 : i18
%6 = comb.concat %c0_i2, %d : i2, i16
%7 = comb.add %5, %6 : i18
hw.output %7 : i18
}
// CHECK-LABEL: @variadic_ops
hw.module @variadic_ops(in %a : i2, in %b : i2, in %c : i2) {
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2}
// CHECK-NEXT: %[[A_EXT2:.+]] = comb.concat %c0_i2, %a {smax = 3 : i4, smin = 0 : i4, umax = 3 : ui4, umin = 0 : ui4} : i2, i2
// CHECK-NEXT: %[[B_EXT2:.+]] = comb.concat %c0_i2, %b {smax = 3 : i4, smin = 0 : i4, umax = 3 : ui4, umin = 0 : ui4} : i2, i2
// CHECK-NEXT: %[[C_EXT2:.+]] = comb.concat %c0_i2, %c {smax = 3 : i4, smin = 0 : i4, umax = 3 : ui4, umin = 0 : ui4} : i2, i2
// CHECK-NEXT: %[[ADD:.+]] = comb.add %[[A_EXT2]], %[[B_EXT2]], %[[C_EXT2]] {smax = 7 : i4, smin = -8 : i4, umax = 9 : ui4, umin = 0 : ui4} : i4
// CHECK-NEXT: %c0_i3 = hw.constant 0 : i3 {smax = 0 : i3, smin = 0 : i3, umax = 0 : ui3, umin = 0 : ui3}
// CHECK-NEXT: %[[A_EXT3:.+]] = comb.concat %c0_i3, %a {smax = 3 : i5, smin = 0 : i5, umax = 3 : ui5, umin = 0 : ui5} : i3, i2
// CHECK-NEXT: %[[B_EXT3:.+]] = comb.concat %c0_i3, %b {smax = 3 : i5, smin = 0 : i5, umax = 3 : ui5, umin = 0 : ui5} : i3, i2
// CHECK-NEXT: %[[C_EXT3:.+]] = comb.concat %c0_i3, %c {smax = 3 : i5, smin = 0 : i5, umax = 3 : ui5, umin = 0 : ui5} : i3, i2
// CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A_EXT3]], %[[B_EXT3]], %[[C_EXT3]] {smax = 15 : i5, smin = -16 : i5, umax = 27 : ui5, umin = 0 : ui5} : i5
// CHECK-NEXT: %[[AND:.+]] = comb.and %a, %b, %c {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : i2
// CHECK-NEXT: %[[OR:.+]] = comb.or %a, %b, %c {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : i2
// CHECK-NEXT: %[[XOR:.+]] = comb.xor %a, %b, %c {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : i2
// CHECK-NEXT: hw.output
%c0_i2 = hw.constant 0 : i2
%0 = comb.concat %c0_i2, %a : i2, i2
%1 = comb.concat %c0_i2, %b : i2, i2
%2 = comb.concat %c0_i2, %c : i2, i2
%3 = comb.add %0, %1, %2 : i4
%c0_i3 = hw.constant 0 : i3
%4 = comb.concat %c0_i3, %a : i3, i2
%5 = comb.concat %c0_i3, %b : i3, i2
%6 = comb.concat %c0_i3, %c : i3, i2
%7 = comb.mul %4, %5, %6 : i5
%8 = comb.and %a, %b, %c : i2
%9 = comb.or %a, %b, %c : i2
%10 = comb.xor %a, %b, %c : i2
hw.output
}
// CHECK-LABEL: @replicate_extract
hw.module @replicate_extract(in %a : i3, in %b : i3, in %sel : i1) {
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2}
// CHECK-NEXT: %[[EXT_A:.+]] = comb.extract %a from 1 {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : (i3) -> i2
// CHECK-NEXT: %[[REPL_A:.+]] = comb.replicate %[[EXT_A]] {smax = 7 : i4, smin = -8 : i4, umax = 15 : ui4, umin = 0 : ui4} : (i2) -> i4
// CHECK-NEXT: %[[REPL_SEL:.+]] = comb.replicate %sel {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : (i1) -> i2
// CHECK-NEXT: %[[EXT_OUT:.+]] = comb.extract %[[REPL_A]] from 1 {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : (i4) -> i2
%c0_i2 = hw.constant 0 : i2
%0 = comb.extract %a from 1 : (i3) -> i2
%1 = comb.replicate %0 : (i2) -> i4
%2 = comb.replicate %sel : (i1) -> i2
%3 = comb.extract %1 from 1 : (i4) -> i2
hw.output
}
// CHECK-LABEL: @comp_predicates
hw.module @comp_predicates(in %a : i3, in %b : i3) {
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2}
// CHECK-NEXT: %c0_i3 = hw.constant 0 : i3 {smax = 0 : i3, smin = 0 : i3, umax = 0 : ui3, umin = 0 : ui3}
// CHECK-NEXT: %c-1_i3 = hw.constant -1 : i3 {smax = -1 : i3, smin = -1 : i3, umax = 7 : ui3, umin = 7 : ui3}
// CHECK-NEXT: %[[ULT:.+]] = comb.icmp ult %a, %c-1_i3 {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3
// CHECK-NEXT: %[[ULE:.+]] = comb.icmp ule %a, %c-1_i3 {smax = true, smin = true, umax = 1 : ui1, umin = 1 : ui1} : i3
// CHECK-NEXT: %[[UGT:.+]] = comb.icmp ugt %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3
// CHECK-NEXT: %[[UGE:.+]] = comb.icmp uge %a, %c0_i3 {smax = true, smin = true, umax = 1 : ui1, umin = 1 : ui1} : i3
// CHECK-NEXT: %[[SLT:.+]] = comb.icmp slt %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3
// CHECK-NEXT: %[[SLE:.+]] = comb.icmp sle %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3
// CHECK-NEXT: %[[SGT:.+]] = comb.icmp sgt %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3
// CHECK-NEXT: %[[SGE:.+]] = comb.icmp sge %a, %c0_i3 {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3
// CHECK-NEXT: %[[EQ:.+]] = comb.icmp eq %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3
// CHECK-NEXT: %[[NE:.+]] = comb.icmp ne %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3
%c0_i2 = hw.constant 0 : i2
%c0_i3 = hw.constant 0 : i3
%c7_i3 = hw.constant 7 : i3
%0 = comb.icmp ult %a, %c7_i3 : i3
%1 = comb.icmp ule %a, %c7_i3 : i3
%2 = comb.icmp ugt %a, %b : i3
%3 = comb.icmp uge %a, %c0_i3 : i3
%4 = comb.icmp slt %a, %b : i3
%5 = comb.icmp sle %a, %b : i3
%6 = comb.icmp sgt %a, %b : i3
%7 = comb.icmp sge %a, %c0_i3 : i3
%8 = comb.icmp eq %a, %b : i3
%9 = comb.icmp ne %a, %b : i3
hw.output
}

View File

@ -0,0 +1,125 @@
// RUN: circt-opt %s --comb-int-range-narrowing | FileCheck %s
// CHECK-LABEL: @basic_csa
hw.module @basic_csa(in %a : i1, in %b : i1, in %c : i1, out add_abc : i3) {
// CHECK-NEXT %c0_i2 = hw.constant 0 : i2
// CHECK-NEXT %false = hw.constant false
// CHECK-NEXT %[[A_EXT:.+]] = comb.concat %false, %a : i1, i1
// CHECK-NEXT %[[B_EXT:.+]] = comb.concat %false, %b : i1, i1
// CHECK-NEXT %[[ADD_2:.+]] = comb.add %[[A_EXT]], %[[B_EXT]] : i2
// CHECK-NEXT %[[ADD_2_EXT:.+]] = comb.concat %false, %[[ADD_2]] : i1, i2
// CHECK-NEXT %[[C_EXT:.+]] = comb.concat %c0_i2, %c : i2, i1
// CHECK-NEXT %[[ADD_2_2:.+]] = comb.extract %[[ADD_2_EXT]] from 0 : (i3) -> i2
// CHECK-NEXT %[[C_2:.+]] = comb.extract %[[C_EXT]] from 0 : (i3) -> i2
// CHECK-NEXT %[[ADD_3:.+]] = comb.add %[[ADD_2_2]], %[[C_2]] : i2
// CHECK-NEXT %[[RES:.+]] = comb.concat %false, %[[ADD_3]] : i1, i2
// CHECK-NEXT hw.output %[[RES]] : i3
%c0_i2 = hw.constant 0 : i2
%false = hw.constant false
%0 = comb.concat %false, %a : i1, i1
%1 = comb.concat %false, %b : i1, i1
%2 = comb.add %0, %1 : i2
%3 = comb.concat %false, %2 : i1, i2
%4 = comb.concat %c0_i2, %c : i2, i1
%5 = comb.add %3, %4 : i3
hw.output %5 : i3
}
// CHECK-LABEL: @basic_fma
hw.module @basic_fma(in %a : i4, in %b : i4, in %c : i4, out d : i9) {
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %c0_i5 = hw.constant 0 : i5
// CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i5, %a : i5, i4
// CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i5, %b : i5, i4
// CHECK-NEXT: %[[A:.+]] = comb.extract %[[A_EXT]] from 0 : (i9) -> i8
// CHECK-NEXT: %[[B:.+]] = comb.extract %[[B_EXT]] from 0 : (i9) -> i8
// CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A]], %[[B]] : i8
// CHECK-NEXT: %[[MUL_EXT:.+]] = comb.concat %false, %[[MUL]] : i1, i8
// CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i5, %c : i5, i4
// CHECK-NEXT: %[[MUL_T:.+]] = comb.extract %[[MUL_EXT]] from 0 : (i9) -> i8
// CHECK-NEXT: %[[C_T:.+]] = comb.extract %[[C_EXT]] from 0 : (i9) -> i8
// CHECK-NEXT: %[[ADD_OUT:.+]] = comb.add %[[MUL_T]], %[[C_T]] : i8
// CHECK-NEXT: %[[ADD_OUT_EXT:.+]] = comb.concat %false, %[[ADD_OUT]] : i1, i8
// CHECK-NEXT: hw.output %[[ADD_OUT_EXT]] : i9
%c0_i5 = hw.constant 0 : i5
%0 = comb.concat %c0_i5, %a : i5, i4
%1 = comb.concat %c0_i5, %b : i5, i4
%2 = comb.mul %0, %1 : i9
%3 = comb.concat %c0_i5, %c : i5, i4
%4 = comb.add %2, %3 : i9
hw.output %4 : i9
}
// CHECK-LABEL: @const_sub
hw.module @const_sub(in %a : i8, out sub_res : i10) {
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %c-256_i9 = hw.constant -256 : i9
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2
// CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i2, %a : i2, i8
// CHECK-NEXT: %[[A_T:.+]] = comb.extract %[[A_EXT]] from 0 : (i10) -> i9
// CHECK-NEXT: %[[SUB:.+]] = comb.sub %c-256_i9, %[[A_T]] : i9
// CHECK-NEXT: %[[RES:.+]] = comb.concat %false, %[[SUB]] : i1, i9
// CHECK-NEXT: hw.output %[[RES]] : i10
%c256_i10 = hw.constant 256 : i10
%c0_i2 = hw.constant 0 : i2
%0 = comb.concat %c0_i2, %a : i2, i8
%1 = comb.sub %c256_i10, %0 : i10
hw.output %1 : i10
}
// CHECK-LABEL: @do_nothing
hw.module @do_nothing(in %a : i8, in %b : i9, in %c : i10, in %d : i16, out res : i18) {
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2
// CHECK-NEXT: %[[FALSE:.+]] = hw.constant false
// CHECK-NEXT: %c0_i9 = hw.constant 0 : i9
// CHECK-NEXT: %c0_i8 = hw.constant 0 : i8
// CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i9, %a : i9, i8
// CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i8, %b : i8, i9
// CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A_EXT]], %[[B_EXT]] : i17
// CHECK-NEXT: %[[MUL_EXT:.+]] = comb.concat %[[FALSE]], %[[MUL]] : i1, i17
// CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i8, %c : i8, i10
// CHECK-NEXT: %[[D_EXT:.+]] = comb.concat %c0_i2, %d : i2, i16
// CHECK-NEXT: %[[RES:.+]] = comb.add %[[MUL_EXT]], %[[C_EXT]], %[[D_EXT]] : i18
// CHECK-NEXT: hw.output %[[RES]] : i18
%c0_i2 = hw.constant 0 : i2
%false = hw.constant false
%c0_i9 = hw.constant 0 : i9
%c0_i8 = hw.constant 0 : i8
%0 = comb.concat %c0_i9, %a : i9, i8
%1 = comb.concat %c0_i8, %b : i8, i9
%2 = comb.mul %0, %1 : i17
%3 = comb.concat %false, %2 : i1, i17
%4 = comb.concat %c0_i8, %c : i8, i10
%5 = comb.concat %c0_i2, %d : i2, i16
%6 = comb.add %3, %4, %5 : i18
hw.output %6 : i18
}
hw.module @logical_ops(in %a : i8, in %b : i9, in %c : i10, in %d : i16, out res : i18) {
// CHECK-NEXT %c0_i7 = hw.constant 0 : i7
// CHECK-NEXT %[[FALSE:.+]] = hw.constant false
// CHECK-NEXT %c0_i9 = hw.constant 0 : i9
// CHECK-NEXT %c0_i8 = hw.constant 0 : i8
// CHECK-NEXT %[[A_EXT:.+]] = comb.concat %c0_i9, %a : i9, i8
// CHECK-NEXT %[[B_EXT:.+]] = comb.concat %c0_i8, %b : i8, i9
// CHECK-NEXT %[[AND:.+]] = comb.and %[[A_EXT]], %[[B_EXT]] : i17
// CHECK-NEXT %[[C_EXT:.+]] = comb.concat %c0_i7, %c : i7, i10
// CHECK-NEXT %[[OR:.+]] = comb.or %[[AND]], %[[C_EXT]] : i17
// CHECK-NEXT %[[D_EXT:.+]] = comb.concat %[[FALSE]], %d : i1, i16
// CHECK-NEXT %[[ADD:.+]] = comb.add %[[OR]], %[[D_EXT]] : i17
// CHECK-NEXT %[[RES:.+]] = comb.concat %[[FALSE]], %[[ADD]] : i1, i17
// CHECK-NEXT hw.output %[[RES]] : i18
%c0_i2 = hw.constant 0 : i2
%false = hw.constant false
%c0_i9 = hw.constant 0 : i9
%c0_i8 = hw.constant 0 : i8
%0 = comb.concat %c0_i9, %a : i9, i8
%1 = comb.concat %c0_i8, %b : i8, i9
%2 = comb.and %0, %1 : i17
%3 = comb.concat %false, %2 : i1, i17
%4 = comb.concat %c0_i8, %c : i8, i10
%5 = comb.or %3, %4 : i18
%6 = comb.concat %c0_i2, %d : i2, i16
%7 = comb.add %5, %6 : i18
hw.output %7 : i18
}