[MLIR][Presburger] subtract: support non-div locals
Also added test cases. Also extend support for `computeReprWithOnlyDivLocals` from `IntegerPolyhedron` to `IntegerRelation` and `PresburgerRelation`. Depends on D128736. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D128737
This commit is contained in:
		
							parent
							
								
									2695e23ad9
								
							
						
					
					
						commit
						dda8b1ceda
					
				| 
						 | 
					@ -27,6 +27,7 @@ namespace presburger {
 | 
				
			||||||
class IntegerRelation;
 | 
					class IntegerRelation;
 | 
				
			||||||
class IntegerPolyhedron;
 | 
					class IntegerPolyhedron;
 | 
				
			||||||
class PresburgerSet;
 | 
					class PresburgerSet;
 | 
				
			||||||
 | 
					class PresburgerRelation;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// An IntegerRelation represents the set of points from a PresburgerSpace that
 | 
					/// An IntegerRelation represents the set of points from a PresburgerSpace that
 | 
				
			||||||
/// satisfy a list of affine constraints. Affine constraints can be inequalities
 | 
					/// satisfy a list of affine constraints. Affine constraints can be inequalities
 | 
				
			||||||
| 
						 | 
					@ -575,6 +576,12 @@ public:
 | 
				
			||||||
  /// this for uniformity with `applyDomain`.
 | 
					  /// this for uniformity with `applyDomain`.
 | 
				
			||||||
  void applyRange(const IntegerRelation &rel);
 | 
					  void applyRange(const IntegerRelation &rel);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Compute an equivalent representation of the same set, such that all local
 | 
				
			||||||
 | 
					  /// vars in all disjuncts have division representations. This representation
 | 
				
			||||||
 | 
					  /// may involve local vars that correspond to divisions, and may also be a
 | 
				
			||||||
 | 
					  /// union of convex disjuncts.
 | 
				
			||||||
 | 
					  PresburgerRelation computeReprWithOnlyDivLocals() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void print(raw_ostream &os) const;
 | 
					  void print(raw_ostream &os) const;
 | 
				
			||||||
  void dump() const;
 | 
					  void dump() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -760,12 +767,6 @@ public:
 | 
				
			||||||
  /// first added variable.
 | 
					  /// first added variable.
 | 
				
			||||||
  unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override;
 | 
					  unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// Compute an equivalent representation of the same set, such that all local
 | 
					 | 
				
			||||||
  /// ids have division representations. This representation may involve
 | 
					 | 
				
			||||||
  /// local ids that correspond to divisions, and may also be a union of convex
 | 
					 | 
				
			||||||
  /// disjuncts.
 | 
					 | 
				
			||||||
  PresburgerSet computeReprWithOnlyDivLocals() const;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /// Compute the symbolic integer lexmin of the polyhedron.
 | 
					  /// Compute the symbolic integer lexmin of the polyhedron.
 | 
				
			||||||
  /// This finds, for every assignment to the symbols, the lexicographically
 | 
					  /// This finds, for every assignment to the symbols, the lexicographically
 | 
				
			||||||
  /// minimum value attained by the dimensions. For example, the symbolic lexmin
 | 
					  /// minimum value attained by the dimensions. For example, the symbolic lexmin
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -128,6 +128,12 @@ public:
 | 
				
			||||||
  /// Check whether all local ids in all disjuncts have a div representation.
 | 
					  /// Check whether all local ids in all disjuncts have a div representation.
 | 
				
			||||||
  bool hasOnlyDivLocals() const;
 | 
					  bool hasOnlyDivLocals() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Compute an equivalent representation of the same relation, such that all
 | 
				
			||||||
 | 
					  /// local ids in all disjuncts have division representations. This
 | 
				
			||||||
 | 
					  /// representation may involve local ids that correspond to divisions, and may
 | 
				
			||||||
 | 
					  /// also be a union of convex disjuncts.
 | 
				
			||||||
 | 
					  PresburgerRelation computeReprWithOnlyDivLocals() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// Print the set's internal state.
 | 
					  /// Print the set's internal state.
 | 
				
			||||||
  void print(raw_ostream &os) const;
 | 
					  void print(raw_ostream &os) const;
 | 
				
			||||||
  void dump() const;
 | 
					  void dump() const;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -570,7 +570,7 @@ public:
 | 
				
			||||||
  /// `symbolDomain` is the set of values of the symbols for which the lexmin
 | 
					  /// `symbolDomain` is the set of values of the symbols for which the lexmin
 | 
				
			||||||
  /// will be computed. `symbolDomain` should have a dim var for every symbol in
 | 
					  /// will be computed. `symbolDomain` should have a dim var for every symbol in
 | 
				
			||||||
  /// `constraints`, and no other vars.
 | 
					  /// `constraints`, and no other vars.
 | 
				
			||||||
  SymbolicLexSimplex(const IntegerPolyhedron &constraints,
 | 
					  SymbolicLexSimplex(const IntegerRelation &constraints,
 | 
				
			||||||
                     const IntegerPolyhedron &symbolDomain)
 | 
					                     const IntegerPolyhedron &symbolDomain)
 | 
				
			||||||
      : SymbolicLexSimplex(constraints,
 | 
					      : SymbolicLexSimplex(constraints,
 | 
				
			||||||
                           constraints.getVarKindOffset(VarKind::Symbol),
 | 
					                           constraints.getVarKindOffset(VarKind::Symbol),
 | 
				
			||||||
| 
						 | 
					@ -582,8 +582,7 @@ public:
 | 
				
			||||||
  /// The symbol ids are the range of ids with absolute index
 | 
					  /// The symbol ids are the range of ids with absolute index
 | 
				
			||||||
  /// [symbolOffset, symbolOffset + symbolDomain.getNumVars())
 | 
					  /// [symbolOffset, symbolOffset + symbolDomain.getNumVars())
 | 
				
			||||||
  /// symbolDomain should only have dim ids.
 | 
					  /// symbolDomain should only have dim ids.
 | 
				
			||||||
  SymbolicLexSimplex(const IntegerPolyhedron &constraints,
 | 
					  SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset,
 | 
				
			||||||
                     unsigned symbolOffset,
 | 
					 | 
				
			||||||
                     const IntegerPolyhedron &symbolDomain)
 | 
					                     const IntegerPolyhedron &symbolDomain)
 | 
				
			||||||
      : LexSimplexBase(/*nVar=*/constraints.getNumVars(), symbolOffset,
 | 
					      : LexSimplexBase(/*nVar=*/constraints.getNumVars(), symbolOffset,
 | 
				
			||||||
                       symbolDomain.getNumVars()),
 | 
					                       symbolDomain.getNumVars()),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -165,16 +165,16 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) {
 | 
				
			||||||
  removeEqualityRange(counts.getNumEqs(), getNumEqualities());
 | 
					  removeEqualityRange(counts.getNumEqs(), getNumEqualities());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const {
 | 
					PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
 | 
				
			||||||
  // If there are no locals, we're done.
 | 
					  // If there are no locals, we're done.
 | 
				
			||||||
  if (getNumLocalVars() == 0)
 | 
					  if (getNumLocalVars() == 0)
 | 
				
			||||||
    return PresburgerSet(*this);
 | 
					    return PresburgerRelation(*this);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Move all the non-div locals to the end, as the current API to
 | 
					  // Move all the non-div locals to the end, as the current API to
 | 
				
			||||||
  // SymbolicLexMin requires these to form a contiguous range.
 | 
					  // SymbolicLexMin requires these to form a contiguous range.
 | 
				
			||||||
  //
 | 
					  //
 | 
				
			||||||
  // Take a copy so we can perform mutations.
 | 
					  // Take a copy so we can perform mutations.
 | 
				
			||||||
  IntegerPolyhedron copy = *this;
 | 
					  IntegerRelation copy = *this;
 | 
				
			||||||
  std::vector<MaybeLocalRepr> reprs;
 | 
					  std::vector<MaybeLocalRepr> reprs;
 | 
				
			||||||
  copy.getLocalReprs(reprs);
 | 
					  copy.getLocalReprs(reprs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -197,7 +197,7 @@ PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // If there are no non-div locals, we're done.
 | 
					  // If there are no non-div locals, we're done.
 | 
				
			||||||
  if (numNonDivLocals == 0)
 | 
					  if (numNonDivLocals == 0)
 | 
				
			||||||
    return PresburgerSet(*this);
 | 
					    return PresburgerRelation(*this);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // We computeSymbolicIntegerLexMin by considering the non-div locals as
 | 
					  // We computeSymbolicIntegerLexMin by considering the non-div locals as
 | 
				
			||||||
  // "non-symbols" and considering everything else as "symbols". This will
 | 
					  // "non-symbols" and considering everything else as "symbols". This will
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -136,6 +136,17 @@ static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
 | 
				
			||||||
  return getNegatedCoeffs(eqCoeffs);
 | 
					  return getNegatedCoeffs(eqCoeffs);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					PresburgerRelation PresburgerRelation::computeReprWithOnlyDivLocals() const {
 | 
				
			||||||
 | 
					  if (hasOnlyDivLocals())
 | 
				
			||||||
 | 
					    return *this;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // The result is just the union of the reprs of the disjuncts.
 | 
				
			||||||
 | 
					  PresburgerRelation result(getSpace());
 | 
				
			||||||
 | 
					  for (const IntegerRelation &disjunct : disjuncts)
 | 
				
			||||||
 | 
					    result.unionInPlace(disjunct.computeReprWithOnlyDivLocals());
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// Return the set difference b \ s.
 | 
					/// Return the set difference b \ s.
 | 
				
			||||||
///
 | 
					///
 | 
				
			||||||
/// In the following, U denotes union, /\ denotes intersection, \ denotes set
 | 
					/// In the following, U denotes union, /\ denotes intersection, \ denotes set
 | 
				
			||||||
| 
						 | 
					@ -174,6 +185,9 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
 | 
				
			||||||
  if (b.isEmptyByGCDTest())
 | 
					  if (b.isEmptyByGCDTest())
 | 
				
			||||||
    return PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
 | 
					    return PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!s.hasOnlyDivLocals())
 | 
				
			||||||
 | 
					    return getSetDifference(b, s.computeReprWithOnlyDivLocals());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Remove duplicate divs up front here to avoid existing
 | 
					  // Remove duplicate divs up front here to avoid existing
 | 
				
			||||||
  // divs disappearing in the call to mergeLocalVars below.
 | 
					  // divs disappearing in the call to mergeLocalVars below.
 | 
				
			||||||
  b.removeDuplicateDivs();
 | 
					  b.removeDuplicateDivs();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -431,6 +431,10 @@ void expectEqual(const PresburgerSet &s, const PresburgerSet &t) {
 | 
				
			||||||
  EXPECT_TRUE(s.isEqual(t));
 | 
					  EXPECT_TRUE(s.isEqual(t));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void expectEqual(const IntegerPolyhedron &s, const IntegerPolyhedron &t) {
 | 
				
			||||||
 | 
					  EXPECT_TRUE(s.isEqual(t));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void expectEmpty(const PresburgerSet &s) { EXPECT_TRUE(s.isIntegerEmpty()); }
 | 
					void expectEmpty(const PresburgerSet &s) { EXPECT_TRUE(s.isIntegerEmpty()); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(SetTest, divisions) {
 | 
					TEST(SetTest, divisions) {
 | 
				
			||||||
| 
						 | 
					@ -505,6 +509,45 @@ TEST(SetTest, divisionsDefByEq) {
 | 
				
			||||||
  expectEqual(evens, PresburgerSet(evensDefByIneq));
 | 
					  expectEqual(evens, PresburgerSet(evensDefByIneq));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SetTest, divisionNonDivLocals) {
 | 
				
			||||||
 | 
					  // This is a tetrahedron with vertices at
 | 
				
			||||||
 | 
					  // (1/3, 0, 0), (2/3, 0, 0), (2/3, 0, 1000), and (1000, 1000, 1000).
 | 
				
			||||||
 | 
					  //
 | 
				
			||||||
 | 
					  // The only integer point in this is at (1000, 1000, 1000).
 | 
				
			||||||
 | 
					  // We project this to the xy plane.
 | 
				
			||||||
 | 
					  IntegerPolyhedron tetrahedron =
 | 
				
			||||||
 | 
					      parsePolyAndMakeLocals("(x, y, z) : (y >= 0, z - y >= 0, 3000*x - 2998*y "
 | 
				
			||||||
 | 
					                             "- 1000 - z >= 0, -1500*x + 1499*y + 1000 >= 0)",
 | 
				
			||||||
 | 
					                             /*numLocals=*/1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // This is a triangle with vertices at (1/3, 0), (2/3, 0) and (1000, 1000).
 | 
				
			||||||
 | 
					  // The only integer point in this is at (1000, 1000).
 | 
				
			||||||
 | 
					  //
 | 
				
			||||||
 | 
					  // It also happens to be the projection of the above onto the xy plane.
 | 
				
			||||||
 | 
					  IntegerPolyhedron triangle = parsePoly("(x,y) : (y >= 0, "
 | 
				
			||||||
 | 
					                                         "3000 * x - 2999 * y - 1000 >= 0, "
 | 
				
			||||||
 | 
					                                         "-3000 * x + 2998 * y + 2000 >= 0)");
 | 
				
			||||||
 | 
					  EXPECT_TRUE(triangle.containsPoint({1000, 1000}));
 | 
				
			||||||
 | 
					  EXPECT_FALSE(triangle.containsPoint({1001, 1001}));
 | 
				
			||||||
 | 
					  expectEqual(triangle, tetrahedron);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  convertSuffixDimsToLocals(triangle, 1);
 | 
				
			||||||
 | 
					  IntegerPolyhedron line = parsePoly("(x) : (x - 1000 == 0)");
 | 
				
			||||||
 | 
					  expectEqual(line, triangle);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Triangle with vertices (0, 0), (5, 0), (15, 5).
 | 
				
			||||||
 | 
					  // Projected on x, it becomes [0, 13] U {15} as it becomes too narrow towards
 | 
				
			||||||
 | 
					  // the apex and so does not have have any integer point at x = 14.
 | 
				
			||||||
 | 
					  // At x = 15, the apex is an integer point.
 | 
				
			||||||
 | 
					  PresburgerSet triangle2{parsePolyAndMakeLocals("(x,y) : (y >= 0, "
 | 
				
			||||||
 | 
					                                                 "x - 3*y >= 0, "
 | 
				
			||||||
 | 
					                                                 "2*y - x + 5 >= 0)",
 | 
				
			||||||
 | 
					                                                 /*numLocals=*/1)};
 | 
				
			||||||
 | 
					  PresburgerSet zeroToThirteen{parsePoly("(x) : (13 - x >= 0, x >= 0)")};
 | 
				
			||||||
 | 
					  PresburgerSet fifteen{parsePoly("(x) : (x - 15 == 0)")};
 | 
				
			||||||
 | 
					  expectEqual(triangle2.subtract(zeroToThirteen), fifteen);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(SetTest, subtractDuplicateDivsRegression) {
 | 
					TEST(SetTest, subtractDuplicateDivsRegression) {
 | 
				
			||||||
  // Previously, subtracting sets with duplicate divs might result in crashes
 | 
					  // Previously, subtracting sets with duplicate divs might result in crashes
 | 
				
			||||||
  // due to existing divs being removed when merging local ids, due to being
 | 
					  // due to existing divs being removed when merging local ids, due to being
 | 
				
			||||||
| 
						 | 
					@ -797,7 +840,7 @@ void testComputeReprAtPoints(IntegerPolyhedron poly,
 | 
				
			||||||
                             unsigned numToProject) {
 | 
					                             unsigned numToProject) {
 | 
				
			||||||
  poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject,
 | 
					  poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject,
 | 
				
			||||||
                      poly.getNumDimVars(), VarKind::Local);
 | 
					                      poly.getNumDimVars(), VarKind::Local);
 | 
				
			||||||
  PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
 | 
					  PresburgerRelation repr = poly.computeReprWithOnlyDivLocals();
 | 
				
			||||||
  EXPECT_TRUE(repr.hasOnlyDivLocals());
 | 
					  EXPECT_TRUE(repr.hasOnlyDivLocals());
 | 
				
			||||||
  EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
 | 
					  EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
 | 
				
			||||||
  for (const SmallVector<int64_t, 4> &point : points) {
 | 
					  for (const SmallVector<int64_t, 4> &point : points) {
 | 
				
			||||||
| 
						 | 
					@ -810,7 +853,7 @@ void testComputeRepr(IntegerPolyhedron poly, const PresburgerSet &expected,
 | 
				
			||||||
                     unsigned numToProject) {
 | 
					                     unsigned numToProject) {
 | 
				
			||||||
  poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject,
 | 
					  poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject,
 | 
				
			||||||
                      poly.getNumDimVars(), VarKind::Local);
 | 
					                      poly.getNumDimVars(), VarKind::Local);
 | 
				
			||||||
  PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
 | 
					  PresburgerRelation repr = poly.computeReprWithOnlyDivLocals();
 | 
				
			||||||
  EXPECT_TRUE(repr.hasOnlyDivLocals());
 | 
					  EXPECT_TRUE(repr.hasOnlyDivLocals());
 | 
				
			||||||
  EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
 | 
					  EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
 | 
				
			||||||
  EXPECT_TRUE(repr.isEqual(expected));
 | 
					  EXPECT_TRUE(repr.isEqual(expected));
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue