[MLIR][Presburger] subtract: support non-div locals

Also added test cases. Also extend support for `computeReprWithOnlyDivLocals` from `IntegerPolyhedron` to `IntegerRelation` and `PresburgerRelation`.

Depends on D128736.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D128737
This commit is contained in:
Arjun P 2022-06-28 20:35:05 +01:00
parent 2695e23ad9
commit dda8b1ceda
6 changed files with 78 additions and 15 deletions

View File

@ -27,6 +27,7 @@ namespace presburger {
class IntegerRelation;
class IntegerPolyhedron;
class PresburgerSet;
class PresburgerRelation;
/// An IntegerRelation represents the set of points from a PresburgerSpace that
/// satisfy a list of affine constraints. Affine constraints can be inequalities
@ -575,6 +576,12 @@ public:
/// this for uniformity with `applyDomain`.
void applyRange(const IntegerRelation &rel);
/// Compute an equivalent representation of the same set, such that all local
/// vars in all disjuncts have division representations. This representation
/// may involve local vars that correspond to divisions, and may also be a
/// union of convex disjuncts.
PresburgerRelation computeReprWithOnlyDivLocals() const;
void print(raw_ostream &os) const;
void dump() const;
@ -760,12 +767,6 @@ public:
/// first added variable.
unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override;
/// Compute an equivalent representation of the same set, such that all local
/// ids have division representations. This representation may involve
/// local ids that correspond to divisions, and may also be a union of convex
/// disjuncts.
PresburgerSet computeReprWithOnlyDivLocals() const;
/// Compute the symbolic integer lexmin of the polyhedron.
/// This finds, for every assignment to the symbols, the lexicographically
/// minimum value attained by the dimensions. For example, the symbolic lexmin

View File

@ -128,6 +128,12 @@ public:
/// Check whether all local ids in all disjuncts have a div representation.
bool hasOnlyDivLocals() const;
/// Compute an equivalent representation of the same relation, such that all
/// local ids in all disjuncts have division representations. This
/// representation may involve local ids that correspond to divisions, and may
/// also be a union of convex disjuncts.
PresburgerRelation computeReprWithOnlyDivLocals() const;
/// Print the set's internal state.
void print(raw_ostream &os) const;
void dump() const;

View File

@ -570,7 +570,7 @@ public:
/// `symbolDomain` is the set of values of the symbols for which the lexmin
/// will be computed. `symbolDomain` should have a dim var for every symbol in
/// `constraints`, and no other vars.
SymbolicLexSimplex(const IntegerPolyhedron &constraints,
SymbolicLexSimplex(const IntegerRelation &constraints,
const IntegerPolyhedron &symbolDomain)
: SymbolicLexSimplex(constraints,
constraints.getVarKindOffset(VarKind::Symbol),
@ -582,8 +582,7 @@ public:
/// The symbol ids are the range of ids with absolute index
/// [symbolOffset, symbolOffset + symbolDomain.getNumVars())
/// symbolDomain should only have dim ids.
SymbolicLexSimplex(const IntegerPolyhedron &constraints,
unsigned symbolOffset,
SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset,
const IntegerPolyhedron &symbolDomain)
: LexSimplexBase(/*nVar=*/constraints.getNumVars(), symbolOffset,
symbolDomain.getNumVars()),

View File

@ -165,16 +165,16 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) {
removeEqualityRange(counts.getNumEqs(), getNumEqualities());
}
PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const {
PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
// If there are no locals, we're done.
if (getNumLocalVars() == 0)
return PresburgerSet(*this);
return PresburgerRelation(*this);
// Move all the non-div locals to the end, as the current API to
// SymbolicLexMin requires these to form a contiguous range.
//
// Take a copy so we can perform mutations.
IntegerPolyhedron copy = *this;
IntegerRelation copy = *this;
std::vector<MaybeLocalRepr> reprs;
copy.getLocalReprs(reprs);
@ -197,7 +197,7 @@ PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const {
// If there are no non-div locals, we're done.
if (numNonDivLocals == 0)
return PresburgerSet(*this);
return PresburgerRelation(*this);
// We computeSymbolicIntegerLexMin by considering the non-div locals as
// "non-symbols" and considering everything else as "symbols". This will

View File

@ -136,6 +136,17 @@ static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
return getNegatedCoeffs(eqCoeffs);
}
PresburgerRelation PresburgerRelation::computeReprWithOnlyDivLocals() const {
if (hasOnlyDivLocals())
return *this;
// The result is just the union of the reprs of the disjuncts.
PresburgerRelation result(getSpace());
for (const IntegerRelation &disjunct : disjuncts)
result.unionInPlace(disjunct.computeReprWithOnlyDivLocals());
return result;
}
/// Return the set difference b \ s.
///
/// In the following, U denotes union, /\ denotes intersection, \ denotes set
@ -174,6 +185,9 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
if (b.isEmptyByGCDTest())
return PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
if (!s.hasOnlyDivLocals())
return getSetDifference(b, s.computeReprWithOnlyDivLocals());
// Remove duplicate divs up front here to avoid existing
// divs disappearing in the call to mergeLocalVars below.
b.removeDuplicateDivs();

View File

@ -431,6 +431,10 @@ void expectEqual(const PresburgerSet &s, const PresburgerSet &t) {
EXPECT_TRUE(s.isEqual(t));
}
void expectEqual(const IntegerPolyhedron &s, const IntegerPolyhedron &t) {
EXPECT_TRUE(s.isEqual(t));
}
void expectEmpty(const PresburgerSet &s) { EXPECT_TRUE(s.isIntegerEmpty()); }
TEST(SetTest, divisions) {
@ -505,6 +509,45 @@ TEST(SetTest, divisionsDefByEq) {
expectEqual(evens, PresburgerSet(evensDefByIneq));
}
TEST(SetTest, divisionNonDivLocals) {
// This is a tetrahedron with vertices at
// (1/3, 0, 0), (2/3, 0, 0), (2/3, 0, 1000), and (1000, 1000, 1000).
//
// The only integer point in this is at (1000, 1000, 1000).
// We project this to the xy plane.
IntegerPolyhedron tetrahedron =
parsePolyAndMakeLocals("(x, y, z) : (y >= 0, z - y >= 0, 3000*x - 2998*y "
"- 1000 - z >= 0, -1500*x + 1499*y + 1000 >= 0)",
/*numLocals=*/1);
// This is a triangle with vertices at (1/3, 0), (2/3, 0) and (1000, 1000).
// The only integer point in this is at (1000, 1000).
//
// It also happens to be the projection of the above onto the xy plane.
IntegerPolyhedron triangle = parsePoly("(x,y) : (y >= 0, "
"3000 * x - 2999 * y - 1000 >= 0, "
"-3000 * x + 2998 * y + 2000 >= 0)");
EXPECT_TRUE(triangle.containsPoint({1000, 1000}));
EXPECT_FALSE(triangle.containsPoint({1001, 1001}));
expectEqual(triangle, tetrahedron);
convertSuffixDimsToLocals(triangle, 1);
IntegerPolyhedron line = parsePoly("(x) : (x - 1000 == 0)");
expectEqual(line, triangle);
// Triangle with vertices (0, 0), (5, 0), (15, 5).
// Projected on x, it becomes [0, 13] U {15} as it becomes too narrow towards
// the apex and so does not have have any integer point at x = 14.
// At x = 15, the apex is an integer point.
PresburgerSet triangle2{parsePolyAndMakeLocals("(x,y) : (y >= 0, "
"x - 3*y >= 0, "
"2*y - x + 5 >= 0)",
/*numLocals=*/1)};
PresburgerSet zeroToThirteen{parsePoly("(x) : (13 - x >= 0, x >= 0)")};
PresburgerSet fifteen{parsePoly("(x) : (x - 15 == 0)")};
expectEqual(triangle2.subtract(zeroToThirteen), fifteen);
}
TEST(SetTest, subtractDuplicateDivsRegression) {
// Previously, subtracting sets with duplicate divs might result in crashes
// due to existing divs being removed when merging local ids, due to being
@ -797,7 +840,7 @@ void testComputeReprAtPoints(IntegerPolyhedron poly,
unsigned numToProject) {
poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject,
poly.getNumDimVars(), VarKind::Local);
PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
PresburgerRelation repr = poly.computeReprWithOnlyDivLocals();
EXPECT_TRUE(repr.hasOnlyDivLocals());
EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
for (const SmallVector<int64_t, 4> &point : points) {
@ -810,7 +853,7 @@ void testComputeRepr(IntegerPolyhedron poly, const PresburgerSet &expected,
unsigned numToProject) {
poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject,
poly.getNumDimVars(), VarKind::Local);
PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
PresburgerRelation repr = poly.computeReprWithOnlyDivLocals();
EXPECT_TRUE(repr.hasOnlyDivLocals());
EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
EXPECT_TRUE(repr.isEqual(expected));