diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h index 2d0fd0cf4702..d1bc70974078 100644 --- a/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h +++ b/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h @@ -47,6 +47,7 @@ class SDBMConstantExpr; class SDBMDialect; class SDBMDimExpr; class SDBMSymbolExpr; +class SDBMTermExpr; /// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side /// expression for the SDBM framework. SDBM expressions are a subset of affine @@ -206,6 +207,13 @@ class SDBMDirectExpr : public SDBMVaryingExpr { public: using SDBMVaryingExpr::SDBMVaryingExpr; + /// If this is a sum expression, return its variable part, otherwise return + /// self. + SDBMTermExpr getTerm(); + + /// If this is a sum expression, return its constant part, otherwise return 0. + int64_t getConstant(); + static bool isClassFor(const SDBMExpr &expr) { return expr.getKind() == SDBMExprKind::DimId || expr.getKind() == SDBMExprKind::SymbolId || diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp index 96b6491776e9..a54d41bdf087 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -354,32 +354,16 @@ static SDBMExpr addConstant(SDBMVaryingExpr lhs, int64_t constant) { // Build a difference expression given a direct expression and a negation // expression. static SDBMExpr buildDiffExpr(SDBMDirectExpr lhs, SDBMNegExpr rhs) { - SDBMTermExpr lhsTerm, rhsTerm; - int lhsConstant = 0; - int64_t rhsConstant = 0; - - if (auto lhsSum = lhs.dyn_cast()) { - lhsConstant = lhsSum.getRHS().getValue(); - lhsTerm = lhsSum.getLHS(); - } else { - lhsTerm = lhs.cast(); - } - - if (auto rhsNegatedSum = rhs.getVar().dyn_cast()) { - rhsTerm = rhsNegatedSum.getLHS(); - rhsConstant = rhsNegatedSum.getRHS().getValue(); - } else { - rhsTerm = rhs.getVar().cast(); - } - // Fold (x + C) - (x + D) = C - D. - if (lhsTerm == rhsTerm) - return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant - rhsConstant); + if (lhs.getTerm() == rhs.getVar().getTerm()) + return SDBMConstantExpr::get( + lhs.getDialect(), lhs.getConstant() - rhs.getVar().getConstant()); return SDBMDiffExpr::get( - addConstantAndSink(lhs, -rhsConstant, /*negated=*/false, + addConstantAndSink(lhs, -rhs.getVar().getConstant(), + /*negated=*/false, [](SDBMDirectExpr e) { return e; }), - rhsTerm); + rhs.getVar().getTerm()); } // Try folding an expression (lhs + rhs) where at least one of the operands @@ -400,18 +384,38 @@ static SDBMExpr foldSumDiff(SDBMExpr lhs, SDBMExpr rhs) { // If a subexpression appears in a diff expression on the LHS(RHS) of a // sum expression where it also appears on the RHS(LHS) with the opposite // sign, we can simplify it away and obtain the SDBM form. - // x - (x - C) = -(x - C) + x = C - // (x - C) - x = -x + (x - C) = -C auto lhsDiff = lhs.dyn_cast(); auto rhsDiff = rhs.dyn_cast(); - if (lhsNeg && rhsDiff && lhsNeg.getVar() == rhsDiff.getLHS()) - return SDBMNegExpr::get(rhsDiff.getRHS()); - if (lhsDirect && rhsDiff && lhsDirect == rhsDiff.getRHS()) - return rhsDiff.getLHS(); - if (lhsDiff && rhsNeg && lhsDiff.getLHS() == rhsNeg.getVar()) - return SDBMNegExpr::get(lhsDiff.getRHS()); - if (rhsDirect && lhsDiff && rhsDirect == lhsDiff.getRHS()) - return lhsDiff.getLHS(); + + // -(x + A) + ((x + B) - y) = -(y + (A - B)) + if (lhsNeg && rhsDiff && + lhsNeg.getVar().getTerm() == rhsDiff.getLHS().getTerm()) { + int64_t constant = + lhsNeg.getVar().getConstant() - rhsDiff.getLHS().getConstant(); + // RHS of the diff is a term expression, its sum with a constant is a direct + // expression. + return SDBMNegExpr::get( + addConstant(rhsDiff.getRHS(), constant).cast()); + } + + // (x + A) + ((y + B) - x) = (y + B) + A. + if (lhsDirect && rhsDiff && lhsDirect.getTerm() == rhsDiff.getRHS()) + return addConstant(rhsDiff.getLHS(), lhsDirect.getConstant()); + + // ((x + A) - y) + (-(x + B)) = -(y + (B - A)). + if (lhsDiff && rhsNeg && + lhsDiff.getLHS().getTerm() == rhsNeg.getVar().getTerm()) { + int64_t constant = + rhsNeg.getVar().getConstant() - lhsDiff.getLHS().getConstant(); + // RHS of the diff is a term expression, its sum with a constant is a direct + // expression. + return SDBMNegExpr::get( + addConstant(lhsDiff.getRHS(), constant).cast()); + } + + // ((x + A) - y) + (y + B) = (x + A) + B. + if (rhsDirect && lhsDiff && rhsDirect.getTerm() == lhsDiff.getRHS()) + return addConstant(lhsDiff.getLHS(), rhsDirect.getConstant()); return {}; } @@ -544,6 +548,22 @@ SDBMTermExpr SDBMDiffExpr::getRHS() const { return static_cast(impl)->rhs; } +//===----------------------------------------------------------------------===// +// SDBMDirectExpr +//===----------------------------------------------------------------------===// + +SDBMTermExpr SDBMDirectExpr::getTerm() { + if (auto sum = dyn_cast()) + return sum.getLHS(); + return cast(); +} + +int64_t SDBMDirectExpr::getConstant() { + if (auto sum = dyn_cast()) + return sum.getRHS().getValue(); + return 0; +} + //===----------------------------------------------------------------------===// // SDBMStripeExpr //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp index 13941cdffd3d..9042981ab851 100644 --- a/mlir/unittests/SDBM/SDBMTest.cpp +++ b/mlir/unittests/SDBM/SDBMTest.cpp @@ -100,6 +100,20 @@ TEST(SDBMOperators, AddFolding) { EXPECT_EQ(diffOfDiffs.getValue(), 0); } +TEST(SDBMOperators, AddNegativeTerms) { + const int64_t A = 7; + const int64_t B = -5; + auto x = SDBMDimExpr::get(dialect(), 0); + auto y = SDBMDimExpr::get(dialect(), 1); + + // Check the simplification patterns in addition where one of the variables is + // cancelled out and the result remains an SDBM. + EXPECT_EQ(-(x + A) + ((x + B) - y), -(y + (A - B))); + EXPECT_EQ((x + A) + ((y + B) - x), (y + B) + A); + EXPECT_EQ(((x + A) - y) + (-(x + B)), -(y + (B - A))); + EXPECT_EQ(((x + A) - y) + (y + B), (x + A) + B); +} + TEST(SDBMOperators, Diff) { auto expr = dim(0) - dim(1); auto diffExpr = expr.dyn_cast();