[mlir][sparse] refactor handling of merger leafs and ops
Using "default:" in the switch statemements that handle all our merger ops has become a bit cumbersome since it is easy to overlook parts of the code that need to handle ops specifically. By enforcing full switch statements without "default:", we get a compiler warning when cases are overlooked. Reviewed By: wrengr Differential Revision: https://reviews.llvm.org/D127263
This commit is contained in:
parent
4badd4d40d
commit
06aa6ec87d
|
|
@ -25,6 +25,7 @@ namespace sparse_tensor {
|
||||||
TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
|
TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
|
||||||
: kind(k), val(v), op(o) {
|
: kind(k), val(v), op(o) {
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
|
// Leaf.
|
||||||
case kTensor:
|
case kTensor:
|
||||||
assert(x != -1u && y == -1u && !v && !o);
|
assert(x != -1u && y == -1u && !v && !o);
|
||||||
tensor = x;
|
tensor = x;
|
||||||
|
|
@ -36,6 +37,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
|
||||||
assert(x != -1u && y == -1u && !v && !o);
|
assert(x != -1u && y == -1u && !v && !o);
|
||||||
index = x;
|
index = x;
|
||||||
break;
|
break;
|
||||||
|
// Unary operations.
|
||||||
case kAbsF:
|
case kAbsF:
|
||||||
case kAbsC:
|
case kAbsC:
|
||||||
case kCeilF:
|
case kCeilF:
|
||||||
|
|
@ -86,13 +88,32 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
|
||||||
children.e0 = x;
|
children.e0 = x;
|
||||||
children.e1 = y;
|
children.e1 = y;
|
||||||
break;
|
break;
|
||||||
case kBinary:
|
// Binary operations.
|
||||||
assert(x != -1u && y != -1u && !v && o);
|
case kMulF:
|
||||||
|
case kMulC:
|
||||||
|
case kMulI:
|
||||||
|
case kDivF:
|
||||||
|
case kDivC:
|
||||||
|
case kDivS:
|
||||||
|
case kDivU:
|
||||||
|
case kAddF:
|
||||||
|
case kAddC:
|
||||||
|
case kAddI:
|
||||||
|
case kSubF:
|
||||||
|
case kSubC:
|
||||||
|
case kSubI:
|
||||||
|
case kAndI:
|
||||||
|
case kOrI:
|
||||||
|
case kXorI:
|
||||||
|
case kShrS:
|
||||||
|
case kShrU:
|
||||||
|
case kShlI:
|
||||||
|
assert(x != -1u && y != -1u && !v && !o);
|
||||||
children.e0 = x;
|
children.e0 = x;
|
||||||
children.e1 = y;
|
children.e1 = y;
|
||||||
break;
|
break;
|
||||||
default:
|
case kBinary:
|
||||||
assert(x != -1u && y != -1u && !v && !o);
|
assert(x != -1u && y != -1u && !v && o);
|
||||||
children.e0 = x;
|
children.e0 = x;
|
||||||
children.e1 = y;
|
children.e1 = y;
|
||||||
break;
|
break;
|
||||||
|
|
@ -280,8 +301,13 @@ bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
|
||||||
|
|
||||||
bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
||||||
switch (tensorExps[e].kind) {
|
switch (tensorExps[e].kind) {
|
||||||
|
// Leaf.
|
||||||
case kTensor:
|
case kTensor:
|
||||||
return tensorExps[e].tensor == t;
|
return tensorExps[e].tensor == t;
|
||||||
|
case kInvariant:
|
||||||
|
case kIndex:
|
||||||
|
return false;
|
||||||
|
// Unary operations.
|
||||||
case kAbsF:
|
case kAbsF:
|
||||||
case kAbsC:
|
case kAbsC:
|
||||||
case kCeilF:
|
case kCeilF:
|
||||||
|
|
@ -313,6 +339,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
||||||
case kCRe:
|
case kCRe:
|
||||||
case kBitCast:
|
case kBitCast:
|
||||||
return isSingleCondition(t, tensorExps[e].children.e0);
|
return isSingleCondition(t, tensorExps[e].children.e0);
|
||||||
|
case kBinaryBranch:
|
||||||
|
case kUnary:
|
||||||
|
return false;
|
||||||
|
// Binary operations.
|
||||||
case kDivF: // note: x / c only
|
case kDivF: // note: x / c only
|
||||||
case kDivC:
|
case kDivC:
|
||||||
case kDivS:
|
case kDivS:
|
||||||
|
|
@ -339,7 +369,12 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
||||||
case kAddI:
|
case kAddI:
|
||||||
return isSingleCondition(t, tensorExps[e].children.e0) &&
|
return isSingleCondition(t, tensorExps[e].children.e0) &&
|
||||||
isSingleCondition(t, tensorExps[e].children.e1);
|
isSingleCondition(t, tensorExps[e].children.e1);
|
||||||
default:
|
case kSubF:
|
||||||
|
case kSubC:
|
||||||
|
case kSubI:
|
||||||
|
case kOrI:
|
||||||
|
case kXorI:
|
||||||
|
case kBinary:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -352,12 +387,14 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
||||||
|
|
||||||
static const char *kindToOpSymbol(Kind kind) {
|
static const char *kindToOpSymbol(Kind kind) {
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
|
// Leaf.
|
||||||
case kTensor:
|
case kTensor:
|
||||||
return "tensor";
|
return "tensor";
|
||||||
case kInvariant:
|
case kInvariant:
|
||||||
return "invariant";
|
return "invariant";
|
||||||
case kIndex:
|
case kIndex:
|
||||||
return "index";
|
return "index";
|
||||||
|
// Unary operations.
|
||||||
case kAbsF:
|
case kAbsF:
|
||||||
case kAbsC:
|
case kAbsC:
|
||||||
return "abs";
|
return "abs";
|
||||||
|
|
@ -404,6 +441,7 @@ static const char *kindToOpSymbol(Kind kind) {
|
||||||
return "binary_branch";
|
return "binary_branch";
|
||||||
case kUnary:
|
case kUnary:
|
||||||
return "unary";
|
return "unary";
|
||||||
|
// Binary operations.
|
||||||
case kMulF:
|
case kMulF:
|
||||||
case kMulC:
|
case kMulC:
|
||||||
case kMulI:
|
case kMulI:
|
||||||
|
|
@ -441,6 +479,7 @@ static const char *kindToOpSymbol(Kind kind) {
|
||||||
|
|
||||||
void Merger::dumpExp(unsigned e) const {
|
void Merger::dumpExp(unsigned e) const {
|
||||||
switch (tensorExps[e].kind) {
|
switch (tensorExps[e].kind) {
|
||||||
|
// Leaf.
|
||||||
case kTensor:
|
case kTensor:
|
||||||
if (tensorExps[e].tensor == syntheticTensor)
|
if (tensorExps[e].tensor == syntheticTensor)
|
||||||
llvm::dbgs() << "synthetic_";
|
llvm::dbgs() << "synthetic_";
|
||||||
|
|
@ -454,7 +493,9 @@ void Merger::dumpExp(unsigned e) const {
|
||||||
case kIndex:
|
case kIndex:
|
||||||
llvm::dbgs() << "index_" << tensorExps[e].index;
|
llvm::dbgs() << "index_" << tensorExps[e].index;
|
||||||
break;
|
break;
|
||||||
|
// Unary operations.
|
||||||
case kAbsF:
|
case kAbsF:
|
||||||
|
case kAbsC:
|
||||||
case kCeilF:
|
case kCeilF:
|
||||||
case kFloorF:
|
case kFloorF:
|
||||||
case kSqrtF:
|
case kSqrtF:
|
||||||
|
|
@ -462,10 +503,13 @@ void Merger::dumpExp(unsigned e) const {
|
||||||
case kExpm1F:
|
case kExpm1F:
|
||||||
case kExpm1C:
|
case kExpm1C:
|
||||||
case kLog1pF:
|
case kLog1pF:
|
||||||
|
case kLog1pC:
|
||||||
case kSinF:
|
case kSinF:
|
||||||
|
case kSinC:
|
||||||
case kTanhF:
|
case kTanhF:
|
||||||
case kTanhC:
|
case kTanhC:
|
||||||
case kNegF:
|
case kNegF:
|
||||||
|
case kNegC:
|
||||||
case kNegI:
|
case kNegI:
|
||||||
case kTruncF:
|
case kTruncF:
|
||||||
case kExtF:
|
case kExtF:
|
||||||
|
|
@ -477,11 +521,35 @@ void Merger::dumpExp(unsigned e) const {
|
||||||
case kCastU:
|
case kCastU:
|
||||||
case kCastIdx:
|
case kCastIdx:
|
||||||
case kTruncI:
|
case kTruncI:
|
||||||
|
case kCIm:
|
||||||
|
case kCRe:
|
||||||
case kBitCast:
|
case kBitCast:
|
||||||
|
case kBinaryBranch:
|
||||||
|
case kUnary:
|
||||||
llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
|
llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
|
||||||
dumpExp(tensorExps[e].children.e0);
|
dumpExp(tensorExps[e].children.e0);
|
||||||
break;
|
break;
|
||||||
default:
|
// Binary operations.
|
||||||
|
case kMulF:
|
||||||
|
case kMulC:
|
||||||
|
case kMulI:
|
||||||
|
case kDivF:
|
||||||
|
case kDivC:
|
||||||
|
case kDivS:
|
||||||
|
case kDivU:
|
||||||
|
case kAddF:
|
||||||
|
case kAddC:
|
||||||
|
case kAddI:
|
||||||
|
case kSubF:
|
||||||
|
case kSubC:
|
||||||
|
case kSubI:
|
||||||
|
case kAndI:
|
||||||
|
case kOrI:
|
||||||
|
case kXorI:
|
||||||
|
case kShrS:
|
||||||
|
case kShrU:
|
||||||
|
case kShlI:
|
||||||
|
case kBinary:
|
||||||
llvm::dbgs() << "(";
|
llvm::dbgs() << "(";
|
||||||
dumpExp(tensorExps[e].children.e0);
|
dumpExp(tensorExps[e].children.e0);
|
||||||
llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
|
llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
|
||||||
|
|
@ -542,6 +610,7 @@ void Merger::dumpBits(const BitVector &bits) const {
|
||||||
unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
||||||
Kind kind = tensorExps[e].kind;
|
Kind kind = tensorExps[e].kind;
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
|
// Leaf.
|
||||||
case kTensor:
|
case kTensor:
|
||||||
case kInvariant:
|
case kInvariant:
|
||||||
case kIndex: {
|
case kIndex: {
|
||||||
|
|
@ -560,11 +629,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
||||||
latSets[s].push_back(addLat(t, i, e));
|
latSets[s].push_back(addLat(t, i, e));
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
// Unary operations.
|
||||||
case kAbsF:
|
case kAbsF:
|
||||||
case kAbsC:
|
case kAbsC:
|
||||||
case kCeilF:
|
case kCeilF:
|
||||||
case kCIm:
|
|
||||||
case kCRe:
|
|
||||||
case kFloorF:
|
case kFloorF:
|
||||||
case kSqrtF:
|
case kSqrtF:
|
||||||
case kSqrtC:
|
case kSqrtC:
|
||||||
|
|
@ -589,6 +657,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
||||||
case kCastU:
|
case kCastU:
|
||||||
case kCastIdx:
|
case kCastIdx:
|
||||||
case kTruncI:
|
case kTruncI:
|
||||||
|
case kCIm:
|
||||||
|
case kCRe:
|
||||||
case kBitCast:
|
case kBitCast:
|
||||||
// A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
|
// A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
|
||||||
// lattice set of the operand through the operator into a new set.
|
// lattice set of the operand through the operator into a new set.
|
||||||
|
|
@ -625,6 +695,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
||||||
unsigned rhs = addExp(kInvariant, absentVal);
|
unsigned rhs = addExp(kInvariant, absentVal);
|
||||||
return takeDisj(kind, child0, buildLattices(rhs, i), unop);
|
return takeDisj(kind, child0, buildLattices(rhs, i), unop);
|
||||||
}
|
}
|
||||||
|
// Binary operations.
|
||||||
case kMulF:
|
case kMulF:
|
||||||
case kMulC:
|
case kMulC:
|
||||||
case kMulI:
|
case kMulI:
|
||||||
|
|
@ -955,16 +1026,17 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
|
||||||
Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
||||||
Value v0, Value v1) {
|
Value v0, Value v1) {
|
||||||
switch (tensorExps[e].kind) {
|
switch (tensorExps[e].kind) {
|
||||||
|
// Leaf.
|
||||||
case kTensor:
|
case kTensor:
|
||||||
case kInvariant:
|
case kInvariant:
|
||||||
case kIndex:
|
case kIndex:
|
||||||
llvm_unreachable("unexpected non-op");
|
llvm_unreachable("unexpected non-op");
|
||||||
// Unary ops.
|
// Unary operations.
|
||||||
case kAbsF:
|
case kAbsF:
|
||||||
return rewriter.create<math::AbsOp>(loc, v0);
|
return rewriter.create<math::AbsOp>(loc, v0);
|
||||||
case kAbsC: {
|
case kAbsC: {
|
||||||
auto type = v0.getType().template cast<ComplexType>();
|
auto type = v0.getType().cast<ComplexType>();
|
||||||
auto eltType = type.getElementType().template cast<FloatType>();
|
auto eltType = type.getElementType().cast<FloatType>();
|
||||||
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
|
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
|
||||||
}
|
}
|
||||||
case kCeilF:
|
case kCeilF:
|
||||||
|
|
@ -1021,18 +1093,19 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
||||||
return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
|
return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
|
||||||
case kTruncI:
|
case kTruncI:
|
||||||
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
|
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
|
||||||
case kCIm:
|
case kCIm: {
|
||||||
|
auto type = v0.getType().cast<ComplexType>();
|
||||||
|
auto eltType = type.getElementType().cast<FloatType>();
|
||||||
|
return rewriter.create<complex::ImOp>(loc, eltType, v0);
|
||||||
|
}
|
||||||
case kCRe: {
|
case kCRe: {
|
||||||
auto type = v0.getType().template cast<ComplexType>();
|
auto type = v0.getType().cast<ComplexType>();
|
||||||
auto eltType = type.getElementType().template cast<FloatType>();
|
auto eltType = type.getElementType().cast<FloatType>();
|
||||||
if (tensorExps[e].kind == kCIm)
|
|
||||||
return rewriter.create<complex::ImOp>(loc, eltType, v0);
|
|
||||||
|
|
||||||
return rewriter.create<complex::ReOp>(loc, eltType, v0);
|
return rewriter.create<complex::ReOp>(loc, eltType, v0);
|
||||||
}
|
}
|
||||||
case kBitCast:
|
case kBitCast:
|
||||||
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
|
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
|
||||||
// Binary ops.
|
// Binary operations.
|
||||||
case kMulF:
|
case kMulF:
|
||||||
return rewriter.create<arith::MulFOp>(loc, v0, v1);
|
return rewriter.create<arith::MulFOp>(loc, v0, v1);
|
||||||
case kMulC:
|
case kMulC:
|
||||||
|
|
@ -1071,8 +1144,7 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
||||||
return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
|
return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
|
||||||
case kShlI:
|
case kShlI:
|
||||||
return rewriter.create<arith::ShLIOp>(loc, v0, v1);
|
return rewriter.create<arith::ShLIOp>(loc, v0, v1);
|
||||||
// Semiring ops with custom logic.
|
case kBinaryBranch: // semi-ring ops with custom logic.
|
||||||
case kBinaryBranch:
|
|
||||||
return insertYieldOp(rewriter, loc,
|
return insertYieldOp(rewriter, loc,
|
||||||
*tensorExps[e].op->getBlock()->getParent(), {v0});
|
*tensorExps[e].op->getBlock()->getParent(), {v0});
|
||||||
case kUnary:
|
case kUnary:
|
||||||
|
|
|
||||||
|
|
@ -136,43 +136,78 @@ protected:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compares expressions for equality. Equality is defined recursively as:
|
/// Compares expressions for equality. Equality is defined recursively as:
|
||||||
/// - Two expressions can only be equal if they have the same Kind.
|
/// - Operations are equal if they have the same kind and children.
|
||||||
/// - Two binary expressions are equal if they have the same Kind and their
|
/// - Leaf tensors are equal if they refer to the same tensor.
|
||||||
/// children are equal.
|
|
||||||
/// - Expressions with Kind invariant or tensor are equal if they have the
|
|
||||||
/// same expression id.
|
|
||||||
bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
|
bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
|
||||||
auto tensorExp = merger.exp(e);
|
auto tensorExp = merger.exp(e);
|
||||||
if (tensorExp.kind != pattern->kind)
|
if (tensorExp.kind != pattern->kind)
|
||||||
return false;
|
return false;
|
||||||
assert(tensorExp.kind != Kind::kInvariant &&
|
|
||||||
"Invariant comparison not yet supported");
|
|
||||||
switch (tensorExp.kind) {
|
switch (tensorExp.kind) {
|
||||||
case Kind::kTensor:
|
// Leaf.
|
||||||
|
case kTensor:
|
||||||
return tensorExp.tensor == pattern->tensorNum;
|
return tensorExp.tensor == pattern->tensorNum;
|
||||||
case Kind::kAbsF:
|
case kInvariant:
|
||||||
case Kind::kCeilF:
|
case kIndex:
|
||||||
case Kind::kFloorF:
|
llvm_unreachable("invariant not handled yet");
|
||||||
case Kind::kNegF:
|
// Unary operations.
|
||||||
case Kind::kNegI:
|
case kAbsF:
|
||||||
|
case kAbsC:
|
||||||
|
case kCeilF:
|
||||||
|
case kFloorF:
|
||||||
|
case kSqrtF:
|
||||||
|
case kSqrtC:
|
||||||
|
case kExpm1F:
|
||||||
|
case kExpm1C:
|
||||||
|
case kLog1pF:
|
||||||
|
case kLog1pC:
|
||||||
|
case kSinF:
|
||||||
|
case kSinC:
|
||||||
|
case kTanhF:
|
||||||
|
case kTanhC:
|
||||||
|
case kNegF:
|
||||||
|
case kNegC:
|
||||||
|
case kNegI:
|
||||||
|
case kTruncF:
|
||||||
|
case kExtF:
|
||||||
|
case kCastFS:
|
||||||
|
case kCastFU:
|
||||||
|
case kCastSF:
|
||||||
|
case kCastUF:
|
||||||
|
case kCastS:
|
||||||
|
case kCastU:
|
||||||
|
case kCastIdx:
|
||||||
|
case kTruncI:
|
||||||
|
case kCIm:
|
||||||
|
case kCRe:
|
||||||
|
case kBitCast:
|
||||||
|
case kBinaryBranch:
|
||||||
|
case kUnary:
|
||||||
|
case kShlI:
|
||||||
|
case kBinary:
|
||||||
return compareExpression(tensorExp.children.e0, pattern->e0);
|
return compareExpression(tensorExp.children.e0, pattern->e0);
|
||||||
case Kind::kMulF:
|
// Binary operations.
|
||||||
case Kind::kMulI:
|
case kMulF:
|
||||||
case Kind::kDivF:
|
case kMulC:
|
||||||
case Kind::kDivS:
|
case kMulI:
|
||||||
case Kind::kDivU:
|
case kDivF:
|
||||||
case Kind::kAddF:
|
case kDivC:
|
||||||
case Kind::kAddI:
|
case kDivS:
|
||||||
case Kind::kSubF:
|
case kDivU:
|
||||||
case Kind::kSubI:
|
case kAddF:
|
||||||
case Kind::kAndI:
|
case kAddC:
|
||||||
case Kind::kOrI:
|
case kAddI:
|
||||||
case Kind::kXorI:
|
case kSubF:
|
||||||
|
case kSubC:
|
||||||
|
case kSubI:
|
||||||
|
case kAndI:
|
||||||
|
case kOrI:
|
||||||
|
case kXorI:
|
||||||
|
case kShrS:
|
||||||
|
case kShrU:
|
||||||
return compareExpression(tensorExp.children.e0, pattern->e0) &&
|
return compareExpression(tensorExp.children.e0, pattern->e0) &&
|
||||||
compareExpression(tensorExp.children.e1, pattern->e1);
|
compareExpression(tensorExp.children.e1, pattern->e1);
|
||||||
default:
|
|
||||||
llvm_unreachable("Unhandled Kind");
|
|
||||||
}
|
}
|
||||||
|
llvm_unreachable("unexpected kind");
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned numTensors;
|
unsigned numTensors;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue