From 65ae143651b51f69e9041f24cb535367abbfddde Mon Sep 17 00:00:00 2001 From: Samuel Coward <83779478+cowardsa@users.noreply.github.com> Date: Wed, 30 Apr 2025 14:42:18 +0100 Subject: [PATCH] Comb Interval Range Analysis and Comb Opt Narrowing pass (#8425) Building on the existing MLIR integer interval range analysis framework, build the interface for the comb dialect. Use the interval range analysis to develop a comb opt narrowing pass, that reduces a comb opt based on the interval range of the operation. Currently this is only supported for addition, subtraction and multiplication, but this could be extended in the future. Future work will leverage the interval analysis to validate the combination of consecutive addition operators into a multi-operand addition. --- include/circt/Dialect/Comb/Comb.td | 1 + include/circt/Dialect/Comb/CombOps.h | 1 + include/circt/Dialect/Comb/CombPasses.h | 8 + include/circt/Dialect/Comb/Combinational.td | 36 +- include/circt/Dialect/Comb/Passes.td | 12 + include/circt/Dialect/HW/HWMiscOps.td | 5 +- include/circt/Dialect/HW/HWOps.h | 1 + lib/Analysis/TestPasses.cpp | 69 ++++ lib/Dialect/Comb/CMakeLists.txt | 2 + .../Comb/InferIntRangeInterfaceImpls.cpp | 309 ++++++++++++++++++ lib/Dialect/Comb/Transforms/CMakeLists.txt | 1 + .../Comb/Transforms/IntRangeOptimizations.cpp | 140 ++++++++ lib/Dialect/HW/CMakeLists.txt | 3 + .../HW/InferIntRangeInterfaceImpls.cpp | 25 ++ test/Analysis/comb-int-range-analysis.mlir | 174 ++++++++++ .../Comb/comb-int-range-narrowing.mlir | 125 +++++++ 16 files changed, 901 insertions(+), 11 deletions(-) create mode 100644 lib/Dialect/Comb/InferIntRangeInterfaceImpls.cpp create mode 100644 lib/Dialect/Comb/Transforms/IntRangeOptimizations.cpp create mode 100644 lib/Dialect/HW/InferIntRangeInterfaceImpls.cpp create mode 100644 test/Analysis/comb-int-range-analysis.mlir create mode 100644 test/Dialect/Comb/comb-int-range-narrowing.mlir diff --git a/include/circt/Dialect/Comb/Comb.td b/include/circt/Dialect/Comb/Comb.td index 8f4d1d6f15..3bd7f90969 100644 --- a/include/circt/Dialect/Comb/Comb.td +++ b/include/circt/Dialect/Comb/Comb.td @@ -15,6 +15,7 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" diff --git a/include/circt/Dialect/Comb/CombOps.h b/include/circt/Dialect/Comb/CombOps.h index 8414c347c4..56bad66912 100644 --- a/include/circt/Dialect/Comb/CombOps.h +++ b/include/circt/Dialect/Comb/CombOps.h @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/include/circt/Dialect/Comb/CombPasses.h b/include/circt/Dialect/Comb/CombPasses.h index c41b577cd7..cdbaf906b9 100644 --- a/include/circt/Dialect/Comb/CombPasses.h +++ b/include/circt/Dialect/Comb/CombPasses.h @@ -18,6 +18,10 @@ #include #include +namespace mlir { +class DataFlowSolver; +} + namespace circt { namespace comb { @@ -26,6 +30,10 @@ namespace comb { #define GEN_PASS_REGISTRATION #include "circt/Dialect/Comb/Passes.h.inc" +/// Add patterns for int range based narrowing. +void populateCombNarrowingPatterns(mlir::RewritePatternSet &patterns, + mlir::DataFlowSolver &solver); + } // namespace comb } // namespace circt diff --git a/include/circt/Dialect/Comb/Combinational.td b/include/circt/Dialect/Comb/Combinational.td index 9321d2912d..9722e5e55c 100644 --- a/include/circt/Dialect/Comb/Combinational.td +++ b/include/circt/Dialect/Comb/Combinational.td @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/IR/EnumAttr.td" // Base class for binary operators. @@ -30,7 +31,9 @@ class BinOp traits = []> : // Binary operator with uniform input/result types. class UTBinOp traits = []> : BinOp { + traits # [SameTypeOperands, SameOperandsAndResultType, + DeclareOpInterfaceMethods]> { let assemblyFormat = "(`bin` $twoState^)? $lhs `,` $rhs attr-dict `:` qualified(type($result))"; } @@ -42,8 +45,10 @@ class VariadicOp traits = []> : } class UTVariadicOp traits = []> : - VariadicOp { + VariadicOp]> { let hasCanonicalizeMethod = true; let hasFolder = true; @@ -76,7 +81,7 @@ let hasFolder = true in { } def AndOp : UTVariadicOp<"and", [Commutative]>; -def OrOp : UTVariadicOp<"or", [Commutative]>; +def OrOp : UTVariadicOp<"or", [Commutative]>; def XorOp : UTVariadicOp<"xor", [Commutative]> { let extraClassDeclaration = [{ /// Return true if this is a two operand xor with an all ones constant as @@ -114,7 +119,10 @@ def ICmpPredicate : I64EnumAttr< ICmpPredicateUGT, ICmpPredicateUGE, ICmpPredicateCEQ, ICmpPredicateCNE, ICmpPredicateWEQ, ICmpPredicateWNE]>; -def ICmpOp : CombOp<"icmp", [Pure, SameTypeOperands]> { +def ICmpOp : CombOp<"icmp", + [Pure, + SameTypeOperands, + DeclareOpInterfaceMethods]> { let summary = "Compare two integer values"; let description = [{ This operation compares two integers using a predicate. If the predicate is @@ -178,7 +186,9 @@ def ParityOp : UnaryI1ReductionOp<"parity">; //===----------------------------------------------------------------------===// // Extract a range of bits from the specified input. -def ExtractOp : CombOp<"extract", [Pure]> { +def ExtractOp : CombOp<"extract", + [Pure, + DeclareOpInterfaceMethods]> { let summary = "Extract a range of bits into a smaller value, lowBit " "specifies the lowest bit included."; @@ -203,7 +213,9 @@ def ExtractOp : CombOp<"extract", [Pure]> { //===----------------------------------------------------------------------===// // Other Operations //===----------------------------------------------------------------------===// -def ConcatOp : CombOp<"concat", [InferTypeOpInterface, Pure]> { +def ConcatOp : CombOp<"concat", + [InferTypeOpInterface, Pure, + DeclareOpInterfaceMethods]> { let summary = "Concatenate a variadic list of operands together."; let description = [{ See the comb rationale document for details on operand ordering. @@ -236,7 +248,9 @@ def ConcatOp : CombOp<"concat", [InferTypeOpInterface, Pure]> { }]; } -def ReplicateOp : CombOp<"replicate", [Pure]> { +def ReplicateOp : CombOp<"replicate", + [Pure, + DeclareOpInterfaceMethods]> { let summary = "Concatenate the operand a constant number of times"; let arguments = (ins HWIntegerType:$input); @@ -266,8 +280,10 @@ def ReplicateOp : CombOp<"replicate", [Pure]> { } // Select one of two values based on a condition. -def MuxOp : CombOp<"mux", - [Pure, AllTypesMatch<["trueValue", "falseValue", "result"]>]> { +def MuxOp : CombOp<"mux", + [Pure, AllTypesMatch<["trueValue", "falseValue", "result"]>, + DeclareOpInterfaceMethods]> { let summary = "Return one or the other operand depending on a selector bit"; let description = [{ ``` diff --git a/include/circt/Dialect/Comb/Passes.td b/include/circt/Dialect/Comb/Passes.td index 197e8b1b46..8ce3960f3e 100644 --- a/include/circt/Dialect/Comb/Passes.td +++ b/include/circt/Dialect/Comb/Passes.td @@ -22,4 +22,16 @@ def LowerComb : Pass<"lower-comb"> { }]; } +def CombIntRangeNarrowing : Pass<"comb-int-range-narrowing"> { + let summary = "Reduce comb op bitwidth based on integer range analysis."; + let description = [{ + Compute a basic value range analysis, by propagating integer intervals + through the domain. The analysis is limited by a lack of sign-extension + operator in the comb dialect, leading to an over-approximation. + Particularly for signed arithmetic, a single interval is often an + over-approximation, a more precise analysis would require a union of + intervals. + }]; +} + #endif // CIRCT_DIALECT_COMB_PASSES_TD diff --git a/include/circt/Dialect/HW/HWMiscOps.td b/include/circt/Dialect/HW/HWMiscOps.td index bd534a5b75..b8b5e1ae9f 100644 --- a/include/circt/Dialect/HW/HWMiscOps.td +++ b/include/circt/Dialect/HW/HWMiscOps.td @@ -19,11 +19,14 @@ include "circt/Dialect/HW/HWOpInterfaces.td" include "circt/Dialect/HW/HWTypes.td" include "mlir/IR/OpAsmInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" def ConstantOp : HWOp<"constant", [Pure, ConstantLike, FirstAttrDerivedResultType, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Produce a constant value"; let description = [{ The constant operation produces a constant value of standard integer type diff --git a/include/circt/Dialect/HW/HWOps.h b/include/circt/Dialect/HW/HWOps.h index 75009161c4..02bff24c6d 100644 --- a/include/circt/Dialect/HW/HWOps.h +++ b/include/circt/Dialect/HW/HWOps.h @@ -26,6 +26,7 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/StringExtras.h" diff --git a/lib/Analysis/TestPasses.cpp b/lib/Analysis/TestPasses.cpp index 8c79c09715..c2b3f6bc95 100644 --- a/lib/Analysis/TestPasses.cpp +++ b/lib/Analysis/TestPasses.cpp @@ -18,6 +18,8 @@ #include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" #include "circt/Dialect/HW/HWInstanceGraph.h" #include "circt/Scheduling/Problems.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -29,6 +31,7 @@ using namespace mlir; using namespace mlir::affine; +using namespace mlir::dataflow; using namespace circt; using namespace circt::analysis; using namespace circt::scheduling; @@ -263,6 +266,69 @@ void FIRRTLInstanceInfoPass::runOnOperation() { printModuleInfo(op, iInfo); } +//===----------------------------------------------------------------------===// +// Comb IntRange Analysis +//===----------------------------------------------------------------------===// + +namespace { +struct TestCombIntegerRangeAnalysisPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCombIntegerRangeAnalysisPass) + + void runOnOperation() override; + StringRef getArgument() const override { + return "test-comb-int-range-analysis"; + } + StringRef getDescription() const override { + return "Perform integer range analysis on comb dialect and set results as " + "attributes."; + } +}; +} // namespace + +void TestCombIntegerRangeAnalysisPass::runOnOperation() { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + // Append the integer range analysis as an operation attribute. + op->walk([&](Operation *op) { + for (auto value : op->getResults()) { + if (auto *range = solver.lookupState(value)) { + // All analyzed comb operations should return a single result. + assert(op->getResults().size() == 1 && + "Expected a single result for the operation analysis"); + assert(!range->getValue().isUninitialized() && + "Expected a valid range for the value"); + auto interval = range->getValue().getValue(); + auto smax = interval.smax(); + auto smaxAttr = + IntegerAttr::get(IntegerType::get(ctx, smax.getBitWidth()), smax); + op->setAttr("smax", smaxAttr); + auto smin = interval.smin(); + auto sminAttr = + IntegerAttr::get(IntegerType::get(ctx, smin.getBitWidth()), smin); + op->setAttr("smin", sminAttr); + auto umax = interval.umax(); + auto umaxAttr = IntegerAttr::get( + IntegerType::get(ctx, umax.getBitWidth(), IntegerType::Unsigned), + umax); + op->setAttr("umax", umaxAttr); + auto umin = interval.umin(); + auto uminAttr = IntegerAttr::get( + IntegerType::get(ctx, umin.getBitWidth(), IntegerType::Unsigned), + umin); + op->setAttr("umin", uminAttr); + } + } + }); +} + //===----------------------------------------------------------------------===// // Pass registration //===----------------------------------------------------------------------===// @@ -285,6 +351,9 @@ void registerAnalysisTestPasses() { registerPass([]() -> std::unique_ptr { return std::make_unique(); }); + registerPass([]() -> std::unique_ptr { + return std::make_unique(); + }); } } // namespace test } // namespace circt diff --git a/lib/Dialect/Comb/CMakeLists.txt b/lib/Dialect/Comb/CMakeLists.txt index 9ecd57f380..d43ddbfc96 100644 --- a/lib/Dialect/Comb/CMakeLists.txt +++ b/lib/Dialect/Comb/CMakeLists.txt @@ -3,6 +3,7 @@ add_circt_dialect_library(CIRCTComb CombOps.cpp CombAnalysis.cpp CombDialect.cpp + InferIntRangeInterfaceImpls.cpp ADDITIONAL_HEADER_DIRS ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/Comb @@ -19,6 +20,7 @@ add_circt_dialect_library(CIRCTComb CIRCTHW MLIRIR MLIRInferTypeOpInterface + MLIRInferIntRangeInterface ) add_dependencies(circt-headers MLIRCombIncGen MLIRCombEnumsIncGen) diff --git a/lib/Dialect/Comb/InferIntRangeInterfaceImpls.cpp b/lib/Dialect/Comb/InferIntRangeInterfaceImpls.cpp new file mode 100644 index 0000000000..d5587bdef8 --- /dev/null +++ b/lib/Dialect/Comb/InferIntRangeInterfaceImpls.cpp @@ -0,0 +1,309 @@ +//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for comb -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation of the interval range analysis interface. +// The overflow flags are not set for the comb operations since they is +// no meaningful concept of overflow detection in comb. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Comb/CombOps.h" + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" + +using namespace mlir; +using namespace mlir::intrange; +using namespace circt; +using namespace circt::comb; +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void comb::AddOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto resultRange = argRanges[0]; + for (auto argRange : argRanges.drop_front()) + resultRange = + inferAdd({resultRange, argRange}, intrange::OverflowFlags::None); + + setResultRange(getResult(), resultRange); +}; + +//===----------------------------------------------------------------------===// +// SubOp +//===----------------------------------------------------------------------===// + +void comb::SubOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferSub(argRanges, intrange::OverflowFlags::None)); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +void comb::MulOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto resultRange = argRanges[0]; + for (auto argRange : argRanges.drop_front()) + resultRange = + inferMul({resultRange, argRange}, intrange::OverflowFlags::None); + + setResultRange(getResult(), resultRange); +} + +//===----------------------------------------------------------------------===// +// DivUIOp +//===----------------------------------------------------------------------===// + +void comb::DivUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferDivU(argRanges)); +} + +//===----------------------------------------------------------------------===// +// DivSIOp +//===----------------------------------------------------------------------===// + +void comb::DivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferDivS(argRanges)); +} + +//===----------------------------------------------------------------------===// +// ModUOp +//===----------------------------------------------------------------------===// + +void comb::ModUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferRemU(argRanges)); +} + +//===----------------------------------------------------------------------===// +// ModSOp +//===----------------------------------------------------------------------===// + +void comb::ModSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferRemS(argRanges)); +} +//===----------------------------------------------------------------------===// +// AndOp +//===----------------------------------------------------------------------===// + +void comb::AndOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto resultRange = argRanges[0]; + for (auto argRange : argRanges.drop_front()) + resultRange = inferAnd({resultRange, argRange}); + + setResultRange(getResult(), resultRange); +} + +//===----------------------------------------------------------------------===// +// OrOp +//===----------------------------------------------------------------------===// + +void comb::OrOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto resultRange = argRanges[0]; + for (auto argRange : argRanges.drop_front()) + resultRange = inferOr({resultRange, argRange}); + + setResultRange(getResult(), resultRange); +} + +//===----------------------------------------------------------------------===// +// XorOp +//===----------------------------------------------------------------------===// + +void comb::XorOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto resultRange = argRanges[0]; + for (auto argRange : argRanges.drop_front()) + resultRange = inferXor({resultRange, argRange}); + + setResultRange(getResult(), resultRange); +} + +//===----------------------------------------------------------------------===// +// ShlOp +//===----------------------------------------------------------------------===// + +void comb::ShlOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferShl(argRanges, intrange::OverflowFlags::None)); +} + +//===----------------------------------------------------------------------===// +// ShRUIOp +//===----------------------------------------------------------------------===// + +void comb::ShrUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferShrU(argRanges)); +} + +//===----------------------------------------------------------------------===// +// ShRSIOp +//===----------------------------------------------------------------------===// + +void comb::ShrSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferShrS(argRanges)); +} + +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +void comb::ConcatOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + // Compute concat as an unsigned integer of bits + const auto resWidth = getResult().getType().getIntOrFloatBitWidth(); + auto totalWidth = resWidth; + APInt umin = APInt::getZero(resWidth); + APInt umax = APInt::getZero(resWidth); + for (auto [operand, arg] : llvm::zip(getOperands(), argRanges)) { + assert(totalWidth >= operand.getType().getIntOrFloatBitWidth() && + "ConcatOp: total width in interval range calculation is negative"); + totalWidth -= operand.getType().getIntOrFloatBitWidth(); + auto uminUpd = arg.umin().zext(resWidth).ushl_sat(totalWidth); + auto umaxUpd = arg.umax().zext(resWidth).ushl_sat(totalWidth); + umin = umin.uadd_sat(uminUpd); + umax = umax.uadd_sat(umaxUpd); + } + auto urange = ConstantIntRanges::fromUnsigned(umin, umax); + setResultRange(getResult(), urange); +}; + +//===----------------------------------------------------------------------===// +// ExtractOp +//===----------------------------------------------------------------------===// + +void comb::ExtractOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + // Right-shift and truncate (trunaction implicitly handled) + const auto resWidth = getResult().getType().getIntOrFloatBitWidth(); + const auto lowBit = getLowBit(); + auto umin = argRanges[0].umin().lshr(lowBit).trunc(resWidth); + auto umax = argRanges[0].umax().lshr(lowBit).trunc(resWidth); + auto urange = ConstantIntRanges::fromUnsigned(umin, umax); + setResultRange(getResult(), urange); +}; + +//===----------------------------------------------------------------------===// +// ReplicateOp +//===----------------------------------------------------------------------===// + +void comb::ReplicateOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + // Compute replicate as an unsigned integer of bits + const auto operandWidth = getOperand().getType().getIntOrFloatBitWidth(); + const auto resWidth = getResult().getType().getIntOrFloatBitWidth(); + APInt umin = APInt::getZero(resWidth); + APInt umax = APInt::getZero(resWidth); + auto uminIn = argRanges[0].umin().zext(resWidth); + auto umaxIn = argRanges[0].umax().zext(resWidth); + for (unsigned int totalWidth = 0; totalWidth < resWidth; + totalWidth += operandWidth) { + auto uminUpd = uminIn.ushl_sat(totalWidth); + auto umaxUpd = umaxIn.ushl_sat(totalWidth); + umin = umin.uadd_sat(uminUpd); + umax = umax.uadd_sat(umaxUpd); + } + auto urange = ConstantIntRanges::fromUnsigned(umin, umax); + setResultRange(getResult(), urange); +}; + +//===----------------------------------------------------------------------===// +// MuxOp +//===----------------------------------------------------------------------===// + +void comb::MuxOp::inferResultRangesFromOptional( + ArrayRef argRanges, SetIntLatticeFn setResultRange) { + std::optional mbCondVal = + argRanges[0].isUninitialized() + ? std::nullopt + : argRanges[0].getValue().getConstantValue(); + + const IntegerValueRange &trueCase = argRanges[1]; + const IntegerValueRange &falseCase = argRanges[2]; + + if (mbCondVal) { + if (mbCondVal->isZero()) + setResultRange(getResult(), falseCase); + else + setResultRange(getResult(), trueCase); + return; + } + setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase)); +} + +//===----------------------------------------------------------------------===// +// ICmpOp +//===----------------------------------------------------------------------===// + +void comb::ICmpOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + comb::ICmpPredicate combPred = getPredicate(); + + APInt min = APInt::getZero(1); + APInt max = APInt::getAllOnes(1); + + intrange::CmpPredicate pred; + switch (combPred) { + case comb::ICmpPredicate::eq: + pred = intrange::CmpPredicate::eq; + break; + case comb::ICmpPredicate::ne: + pred = intrange::CmpPredicate::ne; + break; + case comb::ICmpPredicate::slt: + pred = intrange::CmpPredicate::slt; + break; + case comb::ICmpPredicate::sle: + pred = intrange::CmpPredicate::sle; + break; + case comb::ICmpPredicate::sgt: + pred = intrange::CmpPredicate::sgt; + break; + case comb::ICmpPredicate::sge: + pred = intrange::CmpPredicate::sge; + break; + case comb::ICmpPredicate::ult: + pred = intrange::CmpPredicate::ult; + break; + case comb::ICmpPredicate::ule: + pred = intrange::CmpPredicate::ule; + break; + case comb::ICmpPredicate::ugt: + pred = intrange::CmpPredicate::ugt; + break; + case comb::ICmpPredicate::uge: + pred = intrange::CmpPredicate::uge; + break; + default: + // These predicates are not supported for integer range analysis + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); + return; + } + + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + std::optional truthValue = intrange::evaluatePred(pred, lhs, rhs); + if (truthValue.has_value() && *truthValue) + min = max; + else if (truthValue.has_value() && !(*truthValue)) + max = min; + + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); +} diff --git a/lib/Dialect/Comb/Transforms/CMakeLists.txt b/lib/Dialect/Comb/Transforms/CMakeLists.txt index 9043ccf5de..45b19a1d50 100644 --- a/lib/Dialect/Comb/Transforms/CMakeLists.txt +++ b/lib/Dialect/Comb/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_circt_dialect_library(CIRCTCombTransforms LowerComb.cpp + IntRangeOptimizations.cpp DEPENDS CIRCTCombTransformsIncGen diff --git a/lib/Dialect/Comb/Transforms/IntRangeOptimizations.cpp b/lib/Dialect/Comb/Transforms/IntRangeOptimizations.cpp new file mode 100644 index 0000000000..ec57c15827 --- /dev/null +++ b/lib/Dialect/Comb/Transforms/IntRangeOptimizations.cpp @@ -0,0 +1,140 @@ +//===- IntRangeOptimizations.cpp - Narrow ops in comb ------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/Comb/CombPasses.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" + +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace circt; +using namespace circt::comb; +using namespace mlir; +using namespace mlir::dataflow; + +namespace circt { +namespace comb { +#define GEN_PASS_DEF_COMBINTRANGENARROWING +#include "circt/Dialect/Comb/Passes.h.inc" +} // namespace comb +} // namespace circt + +/// Gather ranges for all the values in `values`. Appends to the existing +/// vector. +static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, + SmallVectorImpl &ranges) { + for (Value val : values) { + auto *maybeInferredRange = + solver.lookupState(val); + if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) + return failure(); + + const ConstantIntRanges &inferredRange = + maybeInferredRange->getValue().getValue(); + ranges.push_back(inferredRange); + } + return success(); +} + +namespace { +template +struct CombOpNarrow : public OpRewritePattern { + CombOpNarrow(MLIRContext *context, DataFlowSolver &s) + : OpRewritePattern(context), solver(s) {} + + LogicalResult matchAndRewrite(CombOpTy op, + PatternRewriter &rewriter) const override { + + auto opWidth = op.getType().getIntOrFloatBitWidth(); + if (op->getNumOperands() != 2 || op->getNumResults() != 1) + return rewriter.notifyMatchFailure( + op, "Only support binary operations with one result"); + + SmallVector ranges; + if (failed(collectRanges(solver, op->getOperands(), ranges))) + return rewriter.notifyMatchFailure(op, "input without specified range"); + if (failed(collectRanges(solver, op->getResults(), ranges))) + return rewriter.notifyMatchFailure(op, "output without specified range"); + + auto removeWidth = ranges[0].umax().countLeadingZeros(); + for (const ConstantIntRanges &range : ranges) { + auto rangeCanRemove = range.umax().countLeadingZeros(); + removeWidth = std::min(removeWidth, rangeCanRemove); + } + if (removeWidth == 0) + return rewriter.notifyMatchFailure(op, "no bits to remove"); + if (removeWidth == opWidth) + return rewriter.notifyMatchFailure( + op, "all bits to remove - replace by zero"); + + // Replace operator by narrower version of itself + Value lhs = op.getOperand(0); + Value rhs = op.getOperand(1); + + Location loc = op.getLoc(); + auto newWidth = opWidth - removeWidth; + // Create a replacement type for the extracted bits + auto replaceType = rewriter.getIntegerType(newWidth); + + // Extract the lsbs from each operand + auto extractLhsOp = + rewriter.create(loc, replaceType, lhs, 0); + auto extractRhsOp = + rewriter.create(loc, replaceType, rhs, 0); + auto narrowOp = rewriter.create(loc, extractLhsOp, extractRhsOp); + + // Concatenate zeros to match the original operator width + auto zero = + rewriter.create(loc, APInt::getZero(removeWidth)); + auto replaceOp = rewriter.create( + loc, op.getType(), ValueRange{zero, narrowOp}); + + rewriter.replaceOp(op, replaceOp); + return success(); + } + +private: + DataFlowSolver &solver; +}; + +struct CombIntRangeNarrowingPass + : comb::impl::CombIntRangeNarrowingBase { + + using CombIntRangeNarrowingBase::CombIntRangeNarrowingBase; + void runOnOperation() override; +}; +} // namespace + +void CombIntRangeNarrowingPass::runOnOperation() { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + RewritePatternSet patterns(ctx); + populateCombNarrowingPatterns(patterns, solver); + + if (failed(applyPatternsGreedily(op, std::move(patterns)))) + signalPassFailure(); +} + +void comb::populateCombNarrowingPatterns(RewritePatternSet &patterns, + DataFlowSolver &solver) { + patterns.add, CombOpNarrow, + CombOpNarrow>(patterns.getContext(), solver); +} diff --git a/lib/Dialect/HW/CMakeLists.txt b/lib/Dialect/HW/CMakeLists.txt index a53eadf7fe..1ae22d1e85 100644 --- a/lib/Dialect/HW/CMakeLists.txt +++ b/lib/Dialect/HW/CMakeLists.txt @@ -15,6 +15,7 @@ set(CIRCT_HW_Sources ModuleImplementation.cpp InnerSymbolTable.cpp PortConverter.cpp + InferIntRangeInterfaceImpls.cpp ) set(LLVM_OPTIONAL_SOURCES @@ -41,6 +42,8 @@ add_circt_dialect_library(CIRCTHW MLIRIR MLIRInferTypeOpInterface MLIRMemorySlotInterfaces + MLIRInferIntRangeCommon + MLIRInferIntRangeInterface ) add_circt_library(CIRCTHWReductions diff --git a/lib/Dialect/HW/InferIntRangeInterfaceImpls.cpp b/lib/Dialect/HW/InferIntRangeInterfaceImpls.cpp new file mode 100644 index 0000000000..2b06306ab3 --- /dev/null +++ b/lib/Dialect/HW/InferIntRangeInterfaceImpls.cpp @@ -0,0 +1,25 @@ +//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for HW -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWOps.h" + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" + +using namespace mlir; +using namespace mlir::intrange; +using namespace circt; + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +void hw::ConstantOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), ConstantIntRanges::constant(getValue())); +} diff --git a/test/Analysis/comb-int-range-analysis.mlir b/test/Analysis/comb-int-range-analysis.mlir new file mode 100644 index 0000000000..fc60b7e496 --- /dev/null +++ b/test/Analysis/comb-int-range-analysis.mlir @@ -0,0 +1,174 @@ +// RUN: circt-opt %s --test-comb-int-range-analysis | FileCheck %s + +// CHECK-LABEL: @basic_csa +hw.module @basic_csa(in %a : i1, in %b : i1, in %c : i1, out add_abc : i3) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %false = hw.constant false + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %false, %a {smax = 1 : i2, smin = 0 : i2, umax = 1 : ui2, umin = 0 : ui2} : i1, i1 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %false, %b {smax = 1 : i2, smin = 0 : i2, umax = 1 : ui2, umin = 0 : ui2} : i1, i1 + // CHECK-NEXT: %[[ADD:.+]] = comb.add %[[A_EXT]], %[[B_EXT]] {smax = 1 : i2, smin = -2 : i2, umax = 2 : ui2, umin = 0 : ui2} : i2 + // CHECK-NEXT: %[[ADD_EXT:.+]] = comb.concat %false, %[[ADD]] {smax = 2 : i3, smin = 0 : i3, umax = 2 : ui3, umin = 0 : ui3} : i1, i2 + // CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i2, %c {smax = 1 : i3, smin = 0 : i3, umax = 1 : ui3, umin = 0 : ui3} : i2, i1 + // CHECK-NEXT: %[[ADD1:.+]] = comb.add %[[ADD_EXT]], %[[C_EXT]] {smax = 3 : i3, smin = 0 : i3, umax = 3 : ui3, umin = 0 : ui3} : i3 + %c0_i2 = hw.constant 0 : i2 + %false = hw.constant false + %0 = comb.concat %false, %a : i1, i1 + %1 = comb.concat %false, %b : i1, i1 + %2 = comb.add %0, %1 : i2 + %3 = comb.concat %false, %2 : i1, i2 + %4 = comb.concat %c0_i2, %c : i2, i1 + %5 = comb.add %3, %4 : i3 + hw.output %5 : i3 +} + +// CHECK-LABEL: @basic_mux +hw.module @basic_mux(in %a : i3, in %b : i3, in %sel : i1, out y : i4) { + // CHECK-NEXT: %false = hw.constant false {smax = false, smin = false, umax = 0 : ui1, umin = 0 : ui1} + // CHECK-NEXT: %true = hw.constant true {smax = true, smin = true, umax = 1 : ui1, umin = 1 : ui1} + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %true, %a {smax = -1 : i4, smin = -8 : i4, umax = 15 : ui4, umin = 8 : ui4} : i1, i3 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %false, %b {smax = 7 : i4, smin = 0 : i4, umax = 7 : ui4, umin = 0 : ui4} : i1, i3 + // CHECK-NEXT: %[[MUX:.+]] = comb.mux %sel, %[[A_EXT]], %[[B_EXT]] {smax = 7 : i4, smin = -8 : i4, umax = 15 : ui4, umin = 0 : ui4} : i4 + %false = hw.constant false + %true = hw.constant true + %0 = comb.concat %true, %a : i1, i3 + %1 = comb.concat %false, %b : i1, i3 + %2 = comb.mux %sel, %0, %1 : i4 + hw.output %2 : i4 +} + +// CHECK-LABEL: @basic_fma +hw.module @basic_fma(in %a : i4, in %b : i4, in %c : i4, out d : i9) { + // CHECK-NEXT: %c0_i5 = hw.constant 0 : i5 {smax = 0 : i5, smin = 0 : i5, umax = 0 : ui5, umin = 0 : ui5} + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i5, %a {smax = 15 : i9, smin = 0 : i9, umax = 15 : ui9, umin = 0 : ui9} : i5, i4 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i5, %b {smax = 15 : i9, smin = 0 : i9, umax = 15 : ui9, umin = 0 : ui9} : i5, i4 + // CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A_EXT]], %[[B_EXT]] {smax = 225 : i9, smin = 0 : i9, umax = 225 : ui9, umin = 0 : ui9} : i9 + // CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i5, %c {smax = 15 : i9, smin = 0 : i9, umax = 15 : ui9, umin = 0 : ui9} : i5, i4 + // CHECK-NEXT: %[[ADD:.+]] = comb.add %[[MUL]], %[[C_EXT]] {smax = 240 : i9, smin = 0 : i9, umax = 240 : ui9, umin = 0 : ui9} : i9 + %c0_i5 = hw.constant 0 : i5 + %0 = comb.concat %c0_i5, %a : i5, i4 + %1 = comb.concat %c0_i5, %b : i5, i4 + %2 = comb.mul %0, %1 : i9 + %3 = comb.concat %c0_i5, %c : i5, i4 + %4 = comb.add %2, %3 : i9 + hw.output %4 : i9 +} + +// CHECK-LABEL: @const_sub +hw.module @const_sub(in %a : i8, out sub_res : i10) { + // CHECK-NEXT: %c256_i10 = hw.constant 256 : i10 {smax = 256 : i10, smin = 256 : i10, umax = 256 : ui10, umin = 256 : ui10} + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i2, %a {smax = 255 : i10, smin = 0 : i10, umax = 255 : ui10, umin = 0 : ui10} : i2, i8 + // CHECK-NEXT: %[[SUB:.+]] = comb.sub %c256_i10, %[[A_EXT]] {smax = 256 : i10, smin = 1 : i10, umax = 256 : ui10, umin = 1 : ui10} : i10 + %c256_i10 = hw.constant 256 : i10 + %c0_i2 = hw.constant 0 : i2 + %0 = comb.concat %c0_i2, %a : i2, i8 + %1 = comb.sub %c256_i10, %0 : i10 + hw.output %1 : i10 +} + +// CHECK-LABEL: @logical_ops +hw.module @logical_ops(in %a : i8, in %b : i9, in %c : i10, in %d : i16, out res : i18) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %false = hw.constant false {smax = false, smin = false, umax = 0 : ui1, umin = 0 : ui1} + // CHECK-NEXT: %c0_i9 = hw.constant 0 : i9 {smax = 0 : i9, smin = 0 : i9, umax = 0 : ui9, umin = 0 : ui9} + // CHECK-NEXT: %c0_i8 = hw.constant 0 : i8 {smax = 0 : i8, smin = 0 : i8, umax = 0 : ui8, umin = 0 : ui8} + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i9, %a {smax = 255 : i17, smin = 0 : i17, umax = 255 : ui17, umin = 0 : ui17} : i9, i8 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i8, %b {smax = 511 : i17, smin = 0 : i17, umax = 511 : ui17, umin = 0 : ui17} : i8, i9 + // CHECK-NEXT: %[[AND:.+]] = comb.and %[[A_EXT]], %[[B_EXT]] {smax = 255 : i17, smin = 0 : i17, umax = 255 : ui17, umin = 0 : ui17} : i17 + // CHECK-NEXT: %[[AND_EXT:.+]] = comb.concat %false, %[[AND]] {smax = 255 : i18, smin = 0 : i18, umax = 255 : ui18, umin = 0 : ui18} : i1, i17 + // CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i8, %c {smax = 1023 : i18, smin = 0 : i18, umax = 1023 : ui18, umin = 0 : ui18} : i8, i10 + // CHECK-NEXT: %[[OR:.+]] = comb.or %[[AND_EXT]], %[[C_EXT]] {smax = 1023 : i18, smin = 0 : i18, umax = 1023 : ui18, umin = 0 : ui18} : i18 + // CHECK-NEXT: %[[D_EXT:.+]] = comb.concat %c0_i2, %d {smax = 65535 : i18, smin = 0 : i18, umax = 65535 : ui18, umin = 0 : ui18} : i2, i16 + // CHECK-NEXT: %[[ADD:.+]] = comb.add %[[OR]], %[[D_EXT]] {smax = 66558 : i18, smin = 0 : i18, umax = 66558 : ui18, umin = 0 : ui18} : i18 + %c0_i2 = hw.constant 0 : i2 + %false = hw.constant false + %c0_i9 = hw.constant 0 : i9 + %c0_i8 = hw.constant 0 : i8 + %0 = comb.concat %c0_i9, %a : i9, i8 + %1 = comb.concat %c0_i8, %b : i8, i9 + %2 = comb.and %0, %1 : i17 + %3 = comb.concat %false, %2 : i1, i17 + %4 = comb.concat %c0_i8, %c : i8, i10 + %5 = comb.or %3, %4 : i18 + %6 = comb.concat %c0_i2, %d : i2, i16 + %7 = comb.add %5, %6 : i18 + hw.output %7 : i18 +} + +// CHECK-LABEL: @variadic_ops +hw.module @variadic_ops(in %a : i2, in %b : i2, in %c : i2) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %[[A_EXT2:.+]] = comb.concat %c0_i2, %a {smax = 3 : i4, smin = 0 : i4, umax = 3 : ui4, umin = 0 : ui4} : i2, i2 + // CHECK-NEXT: %[[B_EXT2:.+]] = comb.concat %c0_i2, %b {smax = 3 : i4, smin = 0 : i4, umax = 3 : ui4, umin = 0 : ui4} : i2, i2 + // CHECK-NEXT: %[[C_EXT2:.+]] = comb.concat %c0_i2, %c {smax = 3 : i4, smin = 0 : i4, umax = 3 : ui4, umin = 0 : ui4} : i2, i2 + // CHECK-NEXT: %[[ADD:.+]] = comb.add %[[A_EXT2]], %[[B_EXT2]], %[[C_EXT2]] {smax = 7 : i4, smin = -8 : i4, umax = 9 : ui4, umin = 0 : ui4} : i4 + // CHECK-NEXT: %c0_i3 = hw.constant 0 : i3 {smax = 0 : i3, smin = 0 : i3, umax = 0 : ui3, umin = 0 : ui3} + // CHECK-NEXT: %[[A_EXT3:.+]] = comb.concat %c0_i3, %a {smax = 3 : i5, smin = 0 : i5, umax = 3 : ui5, umin = 0 : ui5} : i3, i2 + // CHECK-NEXT: %[[B_EXT3:.+]] = comb.concat %c0_i3, %b {smax = 3 : i5, smin = 0 : i5, umax = 3 : ui5, umin = 0 : ui5} : i3, i2 + // CHECK-NEXT: %[[C_EXT3:.+]] = comb.concat %c0_i3, %c {smax = 3 : i5, smin = 0 : i5, umax = 3 : ui5, umin = 0 : ui5} : i3, i2 + // CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A_EXT3]], %[[B_EXT3]], %[[C_EXT3]] {smax = 15 : i5, smin = -16 : i5, umax = 27 : ui5, umin = 0 : ui5} : i5 + // CHECK-NEXT: %[[AND:.+]] = comb.and %a, %b, %c {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : i2 + // CHECK-NEXT: %[[OR:.+]] = comb.or %a, %b, %c {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : i2 + // CHECK-NEXT: %[[XOR:.+]] = comb.xor %a, %b, %c {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : i2 + // CHECK-NEXT: hw.output + %c0_i2 = hw.constant 0 : i2 + %0 = comb.concat %c0_i2, %a : i2, i2 + %1 = comb.concat %c0_i2, %b : i2, i2 + %2 = comb.concat %c0_i2, %c : i2, i2 + %3 = comb.add %0, %1, %2 : i4 + %c0_i3 = hw.constant 0 : i3 + %4 = comb.concat %c0_i3, %a : i3, i2 + %5 = comb.concat %c0_i3, %b : i3, i2 + %6 = comb.concat %c0_i3, %c : i3, i2 + %7 = comb.mul %4, %5, %6 : i5 + %8 = comb.and %a, %b, %c : i2 + %9 = comb.or %a, %b, %c : i2 + %10 = comb.xor %a, %b, %c : i2 + hw.output +} + +// CHECK-LABEL: @replicate_extract +hw.module @replicate_extract(in %a : i3, in %b : i3, in %sel : i1) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %[[EXT_A:.+]] = comb.extract %a from 1 {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : (i3) -> i2 + // CHECK-NEXT: %[[REPL_A:.+]] = comb.replicate %[[EXT_A]] {smax = 7 : i4, smin = -8 : i4, umax = 15 : ui4, umin = 0 : ui4} : (i2) -> i4 + // CHECK-NEXT: %[[REPL_SEL:.+]] = comb.replicate %sel {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : (i1) -> i2 + // CHECK-NEXT: %[[EXT_OUT:.+]] = comb.extract %[[REPL_A]] from 1 {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : (i4) -> i2 + %c0_i2 = hw.constant 0 : i2 + %0 = comb.extract %a from 1 : (i3) -> i2 + %1 = comb.replicate %0 : (i2) -> i4 + %2 = comb.replicate %sel : (i1) -> i2 + %3 = comb.extract %1 from 1 : (i4) -> i2 + hw.output +} + +// CHECK-LABEL: @comp_predicates +hw.module @comp_predicates(in %a : i3, in %b : i3) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %c0_i3 = hw.constant 0 : i3 {smax = 0 : i3, smin = 0 : i3, umax = 0 : ui3, umin = 0 : ui3} + // CHECK-NEXT: %c-1_i3 = hw.constant -1 : i3 {smax = -1 : i3, smin = -1 : i3, umax = 7 : ui3, umin = 7 : ui3} + // CHECK-NEXT: %[[ULT:.+]] = comb.icmp ult %a, %c-1_i3 {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[ULE:.+]] = comb.icmp ule %a, %c-1_i3 {smax = true, smin = true, umax = 1 : ui1, umin = 1 : ui1} : i3 + // CHECK-NEXT: %[[UGT:.+]] = comb.icmp ugt %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[UGE:.+]] = comb.icmp uge %a, %c0_i3 {smax = true, smin = true, umax = 1 : ui1, umin = 1 : ui1} : i3 + // CHECK-NEXT: %[[SLT:.+]] = comb.icmp slt %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[SLE:.+]] = comb.icmp sle %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[SGT:.+]] = comb.icmp sgt %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[SGE:.+]] = comb.icmp sge %a, %c0_i3 {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[EQ:.+]] = comb.icmp eq %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[NE:.+]] = comb.icmp ne %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + %c0_i2 = hw.constant 0 : i2 + %c0_i3 = hw.constant 0 : i3 + %c7_i3 = hw.constant 7 : i3 + %0 = comb.icmp ult %a, %c7_i3 : i3 + %1 = comb.icmp ule %a, %c7_i3 : i3 + %2 = comb.icmp ugt %a, %b : i3 + %3 = comb.icmp uge %a, %c0_i3 : i3 + %4 = comb.icmp slt %a, %b : i3 + %5 = comb.icmp sle %a, %b : i3 + %6 = comb.icmp sgt %a, %b : i3 + %7 = comb.icmp sge %a, %c0_i3 : i3 + %8 = comb.icmp eq %a, %b : i3 + %9 = comb.icmp ne %a, %b : i3 + hw.output +} diff --git a/test/Dialect/Comb/comb-int-range-narrowing.mlir b/test/Dialect/Comb/comb-int-range-narrowing.mlir new file mode 100644 index 0000000000..26c68e5f92 --- /dev/null +++ b/test/Dialect/Comb/comb-int-range-narrowing.mlir @@ -0,0 +1,125 @@ +// RUN: circt-opt %s --comb-int-range-narrowing | FileCheck %s + +// CHECK-LABEL: @basic_csa +hw.module @basic_csa(in %a : i1, in %b : i1, in %c : i1, out add_abc : i3) { + // CHECK-NEXT %c0_i2 = hw.constant 0 : i2 + // CHECK-NEXT %false = hw.constant false + // CHECK-NEXT %[[A_EXT:.+]] = comb.concat %false, %a : i1, i1 + // CHECK-NEXT %[[B_EXT:.+]] = comb.concat %false, %b : i1, i1 + // CHECK-NEXT %[[ADD_2:.+]] = comb.add %[[A_EXT]], %[[B_EXT]] : i2 + // CHECK-NEXT %[[ADD_2_EXT:.+]] = comb.concat %false, %[[ADD_2]] : i1, i2 + // CHECK-NEXT %[[C_EXT:.+]] = comb.concat %c0_i2, %c : i2, i1 + // CHECK-NEXT %[[ADD_2_2:.+]] = comb.extract %[[ADD_2_EXT]] from 0 : (i3) -> i2 + // CHECK-NEXT %[[C_2:.+]] = comb.extract %[[C_EXT]] from 0 : (i3) -> i2 + // CHECK-NEXT %[[ADD_3:.+]] = comb.add %[[ADD_2_2]], %[[C_2]] : i2 + // CHECK-NEXT %[[RES:.+]] = comb.concat %false, %[[ADD_3]] : i1, i2 + // CHECK-NEXT hw.output %[[RES]] : i3 + %c0_i2 = hw.constant 0 : i2 + %false = hw.constant false + %0 = comb.concat %false, %a : i1, i1 + %1 = comb.concat %false, %b : i1, i1 + %2 = comb.add %0, %1 : i2 + %3 = comb.concat %false, %2 : i1, i2 + %4 = comb.concat %c0_i2, %c : i2, i1 + %5 = comb.add %3, %4 : i3 + hw.output %5 : i3 +} + +// CHECK-LABEL: @basic_fma +hw.module @basic_fma(in %a : i4, in %b : i4, in %c : i4, out d : i9) { + // CHECK-NEXT: %false = hw.constant false + // CHECK-NEXT: %c0_i5 = hw.constant 0 : i5 + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i5, %a : i5, i4 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i5, %b : i5, i4 + // CHECK-NEXT: %[[A:.+]] = comb.extract %[[A_EXT]] from 0 : (i9) -> i8 + // CHECK-NEXT: %[[B:.+]] = comb.extract %[[B_EXT]] from 0 : (i9) -> i8 + // CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A]], %[[B]] : i8 + // CHECK-NEXT: %[[MUL_EXT:.+]] = comb.concat %false, %[[MUL]] : i1, i8 + // CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i5, %c : i5, i4 + // CHECK-NEXT: %[[MUL_T:.+]] = comb.extract %[[MUL_EXT]] from 0 : (i9) -> i8 + // CHECK-NEXT: %[[C_T:.+]] = comb.extract %[[C_EXT]] from 0 : (i9) -> i8 + // CHECK-NEXT: %[[ADD_OUT:.+]] = comb.add %[[MUL_T]], %[[C_T]] : i8 + // CHECK-NEXT: %[[ADD_OUT_EXT:.+]] = comb.concat %false, %[[ADD_OUT]] : i1, i8 + // CHECK-NEXT: hw.output %[[ADD_OUT_EXT]] : i9 + %c0_i5 = hw.constant 0 : i5 + %0 = comb.concat %c0_i5, %a : i5, i4 + %1 = comb.concat %c0_i5, %b : i5, i4 + %2 = comb.mul %0, %1 : i9 + %3 = comb.concat %c0_i5, %c : i5, i4 + %4 = comb.add %2, %3 : i9 + hw.output %4 : i9 +} + +// CHECK-LABEL: @const_sub +hw.module @const_sub(in %a : i8, out sub_res : i10) { + // CHECK-NEXT: %false = hw.constant false + // CHECK-NEXT: %c-256_i9 = hw.constant -256 : i9 + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i2, %a : i2, i8 + // CHECK-NEXT: %[[A_T:.+]] = comb.extract %[[A_EXT]] from 0 : (i10) -> i9 + // CHECK-NEXT: %[[SUB:.+]] = comb.sub %c-256_i9, %[[A_T]] : i9 + // CHECK-NEXT: %[[RES:.+]] = comb.concat %false, %[[SUB]] : i1, i9 + // CHECK-NEXT: hw.output %[[RES]] : i10 + %c256_i10 = hw.constant 256 : i10 + %c0_i2 = hw.constant 0 : i2 + %0 = comb.concat %c0_i2, %a : i2, i8 + %1 = comb.sub %c256_i10, %0 : i10 + hw.output %1 : i10 +} + +// CHECK-LABEL: @do_nothing +hw.module @do_nothing(in %a : i8, in %b : i9, in %c : i10, in %d : i16, out res : i18) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 + // CHECK-NEXT: %[[FALSE:.+]] = hw.constant false + // CHECK-NEXT: %c0_i9 = hw.constant 0 : i9 + // CHECK-NEXT: %c0_i8 = hw.constant 0 : i8 + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i9, %a : i9, i8 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i8, %b : i8, i9 + // CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A_EXT]], %[[B_EXT]] : i17 + // CHECK-NEXT: %[[MUL_EXT:.+]] = comb.concat %[[FALSE]], %[[MUL]] : i1, i17 + // CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i8, %c : i8, i10 + // CHECK-NEXT: %[[D_EXT:.+]] = comb.concat %c0_i2, %d : i2, i16 + // CHECK-NEXT: %[[RES:.+]] = comb.add %[[MUL_EXT]], %[[C_EXT]], %[[D_EXT]] : i18 + // CHECK-NEXT: hw.output %[[RES]] : i18 + %c0_i2 = hw.constant 0 : i2 + %false = hw.constant false + %c0_i9 = hw.constant 0 : i9 + %c0_i8 = hw.constant 0 : i8 + %0 = comb.concat %c0_i9, %a : i9, i8 + %1 = comb.concat %c0_i8, %b : i8, i9 + %2 = comb.mul %0, %1 : i17 + %3 = comb.concat %false, %2 : i1, i17 + %4 = comb.concat %c0_i8, %c : i8, i10 + %5 = comb.concat %c0_i2, %d : i2, i16 + %6 = comb.add %3, %4, %5 : i18 + hw.output %6 : i18 +} + +hw.module @logical_ops(in %a : i8, in %b : i9, in %c : i10, in %d : i16, out res : i18) { + // CHECK-NEXT %c0_i7 = hw.constant 0 : i7 + // CHECK-NEXT %[[FALSE:.+]] = hw.constant false + // CHECK-NEXT %c0_i9 = hw.constant 0 : i9 + // CHECK-NEXT %c0_i8 = hw.constant 0 : i8 + // CHECK-NEXT %[[A_EXT:.+]] = comb.concat %c0_i9, %a : i9, i8 + // CHECK-NEXT %[[B_EXT:.+]] = comb.concat %c0_i8, %b : i8, i9 + // CHECK-NEXT %[[AND:.+]] = comb.and %[[A_EXT]], %[[B_EXT]] : i17 + // CHECK-NEXT %[[C_EXT:.+]] = comb.concat %c0_i7, %c : i7, i10 + // CHECK-NEXT %[[OR:.+]] = comb.or %[[AND]], %[[C_EXT]] : i17 + // CHECK-NEXT %[[D_EXT:.+]] = comb.concat %[[FALSE]], %d : i1, i16 + // CHECK-NEXT %[[ADD:.+]] = comb.add %[[OR]], %[[D_EXT]] : i17 + // CHECK-NEXT %[[RES:.+]] = comb.concat %[[FALSE]], %[[ADD]] : i1, i17 + // CHECK-NEXT hw.output %[[RES]] : i18 + %c0_i2 = hw.constant 0 : i2 + %false = hw.constant false + %c0_i9 = hw.constant 0 : i9 + %c0_i8 = hw.constant 0 : i8 + %0 = comb.concat %c0_i9, %a : i9, i8 + %1 = comb.concat %c0_i8, %b : i8, i9 + %2 = comb.and %0, %1 : i17 + %3 = comb.concat %false, %2 : i1, i17 + %4 = comb.concat %c0_i8, %c : i8, i10 + %5 = comb.or %3, %4 : i18 + %6 = comb.concat %c0_i2, %d : i2, i16 + %7 = comb.add %5, %6 : i18 + hw.output %7 : i18 +}