diff --git a/mlir/include/mlir/Dialect/SDBM/SDBM.h b/mlir/include/mlir/Dialect/SDBM/SDBM.h index 3115805bb5fa..97078465ff12 100644 --- a/mlir/include/mlir/Dialect/SDBM/SDBM.h +++ b/mlir/include/mlir/Dialect/SDBM/SDBM.h @@ -31,7 +31,7 @@ namespace mlir { class MLIRContext; class SDBMDialect; class SDBMExpr; -class SDBMPositiveExpr; +class SDBMTermExpr; /// A utility class for SDBM to represent an integer with potentially infinite /// positive value. This uses the largest value of int64_t to represent infinity @@ -130,14 +130,14 @@ private: /// and at(col,row) of the DBM. Depending on the values being finite and /// being subsumed by stripe expressions, this may or may not add elements to /// the lists of equalities and inequalities. - void convertDBMElement(unsigned row, unsigned col, SDBMPositiveExpr rowExpr, - SDBMPositiveExpr colExpr, + void convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr, + SDBMTermExpr colExpr, SmallVectorImpl &inequalities, SmallVectorImpl &equalities); /// Populate `inequalities` based on the value at(pos,pos) of the DBM. Only /// adds new inequalities if the inequality is not trivially true. - void convertDBMDiagonalElement(unsigned pos, SDBMPositiveExpr expr, + void convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr, SmallVectorImpl &inequalities); /// Get the total number of elements in the matrix. diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h index 1e695b68f975..fdb914d54d6f 100644 --- a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h +++ b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h @@ -38,7 +38,7 @@ namespace detail { struct SDBMExprStorage; struct SDBMBinaryExprStorage; struct SDBMDiffExprStorage; -struct SDBMPositiveExprStorage; +struct SDBMTermExprStorage; struct SDBMConstantExprStorage; struct SDBMNegExprStorage; } // namespace detail @@ -176,10 +176,12 @@ public: } }; -/// SDBM positive variable expression can be one of: +/// SDBM term expression can be one of: /// - single variable expression; /// - stripe expression. -class SDBMPositiveExpr : public SDBMVaryingExpr { +/// Stripe expressions are treated as terms since, in the SDBM domain, they are +/// attached to temporary variables and can appear anywhere a variable can. +class SDBMTermExpr : public SDBMVaryingExpr { public: using SDBMVaryingExpr::SDBMVaryingExpr; @@ -209,40 +211,38 @@ public: SDBMConstantExpr getRHS() const; }; -/// SDBM difference expression. Both LHS and RHS are positive variable -/// expressions. +/// SDBM difference expression. Both LHS and RHS are SDBM term expressions. class SDBMDiffExpr : public SDBMVaryingExpr { public: using ImplType = detail::SDBMDiffExprStorage; using SDBMVaryingExpr::SDBMVaryingExpr; /// Obtain or create a difference expression unique'ed in the given context. - static SDBMDiffExpr get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs); + static SDBMDiffExpr get(SDBMTermExpr lhs, SDBMTermExpr rhs); static bool isClassFor(const SDBMExpr &expr) { return expr.getKind() == SDBMExprKind::Diff; } - SDBMPositiveExpr getLHS() const; - SDBMPositiveExpr getRHS() const; + SDBMTermExpr getLHS() const; + SDBMTermExpr getRHS() const; }; -/// SDBM stripe expression "x # C" where "x" is a positive variable expression, -/// "C" is a constant expression and "#" is the stripe operator defined as: +/// SDBM stripe expression "x # C" where "x" is a term expression, "C" is a +/// constant expression and "#" is the stripe operator defined as: /// x # C = x - x mod C. -class SDBMStripeExpr : public SDBMPositiveExpr { +class SDBMStripeExpr : public SDBMTermExpr { public: using ImplType = detail::SDBMBinaryExprStorage; - using SDBMPositiveExpr::SDBMPositiveExpr; + using SDBMTermExpr::SDBMTermExpr; static bool isClassFor(const SDBMExpr &expr) { return expr.getKind() == SDBMExprKind::Stripe; } - static SDBMStripeExpr get(SDBMPositiveExpr var, - SDBMConstantExpr stripeFactor); + static SDBMStripeExpr get(SDBMTermExpr var, SDBMConstantExpr stripeFactor); - SDBMPositiveExpr getVar() const; + SDBMTermExpr getVar() const; SDBMConstantExpr getStripeFactor() const; }; @@ -250,10 +250,10 @@ public: /// a symbol identifier. When used to define SDBM functions, dimensions are /// interpreted as function arguments while symbols are treated as unknown but /// constant values, hence the name. -class SDBMInputExpr : public SDBMPositiveExpr { +class SDBMInputExpr : public SDBMTermExpr { public: - using ImplType = detail::SDBMPositiveExprStorage; - using SDBMPositiveExpr::SDBMPositiveExpr; + using ImplType = detail::SDBMTermExprStorage; + using SDBMTermExpr::SDBMTermExpr; static bool isClassFor(const SDBMExpr &expr) { return expr.getKind() == SDBMExprKind::DimId || @@ -267,7 +267,7 @@ public: /// when defining functions using SDBM expressions. class SDBMDimExpr : public SDBMInputExpr { public: - using ImplType = detail::SDBMPositiveExprStorage; + using ImplType = detail::SDBMTermExprStorage; using SDBMInputExpr::SDBMInputExpr; /// Obtain or create a dimension expression unique'ed in the given dialect @@ -283,7 +283,7 @@ public: /// defining functions using SDBM expressions. class SDBMSymbolExpr : public SDBMInputExpr { public: - using ImplType = detail::SDBMPositiveExprStorage; + using ImplType = detail::SDBMTermExprStorage; using SDBMInputExpr::SDBMInputExpr; /// Obtain or create a symbol expression unique'ed in the given dialect (which @@ -303,13 +303,13 @@ public: using SDBMVaryingExpr::SDBMVaryingExpr; /// Obtain or create a negation expression unique'ed in the given context. - static SDBMNegExpr get(SDBMPositiveExpr var); + static SDBMNegExpr get(SDBMTermExpr var); static bool isClassFor(const SDBMExpr &expr) { return expr.getKind() == SDBMExprKind::Neg; } - SDBMPositiveExpr getVar() const; + SDBMTermExpr getVar() const; }; /// A visitor class for SDBM expressions. Calls the kind-specific function @@ -352,10 +352,10 @@ protected: void visitNeg(SDBMNegExpr) {} void visitConstant(SDBMConstantExpr) {} - /// Default implementation of visitPositive dispatches to the special + /// Default implementation of visitTerm dispatches to the special /// functions for stripes and other variables. Concrete visitors can override /// it. - Result visitPositive(SDBMPositiveExpr expr) { + Result visitTerm(SDBMTermExpr expr) { auto *derived = static_cast(this); if (expr.getKind() == SDBMExprKind::Stripe) return derived->visitStripe(expr.cast()); @@ -379,8 +379,8 @@ protected: /// override it to visit all variables and negations instead. Result visitVarying(SDBMVaryingExpr expr) { auto *derived = static_cast(this); - if (auto var = expr.dyn_cast()) - return derived->visitPositive(var); + if (auto var = expr.dyn_cast()) + return derived->visitTerm(var); else if (auto neg = expr.dyn_cast()) return derived->visitNeg(neg); else if (auto sum = expr.dyn_cast()) @@ -486,22 +486,20 @@ template <> struct DenseMapInfo { } }; -// SDBMPositiveExpr hash just like pointers. -template <> struct DenseMapInfo { - static mlir::SDBMPositiveExpr getEmptyKey() { +// SDBMTermExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMTermExpr getEmptyKey() { auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::SDBMPositiveExpr( - static_cast(pointer)); + return mlir::SDBMTermExpr(static_cast(pointer)); } - static mlir::SDBMPositiveExpr getTombstoneKey() { + static mlir::SDBMTermExpr getTombstoneKey() { auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::SDBMPositiveExpr( - static_cast(pointer)); + return mlir::SDBMTermExpr(static_cast(pointer)); } - static unsigned getHashValue(mlir::SDBMPositiveExpr expr) { + static unsigned getHashValue(mlir::SDBMTermExpr expr) { return expr.hash_value(); } - static bool isEqual(mlir::SDBMPositiveExpr lhs, mlir::SDBMPositiveExpr rhs) { + static bool isEqual(mlir::SDBMTermExpr lhs, mlir::SDBMTermExpr rhs) { return lhs == rhs; } }; diff --git a/mlir/lib/Dialect/SDBM/SDBM.cpp b/mlir/lib/Dialect/SDBM/SDBM.cpp index 5450a61b17b9..b3e648300e50 100644 --- a/mlir/lib/Dialect/SDBM/SDBM.cpp +++ b/mlir/lib/Dialect/SDBM/SDBM.cpp @@ -354,8 +354,8 @@ SDBM SDBM::get(ArrayRef inequalities, ArrayRef equalities) { // If one of the expressions is derived from another using a stripe operation, // check if the inequalities induced by the stripe operation subsume the // inequalities defined in the DBM and if so, elide these inequalities. -void SDBM::convertDBMElement(unsigned row, unsigned col, - SDBMPositiveExpr rowExpr, SDBMPositiveExpr colExpr, +void SDBM::convertDBMElement(unsigned row, unsigned col, SDBMTermExpr rowExpr, + SDBMTermExpr colExpr, SmallVectorImpl &inequalities, SmallVectorImpl &equalities) { using ops_assertions::operator+; @@ -388,13 +388,13 @@ void SDBM::convertDBMElement(unsigned row, unsigned col, SDBMExpr x1Expr, int64_t value) { if (stripeToPoint.count(x0)) { auto stripe = stripeToPoint[x0].cast(); - SDBMPositiveExpr var = stripe.getVar(); + SDBMTermExpr var = stripe.getVar(); if (x1Expr == var && value >= 0) return true; } if (stripeToPoint.count(x1)) { auto stripe = stripeToPoint[x1].cast(); - SDBMPositiveExpr var = stripe.getVar(); + SDBMTermExpr var = stripe.getVar(); if (x0Expr == var && value >= stripe.getStripeFactor().getValue() - 1) return true; } @@ -418,7 +418,7 @@ void SDBM::convertDBMElement(unsigned row, unsigned col, // to -C <= 0. Only construct the inequalities when C is negative, which // are trivially false but necessary for the returned system of inequalities // to indicate that the set it defines is empty. -void SDBM::convertDBMDiagonalElement(unsigned pos, SDBMPositiveExpr expr, +void SDBM::convertDBMDiagonalElement(unsigned pos, SDBMTermExpr expr, SmallVectorImpl &inequalities) { auto selfDifference = at(pos, pos); if (selfDifference.isFinite() && selfDifference < 0) { diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp index a174c8c84f23..f1c02a36312c 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -166,20 +166,20 @@ void SDBMExpr::print(raw_ostream &os) const { visitConstant(expr.getRHS()); } void visitDiff(SDBMDiffExpr expr) { - visitPositive(expr.getLHS()); + visitTerm(expr.getLHS()); prn << " - "; - visitPositive(expr.getRHS()); + visitTerm(expr.getRHS()); } void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); } void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); } void visitStripe(SDBMStripeExpr expr) { - visitPositive(expr.getVar()); + visitTerm(expr.getVar()); prn << " # "; visitConstant(expr.getStripeFactor()); } void visitNeg(SDBMNegExpr expr) { prn << '-'; - visitPositive(expr.getVar()); + visitTerm(expr.getVar()); } void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); } @@ -197,11 +197,9 @@ void SDBMExpr::dump() const { namespace { // Helper class to perform negation of an SDBM expression. struct SDBMNegator : public SDBMVisitor { - // Any positive expression is wrapped into a negation expression. + // Any term expression is wrapped into a negation expression. // -(x) = -x - SDBMExpr visitPositive(SDBMPositiveExpr expr) { - return SDBMNegExpr::get(expr); - } + SDBMExpr visitTerm(SDBMTermExpr expr) { return SDBMNegExpr::get(expr); } // A negation expression is unwrapped. // -(-x) = x SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); } @@ -305,7 +303,7 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { if (auto convertedLHS = visit(x.matched())) { // TODO(ntv): return convertedLHS.stripe(C); return SDBMStripeExpr::get( - convertedLHS.cast(), + convertedLHS.cast(), visit(C.matched()).cast()); } } @@ -328,8 +326,8 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { // difference, supported as a special kind in SDBM. Because AffineExprs // don't have first-class difference kind, check both LHS and RHS for // negation. - auto lhsPos = lhs.dyn_cast(); - auto rhsPos = rhs.dyn_cast(); + auto lhsPos = lhs.dyn_cast(); + auto rhsPos = rhs.dyn_cast(); auto lhsNeg = lhs.dyn_cast(); auto rhsNeg = rhs.dyn_cast(); if (lhsNeg && rhsVar) @@ -347,7 +345,7 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { AffineExprMatcher pattern = (x.floorDiv(C)) * C; if (pattern.match(expr)) { if (SDBMExpr converted = visit(x.matched())) { - if (auto varConverted = converted.dyn_cast()) + if (auto varConverted = converted.dyn_cast()) // TODO(ntv): return varConverted.stripe(C.getConstantValue()); return SDBMStripeExpr::get( varConverted, @@ -369,7 +367,7 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { // The only supported "multiplication" expression is an SDBM is dimension // negation, that is a product of dimension and constant -1. - auto lhsVar = lhs.dyn_cast(); + auto lhsVar = lhs.dyn_cast(); if (lhsVar && rhsConstant.getValue() == -1) return SDBMNegExpr::get(lhsVar); @@ -385,7 +383,7 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { // 'mod' can only be converted to SDBM if its LHS is a variable // and its RHS is a constant. Then it `x mod c = x - x stripe c`. auto rhsConstant = rhs.dyn_cast(); - auto lhsVar = rhs.dyn_cast(); + auto lhsVar = rhs.dyn_cast(); if (!lhsVar || !rhsConstant) return {}; return SDBMDiffExpr::get(lhsVar, @@ -420,7 +418,7 @@ Optional SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { // SDBMDiffExpr //===----------------------------------------------------------------------===// -SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) { +SDBMDiffExpr SDBMDiffExpr::get(SDBMTermExpr lhs, SDBMTermExpr rhs) { assert(lhs && "expected SDBM dimension"); assert(rhs && "expected SDBM dimension"); @@ -429,11 +427,11 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) { /*initFn=*/{}, static_cast(SDBMExprKind::Diff), lhs, rhs); } -SDBMPositiveExpr SDBMDiffExpr::getLHS() const { +SDBMTermExpr SDBMDiffExpr::getLHS() const { return static_cast(impl)->lhs; } -SDBMPositiveExpr SDBMDiffExpr::getRHS() const { +SDBMTermExpr SDBMDiffExpr::getRHS() const { return static_cast(impl)->rhs; } @@ -441,7 +439,7 @@ SDBMPositiveExpr SDBMDiffExpr::getRHS() const { // SDBMStripeExpr //===----------------------------------------------------------------------===// -SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var, +SDBMStripeExpr SDBMStripeExpr::get(SDBMTermExpr var, SDBMConstantExpr stripeFactor) { assert(var && "expected SDBM variable expression"); assert(stripeFactor && "expected non-null stripe factor"); @@ -454,9 +452,9 @@ SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var, stripeFactor); } -SDBMPositiveExpr SDBMStripeExpr::getVar() const { +SDBMTermExpr SDBMStripeExpr::getVar() const { if (SDBMVaryingExpr lhs = static_cast(impl)->lhs) - return lhs.cast(); + return lhs.cast(); return {}; } @@ -479,12 +477,12 @@ unsigned SDBMInputExpr::getPosition() const { SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) { assert(dialect && "expected non-null dialect"); - auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) { + auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) { storage->dialect = dialect; }; StorageUniquer &uniquer = dialect->getUniquer(); - return uniquer.get( + return uniquer.get( assignDialect, static_cast(SDBMExprKind::DimId), position); } @@ -495,12 +493,12 @@ SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) { SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) { assert(dialect && "expected non-null dialect"); - auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) { + auto assignDialect = [dialect](detail::SDBMTermExprStorage *storage) { storage->dialect = dialect; }; StorageUniquer &uniquer = dialect->getUniquer(); - return uniquer.get( + return uniquer.get( assignDialect, static_cast(SDBMExprKind::SymbolId), position); } @@ -528,7 +526,7 @@ int64_t SDBMConstantExpr::getValue() const { // SDBMNegExpr //===----------------------------------------------------------------------===// -SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) { +SDBMNegExpr SDBMNegExpr::get(SDBMTermExpr var) { assert(var && "expected non-null SDBM variable expression"); StorageUniquer &uniquer = var.getDialect()->getUniquer(); @@ -536,7 +534,7 @@ SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) { /*initFn=*/{}, static_cast(SDBMExprKind::Neg), var); } -SDBMPositiveExpr SDBMNegExpr::getVar() const { +SDBMTermExpr SDBMNegExpr::getVar() const { return static_cast(impl)->dim; } @@ -627,8 +625,7 @@ SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { } // This calls into operator+ for futher simplification in case value == 0. - return SDBMDiffExpr::get(lhs.cast(), - rhs.cast()) + + return SDBMDiffExpr::get(lhs.cast(), rhs.cast()) + value; } @@ -640,7 +637,7 @@ SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) { if (constantFactor.getValue() == 1) return expr; - return SDBMStripeExpr::get(expr.cast(), constantFactor); + return SDBMStripeExpr::get(expr.cast(), constantFactor); } } // namespace ops_assertions diff --git a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h index 1721b02dae7d..b202ab5efb4f 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h +++ b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h @@ -64,7 +64,7 @@ struct SDBMBinaryExprStorage : public SDBMExprStorage { // Storage class for SDBM difference expressions. struct SDBMDiffExprStorage : public SDBMExprStorage { - using KeyTy = std::pair; + using KeyTy = std::pair; bool operator==(const KeyTy &key) const { return std::get<0>(key) == lhs && std::get<1>(key) == rhs; @@ -79,8 +79,8 @@ struct SDBMDiffExprStorage : public SDBMExprStorage { return result; } - SDBMPositiveExpr lhs; - SDBMPositiveExpr rhs; + SDBMTermExpr lhs; + SDBMTermExpr rhs; }; // Storage class for SDBM constant expressions. @@ -100,14 +100,14 @@ struct SDBMConstantExprStorage : public SDBMExprStorage { }; // Storage class for SDBM dimension and symbol expressions. -struct SDBMPositiveExprStorage : public SDBMExprStorage { +struct SDBMTermExprStorage : public SDBMExprStorage { using KeyTy = unsigned; bool operator==(const KeyTy &key) const { return position == key; } - static SDBMPositiveExprStorage * + static SDBMTermExprStorage * construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { - auto *result = allocator.allocate(); + auto *result = allocator.allocate(); result->position = key; return result; } @@ -117,7 +117,7 @@ struct SDBMPositiveExprStorage : public SDBMExprStorage { // Storage class for SDBM negation expressions. struct SDBMNegExprStorage : public SDBMExprStorage { - using KeyTy = SDBMPositiveExpr; + using KeyTy = SDBMTermExpr; bool operator==(const KeyTy &key) const { return key == dim; } @@ -129,7 +129,7 @@ struct SDBMNegExprStorage : public SDBMExprStorage { return result; } - SDBMPositiveExpr dim; + SDBMTermExpr dim; }; } // end namespace detail diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp index f6c29e101d53..b6f8080e1056 100644 --- a/mlir/unittests/SDBM/SDBMTest.cpp +++ b/mlir/unittests/SDBM/SDBMTest.cpp @@ -173,7 +173,7 @@ TEST(SDBMExpr, Dim) { auto generic = static_cast(expr); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); // Dimensions are not Symbols. @@ -195,7 +195,7 @@ TEST(SDBMExpr, Symbol) { auto generic = static_cast(expr); EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); // Dimensions are not Symbols. @@ -228,7 +228,7 @@ TEST(SDBMExpr, Stripe) { // Hierarchy is okay. auto generic = static_cast(expr); EXPECT_TRUE(generic.isa()); - EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); EXPECT_TRUE(generic.isa()); }