[mlir][sparse] refactor handling of merger leafs and ops

Using "default:" in the switch statemements that handle all our
merger ops has become a bit cumbersome since it is easy to overlook
parts of the code that need to handle ops specifically. By enforcing
full switch statements without "default:", we get a compiler warning
when cases are overlooked.

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D127263
This commit is contained in:
Aart Bik 2022-06-07 15:51:17 -07:00
parent 4badd4d40d
commit 06aa6ec87d
2 changed files with 154 additions and 47 deletions

View File

@ -25,6 +25,7 @@ namespace sparse_tensor {
TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
: kind(k), val(v), op(o) { : kind(k), val(v), op(o) {
switch (kind) { switch (kind) {
// Leaf.
case kTensor: case kTensor:
assert(x != -1u && y == -1u && !v && !o); assert(x != -1u && y == -1u && !v && !o);
tensor = x; tensor = x;
@ -36,6 +37,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
assert(x != -1u && y == -1u && !v && !o); assert(x != -1u && y == -1u && !v && !o);
index = x; index = x;
break; break;
// Unary operations.
case kAbsF: case kAbsF:
case kAbsC: case kAbsC:
case kCeilF: case kCeilF:
@ -86,13 +88,32 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
children.e0 = x; children.e0 = x;
children.e1 = y; children.e1 = y;
break; break;
case kBinary: // Binary operations.
assert(x != -1u && y != -1u && !v && o); case kMulF:
case kMulC:
case kMulI:
case kDivF:
case kDivC:
case kDivS:
case kDivU:
case kAddF:
case kAddC:
case kAddI:
case kSubF:
case kSubC:
case kSubI:
case kAndI:
case kOrI:
case kXorI:
case kShrS:
case kShrU:
case kShlI:
assert(x != -1u && y != -1u && !v && !o);
children.e0 = x; children.e0 = x;
children.e1 = y; children.e1 = y;
break; break;
default: case kBinary:
assert(x != -1u && y != -1u && !v && !o); assert(x != -1u && y != -1u && !v && o);
children.e0 = x; children.e0 = x;
children.e1 = y; children.e1 = y;
break; break;
@ -280,8 +301,13 @@ bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
bool Merger::isSingleCondition(unsigned t, unsigned e) const { bool Merger::isSingleCondition(unsigned t, unsigned e) const {
switch (tensorExps[e].kind) { switch (tensorExps[e].kind) {
// Leaf.
case kTensor: case kTensor:
return tensorExps[e].tensor == t; return tensorExps[e].tensor == t;
case kInvariant:
case kIndex:
return false;
// Unary operations.
case kAbsF: case kAbsF:
case kAbsC: case kAbsC:
case kCeilF: case kCeilF:
@ -313,6 +339,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kCRe: case kCRe:
case kBitCast: case kBitCast:
return isSingleCondition(t, tensorExps[e].children.e0); return isSingleCondition(t, tensorExps[e].children.e0);
case kBinaryBranch:
case kUnary:
return false;
// Binary operations.
case kDivF: // note: x / c only case kDivF: // note: x / c only
case kDivC: case kDivC:
case kDivS: case kDivS:
@ -339,7 +369,12 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kAddI: case kAddI:
return isSingleCondition(t, tensorExps[e].children.e0) && return isSingleCondition(t, tensorExps[e].children.e0) &&
isSingleCondition(t, tensorExps[e].children.e1); isSingleCondition(t, tensorExps[e].children.e1);
default: case kSubF:
case kSubC:
case kSubI:
case kOrI:
case kXorI:
case kBinary:
return false; return false;
} }
} }
@ -352,12 +387,14 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
static const char *kindToOpSymbol(Kind kind) { static const char *kindToOpSymbol(Kind kind) {
switch (kind) { switch (kind) {
// Leaf.
case kTensor: case kTensor:
return "tensor"; return "tensor";
case kInvariant: case kInvariant:
return "invariant"; return "invariant";
case kIndex: case kIndex:
return "index"; return "index";
// Unary operations.
case kAbsF: case kAbsF:
case kAbsC: case kAbsC:
return "abs"; return "abs";
@ -404,6 +441,7 @@ static const char *kindToOpSymbol(Kind kind) {
return "binary_branch"; return "binary_branch";
case kUnary: case kUnary:
return "unary"; return "unary";
// Binary operations.
case kMulF: case kMulF:
case kMulC: case kMulC:
case kMulI: case kMulI:
@ -441,6 +479,7 @@ static const char *kindToOpSymbol(Kind kind) {
void Merger::dumpExp(unsigned e) const { void Merger::dumpExp(unsigned e) const {
switch (tensorExps[e].kind) { switch (tensorExps[e].kind) {
// Leaf.
case kTensor: case kTensor:
if (tensorExps[e].tensor == syntheticTensor) if (tensorExps[e].tensor == syntheticTensor)
llvm::dbgs() << "synthetic_"; llvm::dbgs() << "synthetic_";
@ -454,7 +493,9 @@ void Merger::dumpExp(unsigned e) const {
case kIndex: case kIndex:
llvm::dbgs() << "index_" << tensorExps[e].index; llvm::dbgs() << "index_" << tensorExps[e].index;
break; break;
// Unary operations.
case kAbsF: case kAbsF:
case kAbsC:
case kCeilF: case kCeilF:
case kFloorF: case kFloorF:
case kSqrtF: case kSqrtF:
@ -462,10 +503,13 @@ void Merger::dumpExp(unsigned e) const {
case kExpm1F: case kExpm1F:
case kExpm1C: case kExpm1C:
case kLog1pF: case kLog1pF:
case kLog1pC:
case kSinF: case kSinF:
case kSinC:
case kTanhF: case kTanhF:
case kTanhC: case kTanhC:
case kNegF: case kNegF:
case kNegC:
case kNegI: case kNegI:
case kTruncF: case kTruncF:
case kExtF: case kExtF:
@ -477,11 +521,35 @@ void Merger::dumpExp(unsigned e) const {
case kCastU: case kCastU:
case kCastIdx: case kCastIdx:
case kTruncI: case kTruncI:
case kCIm:
case kCRe:
case kBitCast: case kBitCast:
case kBinaryBranch:
case kUnary:
llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
dumpExp(tensorExps[e].children.e0); dumpExp(tensorExps[e].children.e0);
break; break;
default: // Binary operations.
case kMulF:
case kMulC:
case kMulI:
case kDivF:
case kDivC:
case kDivS:
case kDivU:
case kAddF:
case kAddC:
case kAddI:
case kSubF:
case kSubC:
case kSubI:
case kAndI:
case kOrI:
case kXorI:
case kShrS:
case kShrU:
case kShlI:
case kBinary:
llvm::dbgs() << "("; llvm::dbgs() << "(";
dumpExp(tensorExps[e].children.e0); dumpExp(tensorExps[e].children.e0);
llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
@ -542,6 +610,7 @@ void Merger::dumpBits(const BitVector &bits) const {
unsigned Merger::buildLattices(unsigned e, unsigned i) { unsigned Merger::buildLattices(unsigned e, unsigned i) {
Kind kind = tensorExps[e].kind; Kind kind = tensorExps[e].kind;
switch (kind) { switch (kind) {
// Leaf.
case kTensor: case kTensor:
case kInvariant: case kInvariant:
case kIndex: { case kIndex: {
@ -560,11 +629,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
latSets[s].push_back(addLat(t, i, e)); latSets[s].push_back(addLat(t, i, e));
return s; return s;
} }
// Unary operations.
case kAbsF: case kAbsF:
case kAbsC: case kAbsC:
case kCeilF: case kCeilF:
case kCIm:
case kCRe:
case kFloorF: case kFloorF:
case kSqrtF: case kSqrtF:
case kSqrtC: case kSqrtC:
@ -589,6 +657,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kCastU: case kCastU:
case kCastIdx: case kCastIdx:
case kTruncI: case kTruncI:
case kCIm:
case kCRe:
case kBitCast: case kBitCast:
// A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
// lattice set of the operand through the operator into a new set. // lattice set of the operand through the operator into a new set.
@ -625,6 +695,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
unsigned rhs = addExp(kInvariant, absentVal); unsigned rhs = addExp(kInvariant, absentVal);
return takeDisj(kind, child0, buildLattices(rhs, i), unop); return takeDisj(kind, child0, buildLattices(rhs, i), unop);
} }
// Binary operations.
case kMulF: case kMulF:
case kMulC: case kMulC:
case kMulI: case kMulI:
@ -955,16 +1026,17 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
Value v0, Value v1) { Value v0, Value v1) {
switch (tensorExps[e].kind) { switch (tensorExps[e].kind) {
// Leaf.
case kTensor: case kTensor:
case kInvariant: case kInvariant:
case kIndex: case kIndex:
llvm_unreachable("unexpected non-op"); llvm_unreachable("unexpected non-op");
// Unary ops. // Unary operations.
case kAbsF: case kAbsF:
return rewriter.create<math::AbsOp>(loc, v0); return rewriter.create<math::AbsOp>(loc, v0);
case kAbsC: { case kAbsC: {
auto type = v0.getType().template cast<ComplexType>(); auto type = v0.getType().cast<ComplexType>();
auto eltType = type.getElementType().template cast<FloatType>(); auto eltType = type.getElementType().cast<FloatType>();
return rewriter.create<complex::AbsOp>(loc, eltType, v0); return rewriter.create<complex::AbsOp>(loc, eltType, v0);
} }
case kCeilF: case kCeilF:
@ -1021,18 +1093,19 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
case kTruncI: case kTruncI:
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
case kCIm: case kCIm: {
auto type = v0.getType().cast<ComplexType>();
auto eltType = type.getElementType().cast<FloatType>();
return rewriter.create<complex::ImOp>(loc, eltType, v0);
}
case kCRe: { case kCRe: {
auto type = v0.getType().template cast<ComplexType>(); auto type = v0.getType().cast<ComplexType>();
auto eltType = type.getElementType().template cast<FloatType>(); auto eltType = type.getElementType().cast<FloatType>();
if (tensorExps[e].kind == kCIm)
return rewriter.create<complex::ImOp>(loc, eltType, v0);
return rewriter.create<complex::ReOp>(loc, eltType, v0); return rewriter.create<complex::ReOp>(loc, eltType, v0);
} }
case kBitCast: case kBitCast:
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
// Binary ops. // Binary operations.
case kMulF: case kMulF:
return rewriter.create<arith::MulFOp>(loc, v0, v1); return rewriter.create<arith::MulFOp>(loc, v0, v1);
case kMulC: case kMulC:
@ -1071,8 +1144,7 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::ShRUIOp>(loc, v0, v1); return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
case kShlI: case kShlI:
return rewriter.create<arith::ShLIOp>(loc, v0, v1); return rewriter.create<arith::ShLIOp>(loc, v0, v1);
// Semiring ops with custom logic. case kBinaryBranch: // semi-ring ops with custom logic.
case kBinaryBranch:
return insertYieldOp(rewriter, loc, return insertYieldOp(rewriter, loc,
*tensorExps[e].op->getBlock()->getParent(), {v0}); *tensorExps[e].op->getBlock()->getParent(), {v0});
case kUnary: case kUnary:

View File

@ -136,43 +136,78 @@ protected:
} }
/// Compares expressions for equality. Equality is defined recursively as: /// Compares expressions for equality. Equality is defined recursively as:
/// - Two expressions can only be equal if they have the same Kind. /// - Operations are equal if they have the same kind and children.
/// - Two binary expressions are equal if they have the same Kind and their /// - Leaf tensors are equal if they refer to the same tensor.
/// children are equal.
/// - Expressions with Kind invariant or tensor are equal if they have the
/// same expression id.
bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) { bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
auto tensorExp = merger.exp(e); auto tensorExp = merger.exp(e);
if (tensorExp.kind != pattern->kind) if (tensorExp.kind != pattern->kind)
return false; return false;
assert(tensorExp.kind != Kind::kInvariant &&
"Invariant comparison not yet supported");
switch (tensorExp.kind) { switch (tensorExp.kind) {
case Kind::kTensor: // Leaf.
case kTensor:
return tensorExp.tensor == pattern->tensorNum; return tensorExp.tensor == pattern->tensorNum;
case Kind::kAbsF: case kInvariant:
case Kind::kCeilF: case kIndex:
case Kind::kFloorF: llvm_unreachable("invariant not handled yet");
case Kind::kNegF: // Unary operations.
case Kind::kNegI: case kAbsF:
case kAbsC:
case kCeilF:
case kFloorF:
case kSqrtF:
case kSqrtC:
case kExpm1F:
case kExpm1C:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
case kTanhC:
case kNegF:
case kNegC:
case kNegI:
case kTruncF:
case kExtF:
case kCastFS:
case kCastFU:
case kCastSF:
case kCastUF:
case kCastS:
case kCastU:
case kCastIdx:
case kTruncI:
case kCIm:
case kCRe:
case kBitCast:
case kBinaryBranch:
case kUnary:
case kShlI:
case kBinary:
return compareExpression(tensorExp.children.e0, pattern->e0); return compareExpression(tensorExp.children.e0, pattern->e0);
case Kind::kMulF: // Binary operations.
case Kind::kMulI: case kMulF:
case Kind::kDivF: case kMulC:
case Kind::kDivS: case kMulI:
case Kind::kDivU: case kDivF:
case Kind::kAddF: case kDivC:
case Kind::kAddI: case kDivS:
case Kind::kSubF: case kDivU:
case Kind::kSubI: case kAddF:
case Kind::kAndI: case kAddC:
case Kind::kOrI: case kAddI:
case Kind::kXorI: case kSubF:
case kSubC:
case kSubI:
case kAndI:
case kOrI:
case kXorI:
case kShrS:
case kShrU:
return compareExpression(tensorExp.children.e0, pattern->e0) && return compareExpression(tensorExp.children.e0, pattern->e0) &&
compareExpression(tensorExp.children.e1, pattern->e1); compareExpression(tensorExp.children.e1, pattern->e1);
default:
llvm_unreachable("Unhandled Kind");
} }
llvm_unreachable("unexpected kind");
} }
unsigned numTensors; unsigned numTensors;