[llvm][STLExtras] Move the algorithm `interleave*` methods from MLIR to LLVM

These have proved incredibly useful for interleaving values between a range w.r.t to streams. After this revision, the mlir/Support/STLExtras.h is empty. A followup revision will remove it from the tree.

Differential Revision: https://reviews.llvm.org/D78067
This commit is contained in:
River Riddle 2020-04-14 14:53:28 -07:00
parent 204c3b5516
commit 2f21a57966
40 changed files with 219 additions and 220 deletions

View File

@ -248,14 +248,7 @@ class fir_AllocatableOp<string mnemonic, list<OpTrait> traits =[]> :
p << getOperationName() << ' ' << getAttr(inType());
if (hasLenParams()) {
// print the LEN parameters to a derived type in parens
p << '(';
p.printOperands(getLenParams());
p << " : ";
mlir::interleaveComma(getLenParams(), p.getStream(),
[&](const auto &opnd) {
p.printType(opnd.getType());
});
p << ')';
p << '(' << getLenParams() << " : " << getLenParams().getTypes() << ')';
}
// print the shape of the allocation (if any); all must be index type
for (auto sh : getShapeOperands()) {

View File

@ -50,6 +50,10 @@ namespace detail {
template <typename RangeT>
using IterOfRange = decltype(std::begin(std::declval<RangeT &>()));
template <typename RangeT>
using ValueOfRange = typename std::remove_reference<decltype(
*std::begin(std::declval<RangeT &>()))>::type;
} // end namespace detail
//===----------------------------------------------------------------------===//
@ -1674,6 +1678,69 @@ void replace(Container &Cont, typename Container::iterator ContIt,
replace(Cont, ContIt, ContEnd, R.begin(), R.end());
}
/// An STL-style algorithm similar to std::for_each that applies a second
/// functor between every pair of elements.
///
/// This provides the control flow logic to, for example, print a
/// comma-separated list:
/// \code
/// interleave(names.begin(), names.end(),
/// [&](StringRef name) { os << name; },
/// [&] { os << ", "; });
/// \endcode
template <typename ForwardIterator, typename UnaryFunctor,
typename NullaryFunctor,
typename = typename std::enable_if<
!std::is_constructible<StringRef, UnaryFunctor>::value &&
!std::is_constructible<StringRef, NullaryFunctor>::value>::type>
inline void interleave(ForwardIterator begin, ForwardIterator end,
UnaryFunctor each_fn, NullaryFunctor between_fn) {
if (begin == end)
return;
each_fn(*begin);
++begin;
for (; begin != end; ++begin) {
between_fn();
each_fn(*begin);
}
}
template <typename Container, typename UnaryFunctor, typename NullaryFunctor,
typename = typename std::enable_if<
!std::is_constructible<StringRef, UnaryFunctor>::value &&
!std::is_constructible<StringRef, NullaryFunctor>::value>::type>
inline void interleave(const Container &c, UnaryFunctor each_fn,
NullaryFunctor between_fn) {
interleave(c.begin(), c.end(), each_fn, between_fn);
}
/// Overload of interleave for the common case of string separator.
template <typename Container, typename UnaryFunctor, typename StreamT,
typename T = detail::ValueOfRange<Container>>
inline void interleave(const Container &c, StreamT &os, UnaryFunctor each_fn,
const StringRef &separator) {
interleave(c.begin(), c.end(), each_fn, [&] { os << separator; });
}
template <typename Container, typename StreamT,
typename T = detail::ValueOfRange<Container>>
inline void interleave(const Container &c, StreamT &os,
const StringRef &separator) {
interleave(
c, os, [&](const T &a) { os << a; }, separator);
}
template <typename Container, typename UnaryFunctor, typename StreamT,
typename T = detail::ValueOfRange<Container>>
inline void interleaveComma(const Container &c, StreamT &os,
UnaryFunctor each_fn) {
interleave(c, os, each_fn, ", ");
}
template <typename Container, typename StreamT,
typename T = detail::ValueOfRange<Container>>
inline void interleaveComma(const Container &c, StreamT &os) {
interleaveComma(c, os, [&](const T &a) { os << a; });
}
//===----------------------------------------------------------------------===//
// Extra additions to <memory>
//===----------------------------------------------------------------------===//

View File

@ -61,7 +61,7 @@ everything to the LLVM dialect.
```c++
mlir::ConversionTarget target(getContext());
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
target.addLegalDialect<mlir::LLVMDialect>();
target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
```

View File

@ -319,7 +319,7 @@ void ToyDialect::printType(mlir::Type type,
// Print the struct type according to the parser format.
printer << "struct<";
mlir::interleaveComma(structType.getElementTypes(), printer);
llvm::interleaveComma(structType.getElementTypes(), printer);
printer << '>';
}
```

View File

@ -127,12 +127,12 @@ void printLitHelper(ExprAST *litOrNum) {
// Print the dimension for this literal first
llvm::errs() << "<";
mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
mlir::interleaveComma(literal->getValues(), llvm::errs(),
llvm::interleaveComma(literal->getValues(), llvm::errs(),
[&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
@ -194,7 +194,7 @@ void ASTDumper::dump(PrintExprAST *node) {
/// Print type: only the shape is printed in between '<' and '>'
void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
mlir::interleaveComma(type.shape, llvm::errs());
llvm::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
@ -205,7 +205,7 @@ void ASTDumper::dump(PrototypeAST *node) {
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
mlir::interleaveComma(node->getArgs(), llvm::errs(),
llvm::interleaveComma(node->getArgs(), llvm::errs(),
[](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}

View File

@ -127,12 +127,12 @@ void printLitHelper(ExprAST *litOrNum) {
// Print the dimension for this literal first
llvm::errs() << "<";
mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
mlir::interleaveComma(literal->getValues(), llvm::errs(),
llvm::interleaveComma(literal->getValues(), llvm::errs(),
[&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
@ -194,7 +194,7 @@ void ASTDumper::dump(PrintExprAST *node) {
/// Print type: only the shape is printed in between '<' and '>'
void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
mlir::interleaveComma(type.shape, llvm::errs());
llvm::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
@ -205,7 +205,7 @@ void ASTDumper::dump(PrototypeAST *node) {
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
mlir::interleaveComma(node->getArgs(), llvm::errs(),
llvm::interleaveComma(node->getArgs(), llvm::errs(),
[](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}

View File

@ -127,12 +127,12 @@ void printLitHelper(ExprAST *litOrNum) {
// Print the dimension for this literal first
llvm::errs() << "<";
mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
mlir::interleaveComma(literal->getValues(), llvm::errs(),
llvm::interleaveComma(literal->getValues(), llvm::errs(),
[&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
@ -194,7 +194,7 @@ void ASTDumper::dump(PrintExprAST *node) {
/// Print type: only the shape is printed in between '<' and '>'
void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
mlir::interleaveComma(type.shape, llvm::errs());
llvm::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
@ -205,7 +205,7 @@ void ASTDumper::dump(PrototypeAST *node) {
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
mlir::interleaveComma(node->getArgs(), llvm::errs(),
llvm::interleaveComma(node->getArgs(), llvm::errs(),
[](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}

View File

@ -127,12 +127,12 @@ void printLitHelper(ExprAST *litOrNum) {
// Print the dimension for this literal first
llvm::errs() << "<";
mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
mlir::interleaveComma(literal->getValues(), llvm::errs(),
llvm::interleaveComma(literal->getValues(), llvm::errs(),
[&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
@ -194,7 +194,7 @@ void ASTDumper::dump(PrintExprAST *node) {
/// Print type: only the shape is printed in between '<' and '>'
void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
mlir::interleaveComma(type.shape, llvm::errs());
llvm::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
@ -205,7 +205,7 @@ void ASTDumper::dump(PrototypeAST *node) {
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
mlir::interleaveComma(node->getArgs(), llvm::errs(),
llvm::interleaveComma(node->getArgs(), llvm::errs(),
[](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}

View File

@ -127,12 +127,12 @@ void printLitHelper(ExprAST *litOrNum) {
// Print the dimension for this literal first
llvm::errs() << "<";
mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
mlir::interleaveComma(literal->getValues(), llvm::errs(),
llvm::interleaveComma(literal->getValues(), llvm::errs(),
[&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
@ -194,7 +194,7 @@ void ASTDumper::dump(PrintExprAST *node) {
/// Print type: only the shape is printed in between '<' and '>'
void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
mlir::interleaveComma(type.shape, llvm::errs());
llvm::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
@ -205,7 +205,7 @@ void ASTDumper::dump(PrototypeAST *node) {
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
mlir::interleaveComma(node->getArgs(), llvm::errs(),
llvm::interleaveComma(node->getArgs(), llvm::errs(),
[](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}

View File

@ -127,12 +127,12 @@ void printLitHelper(ExprAST *litOrNum) {
// Print the dimension for this literal first
llvm::errs() << "<";
mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
mlir::interleaveComma(literal->getValues(), llvm::errs(),
llvm::interleaveComma(literal->getValues(), llvm::errs(),
[&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
@ -194,7 +194,7 @@ void ASTDumper::dump(PrintExprAST *node) {
/// Print type: only the shape is printed in between '<' and '>'
void ASTDumper::dump(const VarType &type) {
llvm::errs() << "<";
mlir::interleaveComma(type.shape, llvm::errs());
llvm::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
@ -205,7 +205,7 @@ void ASTDumper::dump(PrototypeAST *node) {
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
mlir::interleaveComma(node->getArgs(), llvm::errs(),
llvm::interleaveComma(node->getArgs(), llvm::errs(),
[](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}

View File

@ -537,7 +537,7 @@ void ToyDialect::printType(mlir::Type type,
// Print the struct type according to the parser format.
printer << "struct<";
mlir::interleaveComma(structType.getElementTypes(), printer);
llvm::interleaveComma(structType.getElementTypes(), printer);
printer << '>';
}

View File

@ -130,12 +130,12 @@ void printLitHelper(ExprAST *litOrNum) {
// Print the dimension for this literal first
llvm::errs() << "<";
mlir::interleaveComma(literal->getDims(), llvm::errs());
llvm::interleaveComma(literal->getDims(), llvm::errs());
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
mlir::interleaveComma(literal->getValues(), llvm::errs(),
llvm::interleaveComma(literal->getValues(), llvm::errs(),
[&](auto &elt) { printLitHelper(elt.get()); });
llvm::errs() << "]";
}
@ -210,7 +210,7 @@ void ASTDumper::dump(const VarType &type) {
if (!type.name.empty())
llvm::errs() << type.name;
else
mlir::interleaveComma(type.shape, llvm::errs());
llvm::interleaveComma(type.shape, llvm::errs());
llvm::errs() << ">";
}
@ -221,7 +221,7 @@ void ASTDumper::dump(PrototypeAST *node) {
llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
indent();
llvm::errs() << "Params: [";
mlir::interleaveComma(node->getArgs(), llvm::errs(),
llvm::interleaveComma(node->getArgs(), llvm::errs(),
[](auto &arg) { llvm::errs() << arg->getName(); });
llvm::errs() << "]\n";
}

View File

@ -232,9 +232,8 @@ public:
/// is ','.
template <typename T, template <typename> class Container>
Diagnostic &appendRange(const Container<T> &c, const char *delim = ", ") {
interleave(
c, [&](const detail::ValueOfRange<Container<T>> &a) { *this << a; },
[&]() { *this << delim; });
llvm::interleave(
c, [this](const auto &a) { *this << a; }, [&]() { *this << delim; });
return *this;
}

View File

@ -117,7 +117,7 @@ public:
(*types.begin()).template isa<FunctionType>();
if (wrapped)
os << '(';
interleaveComma(types, *this);
llvm::interleaveComma(types, *this);
if (wrapped)
os << ')';
}
@ -131,7 +131,7 @@ public:
void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
auto &os = getStream();
os << "(";
interleaveComma(inputs, *this);
llvm::interleaveComma(inputs, *this);
os << ")";
printArrowTypeList(results);
}
@ -199,11 +199,11 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
template <typename ValueRangeT>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
const ValueTypeRange<ValueRangeT> &types) {
interleaveComma(types, p);
llvm::interleaveComma(types, p);
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
interleaveComma(types, p);
llvm::interleaveComma(types, p);
return p;
}

View File

@ -194,7 +194,7 @@ public:
auto printElementFn = [&](const DataType &value) {
printValue(os, this->getParser(), value);
};
interleave(*this, os, printElementFn, ",");
llvm::interleave(*this, os, printElementFn, ",");
}
/// Copy the value from the given option into this one.

View File

@ -19,75 +19,6 @@
namespace mlir {
namespace detail {
template <typename RangeT>
using ValueOfRange = typename std::remove_reference<decltype(
*std::begin(std::declval<RangeT &>()))>::type;
} // end namespace detail
/// An STL-style algorithm similar to std::for_each that applies a second
/// functor between every pair of elements.
///
/// This provides the control flow logic to, for example, print a
/// comma-separated list:
/// \code
/// interleave(names.begin(), names.end(),
/// [&](StringRef name) { os << name; },
/// [&] { os << ", "; });
/// \endcode
template <typename ForwardIterator, typename UnaryFunctor,
typename NullaryFunctor,
typename = typename std::enable_if<
!std::is_constructible<StringRef, UnaryFunctor>::value &&
!std::is_constructible<StringRef, NullaryFunctor>::value>::type>
inline void interleave(ForwardIterator begin, ForwardIterator end,
UnaryFunctor each_fn, NullaryFunctor between_fn) {
if (begin == end)
return;
each_fn(*begin);
++begin;
for (; begin != end; ++begin) {
between_fn();
each_fn(*begin);
}
}
template <typename Container, typename UnaryFunctor, typename NullaryFunctor,
typename = typename std::enable_if<
!std::is_constructible<StringRef, UnaryFunctor>::value &&
!std::is_constructible<StringRef, NullaryFunctor>::value>::type>
inline void interleave(const Container &c, UnaryFunctor each_fn,
NullaryFunctor between_fn) {
interleave(c.begin(), c.end(), each_fn, between_fn);
}
/// Overload of interleave for the common case of string separator.
template <typename Container, typename UnaryFunctor, typename raw_ostream,
typename T = detail::ValueOfRange<Container>>
inline void interleave(const Container &c, raw_ostream &os,
UnaryFunctor each_fn, const StringRef &separator) {
interleave(c.begin(), c.end(), each_fn, [&] { os << separator; });
}
template <typename Container, typename raw_ostream,
typename T = detail::ValueOfRange<Container>>
inline void interleave(const Container &c, raw_ostream &os,
const StringRef &separator) {
interleave(
c, os, [&](const T &a) { os << a; }, separator);
}
template <typename Container, typename UnaryFunctor, typename raw_ostream,
typename T = detail::ValueOfRange<Container>>
inline void interleaveComma(const Container &c, raw_ostream &os,
UnaryFunctor each_fn) {
interleave(c, os, each_fn, ", ");
}
template <typename Container, typename raw_ostream,
typename T = detail::ValueOfRange<Container>>
inline void interleaveComma(const Container &c, raw_ostream &os) {
interleaveComma(c, os, [&](const T &a) { os << a; });
}
} // end namespace mlir
#endif // MLIR_SUPPORT_STLEXTRAS_H

View File

@ -2304,7 +2304,7 @@ static void print(OpAsmPrinter &p, AffineParallelOp op) {
}
if (!elideSteps) {
p << " step (";
interleaveComma(steps, p);
llvm::interleaveComma(steps, p);
p << ')';
}
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,

View File

@ -610,8 +610,8 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
return;
p << ' ' << keyword << '(';
interleaveComma(values, p,
[&p](BlockArgument v) { p << v << " : " << v.getType(); });
llvm::interleaveComma(
values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
p << ')';
}

View File

@ -1056,7 +1056,7 @@ static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
appendMangledType(ss, memref.getElementType());
} else if (auto vec = t.dyn_cast<VectorType>()) {
ss << "vector";
interleave(
llvm::interleave(
vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
appendMangledType(ss, vec.getElementType());
} else if (t.isSignlessIntOrIndexOrFloat()) {
@ -1074,7 +1074,7 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
llvm::raw_string_ostream ss(name);
ss << "_";
auto types = op->getOperandTypes();
interleave(
llvm::interleave(
types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
[&]() { ss << "_"; });
return ss.str();

View File

@ -107,7 +107,7 @@ static void print(OpAsmPrinter &p, ForOp op) {
auto regionArgs = op.getRegionIterArgs();
auto operands = op.getIterOperands();
mlir::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
p << std::get<0>(it) << " = " << std::get<1>(it);
});
p << ")";

View File

@ -354,7 +354,7 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
ArrayRef<double> scales = type.getScales();
ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
out << "{";
interleave(
llvm::interleave(
llvm::seq<size_t>(0, scales.size()), out,
[&](size_t index) {
printQuantParams(scales[index], zeroPoints[index], out);

View File

@ -587,12 +587,12 @@ static void print(StructType type, DialectAsmPrinter &os) {
auto eachFn = [&os](spirv::Decoration decoration) {
os << stringifyDecoration(decoration);
};
interleaveComma(decorations, os, eachFn);
llvm::interleaveComma(decorations, os, eachFn);
os << "]";
}
};
interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
printMember);
llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
printMember);
os << ">";
}
@ -856,11 +856,11 @@ static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
auto &os = printer.getStream();
printer << spirv::VerCapExtAttr::getKindName() << "<"
<< spirv::stringifyVersion(triple.getVersion()) << ", [";
interleaveComma(triple.getCapabilities(), os, [&](spirv::Capability cap) {
os << spirv::stringifyCapability(cap);
});
llvm::interleaveComma(
triple.getCapabilities(), os,
[&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
printer << "], [";
interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
os << attr.cast<StringAttr>().getValue();
});
printer << "]>";

View File

@ -1064,7 +1064,7 @@ static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
if (auto weights = branchOp.branch_weights()) {
printer << " [";
interleaveComma(weights->getValue(), printer, [&](Attribute a) {
llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
printer << a.cast<IntegerAttr>().getInt();
});
printer << "]";
@ -1465,7 +1465,7 @@ static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
auto interfaceVars = entryPointOp.interface().getValue();
if (!interfaceVars.empty()) {
printer << ", ";
interleaveComma(interfaceVars, printer);
llvm::interleaveComma(interfaceVars, printer);
}
}
@ -1521,7 +1521,7 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
if (!values.size())
return;
printer << ", ";
interleaveComma(values, printer, [&](Attribute a) {
llvm::interleaveComma(values, printer, [&](Attribute a) {
printer << a.cast<IntegerAttr>().getInt();
});
}

View File

@ -1512,7 +1512,7 @@ static void print(OpAsmPrinter &p, TupleOp op) {
p.printOperands(op.getOperands());
p.printOptionalAttrDict(op.getAttrs());
p << " : ";
interleaveComma(op.getOperation()->getOperandTypes(), p);
llvm::interleaveComma(op.getOperation()->getOperandTypes(), p);
}
static LogicalResult verify(TupleOp op) { return success(); }

View File

@ -933,7 +933,7 @@ public:
template <typename Container, typename UnaryFunctor>
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
mlir::interleaveComma(c, os, each_fn);
llvm::interleaveComma(c, os, each_fn);
}
/// This enum describes the different kinds of elision for the type of an

View File

@ -216,11 +216,12 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty();
if (needsParens)
os << '(';
interleaveComma(llvm::zip(types, attrs), os,
[&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) {
p.printType(std::get<0>(t));
p.printOptionalAttrDict(std::get<1>(t));
});
llvm::interleaveComma(
llvm::zip(types, attrs), os,
[&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) {
p.printType(std::get<0>(t));
p.printOptionalAttrDict(std::get<1>(t));
});
if (needsParens)
os << ')';
}

View File

@ -52,11 +52,12 @@ void Pass::copyOptionValuesFrom(const Pass *other) {
void Pass::printAsTextualPipeline(raw_ostream &os) {
// Special case for adaptors to use the 'op_name(sub_passes)' format.
if (auto *adaptor = getAdaptorPassBase(this)) {
interleaveComma(adaptor->getPassManagers(), os, [&](OpPassManager &pm) {
os << pm.getOpName() << "(";
pm.printAsTextualPipeline(os);
os << ")";
});
llvm::interleaveComma(adaptor->getPassManagers(), os,
[&](OpPassManager &pm) {
os << pm.getOpName() << "(";
pm.printAsTextualPipeline(os);
os << ")";
});
return;
}
// Otherwise, print the pass argument followed by its options. If the pass
@ -295,9 +296,10 @@ void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
impl->passes, [](const std::unique_ptr<Pass> &pass) {
return !isa<VerifierPass>(pass);
});
interleaveComma(filteredPasses, os, [&](const std::unique_ptr<Pass> &pass) {
pass->printAsTextualPipeline(os);
});
llvm::interleaveComma(filteredPasses, os,
[&](const std::unique_ptr<Pass> &pass) {
pass->printAsTextualPipeline(os);
});
}
//===----------------------------------------------------------------------===//
@ -358,7 +360,7 @@ void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) {
std::string OpToOpPassAdaptorBase::getName() {
std::string name = "Pipeline Collection : [";
llvm::raw_string_ostream os(name);
interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
os << '\'' << pm.getOpName() << '\'';
});
os << ']';

View File

@ -184,7 +184,7 @@ void detail::PassOptions::print(raw_ostream &os) {
// Interleave the options with ' '.
os << '{';
interleave(
llvm::interleave(
orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
os << '}';
}

View File

@ -1250,7 +1250,7 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
auto &os = rewriterImpl.logger;
os.getOStream() << "\n";
os.startLine() << "* Pattern : '" << pattern->getRootKind() << " -> (";
interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
llvm::interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
os.getOStream() << ")' {\n";
os.indent();
});

View File

@ -65,7 +65,7 @@ std::string DOTGraphTraits<Block *>::getNodeLabel(Operation *op, Block *b) {
}
// Print resultant types
interleaveComma(op->getResultTypes(), os);
llvm::interleaveComma(op->getResultTypes(), os);
os << "\n";
// A value used to elide large container attribute.

View File

@ -120,7 +120,7 @@ void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
opInst->emitRemark("NOT MATCHED");
} else {
outs << "\nmatched: " << *opInst << " with shape ratio: ";
interleaveComma(MutableArrayRef<int64_t>(*ratio), outs);
llvm::interleaveComma(MutableArrayRef<int64_t>(*ratio), outs);
}
}
}

View File

@ -51,7 +51,7 @@ void PrintOpAvailability::runOnFunction() {
os << opName << " extensions: [";
for (const auto &exts : extension.getExtensions()) {
os << " [";
interleaveComma(exts, os, [&](spirv::Extension ext) {
llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
os << spirv::stringifyExtension(ext);
});
os << "]";
@ -63,7 +63,7 @@ void PrintOpAvailability::runOnFunction() {
os << opName << " capabilities: [";
for (const auto &caps : capability.getCapabilities()) {
os << " [";
interleaveComma(caps, os, [&](spirv::Capability cap) {
llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
os << spirv::stringifyCapability(cap);
});
os << "]";

View File

@ -38,7 +38,7 @@ void TestMemRefStrideCalculation::runOnFunction() {
else
llvm::outs() << offset;
llvm::outs() << " strides: ";
interleaveComma(strides, llvm::outs(), [&](int64_t v) {
llvm::interleaveComma(strides, llvm::outs(), [&](int64_t v) {
if (v == MemRefType::getDynamicStrideOrOffset())
llvm::outs() << "?";
else

View File

@ -1480,22 +1480,23 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os, StringRef opId,
std::string iteratorsStr;
llvm::raw_string_ostream ss(iteratorsStr);
unsigned pos = 0;
interleaveComma(state.dims, ss, [&](std::pair<StringRef, AffineExpr> p) {
bool reduction = false;
for (auto &expr : state.expressions) {
visitPostorder(*expr, [&](const Expression &e) {
if (auto *pTensorExpr = dyn_cast<TensorExpr>(&e)) {
if (pTensorExpr->reductionDimensions.count(pos) > 0)
reduction = true;
llvm::interleaveComma(
state.dims, ss, [&](std::pair<StringRef, AffineExpr> p) {
bool reduction = false;
for (auto &expr : state.expressions) {
visitPostorder(*expr, [&](const Expression &e) {
if (auto *pTensorExpr = dyn_cast<TensorExpr>(&e)) {
if (pTensorExpr->reductionDimensions.count(pos) > 0)
reduction = true;
}
});
if (reduction)
break;
}
ss << (reduction ? "getReductionIteratorTypeName()"
: "getParallelIteratorTypeName()");
pos++;
});
if (reduction)
break;
}
ss << (reduction ? "getReductionIteratorTypeName()"
: "getParallelIteratorTypeName()");
pos++;
});
ss.flush();
os << llvm::formatv(referenceReferenceIteratorsFmt, opId, iteratorsStr);
@ -1515,8 +1516,9 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId,
std::string dimsStr;
llvm::raw_string_ostream ss(dimsStr);
interleaveComma(state.dims, ss,
[&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
llvm::interleaveComma(
state.dims, ss,
[&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
ss.flush();
std::string mapsStr;
@ -1524,7 +1526,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId,
SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
for (auto it : state.orderedTensorArgs)
orderedUses[it.second] = it.first;
interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
assert(u.indexingMap);
const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1})";
if (u.indexingMap.isEmpty()) {
@ -1535,7 +1537,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId,
std::string exprsStr;
llvm::raw_string_ostream exprsStringStream(exprsStr);
exprsStringStream << "{";
interleaveComma(u.indexingMap.getResults(), exprsStringStream);
llvm::interleaveComma(u.indexingMap.getResults(), exprsStringStream);
exprsStringStream << "}";
exprsStringStream.flush();
@ -1563,10 +1565,10 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId,
} else {
std::string subExprs;
llvm::raw_string_ostream subExprsStringStream(subExprs);
interleaveComma(pTensorExpr->expressions, subExprsStringStream,
[&](const std::unique_ptr<Expression> &e) {
printExpr(subExprsStringStream, *e);
});
llvm::interleaveComma(pTensorExpr->expressions, subExprsStringStream,
[&](const std::unique_ptr<Expression> &e) {
printExpr(subExprsStringStream, *e);
});
subExprsStringStream.flush();
const char *tensorExprFmt = "\n ValueHandle _{0} = {1}({2});";
os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->opId, subExprs);
@ -1586,10 +1588,11 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId,
unsigned idx = 0;
std::string valueHandleStr;
llvm::raw_string_ostream valueHandleStringStream(valueHandleStr);
interleaveComma(state.orderedTensorArgs, valueHandleStringStream, [&](auto) {
valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
idx++;
});
llvm::interleaveComma(
state.orderedTensorArgs, valueHandleStringStream, [&](auto) {
valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
idx++;
});
std::string expressionsStr;
llvm::raw_string_ostream expressionStringStream(expressionsStr);
@ -1601,10 +1604,10 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId,
std::string yieldStr;
llvm::raw_string_ostream yieldStringStream(yieldStr);
interleaveComma(state.expressions, yieldStringStream,
[&](const std::unique_ptr<Expression> &e) {
printExpr(yieldStringStream, *e);
});
llvm::interleaveComma(state.expressions, yieldStringStream,
[&](const std::unique_ptr<Expression> &e) {
printExpr(yieldStringStream, *e);
});
valueHandleStringStream.flush();
expressionStringStream.flush();

View File

@ -183,7 +183,7 @@ private:
template <typename Range>
void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
os << '[';
mlir::interleaveComma(range, os);
llvm::interleaveComma(range, os);
os << ']';
}
@ -213,7 +213,7 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
printBracketedRange(traits, os);
os << ", " << (intr.getNumResults() == 0 ? 0 : 1) << ">, Arguments<(ins"
<< (operands.empty() ? "" : " ");
mlir::interleaveComma(operands, os);
llvm::interleaveComma(operands, os);
os << ")>;\n\n";
return false;

View File

@ -1107,14 +1107,16 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
body << " " << builderOpState
<< ".addAttribute(\"operand_segment_sizes\", "
"odsBuilder->getI32VectorAttr({";
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
if (op.getOperand(i).isOptional())
body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
else if (op.getOperand(i).isVariadic())
body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
else
body << "1";
});
llvm::interleaveComma(
llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
if (op.getOperand(i).isOptional())
body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
else if (op.getOperand(i).isVariadic())
body << "static_cast<int32_t>(" << getArgumentName(op, i)
<< ".size())";
else
body << "1";
});
body << "}));\n";
}
@ -1212,7 +1214,7 @@ void OpEmitter::genOpInterfaceMethods() {
continue;
std::string args;
llvm::raw_string_ostream os(args);
mlir::interleaveComma(method.getArguments(), os,
llvm::interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
@ -1766,7 +1768,7 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_LIST", os);
interleave(
llvm::interleave(
// TODO: We are constructing the Operator wrapper instance just for
// getting it's qualified class name here. Reduce the overhead by having a
// lightweight version of Operator class just for that purpose.

View File

@ -795,7 +795,7 @@ void OperationFormat::genParserTypeResolution(Operator &op,
body << " if (parser.resolveOperands(";
if (op.getNumOperands() > 1) {
body << "llvm::concat<const OpAsmParser::OperandType>(";
interleaveComma(op.getOperands(), body, [&](auto &operand) {
llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
body << operand.name << "Operands";
});
body << ")";
@ -815,11 +815,12 @@ void OperationFormat::genParserTypeResolution(Operator &op,
// the case of a single range, so guard it here.
if (op.getNumOperands() > 1) {
body << "llvm::concat<const Type>(";
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
body << "ArrayRef<Type>(";
emitTypeResolver(operandTypes[i], op.getOperand(i).name);
body << ")";
});
llvm::interleaveComma(
llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
body << "ArrayRef<Type>(";
emitTypeResolver(operandTypes[i], op.getOperand(i).name);
body << ")";
});
body << ")";
} else {
emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
@ -875,7 +876,7 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
else
body << "1";
};
interleaveComma(op.getOperands(), body, interleaveFn);
llvm::interleaveComma(op.getOperands(), body, interleaveFn);
body << "}));\n";
}
}
@ -897,7 +898,7 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
// Elide the variadic segment size attributes if necessary.
if (!fmt.allOperands && op.getTrait("OpTrait::AttrSizedOperandSegments"))
body << "\"operand_segment_sizes\", ";
interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
body << "\"" << attr->name << "\"";
});
body << "});\n";
@ -1016,13 +1017,13 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
const NamedSuccessor *var = successor->getVar();
if (var->isVariadic())
body << " interleaveComma(" << var->name << "(), p);\n";
body << " llvm::interleaveComma(" << var->name << "(), p);\n";
else
body << " p << " << var->name << "();\n";
} else if (isa<OperandsDirective>(element)) {
body << " p << getOperation()->getOperands();\n";
} else if (isa<SuccessorsDirective>(element)) {
body << " interleaveComma(getOperation()->getSuccessors(), p);\n";
body << " llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
body << " p << ";
genTypeOperandPrinter(dir->getOperand(), body) << ";\n";

View File

@ -36,10 +36,10 @@ static void emitMethodNameAndArgs(const OpInterfaceMethod &method,
os << method.getName() << '(';
if (addOperationArg)
os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", ");
interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
llvm::interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
os << ')';
}
@ -72,7 +72,7 @@ static void emitInterfaceDef(OpInterface &interface, raw_ostream &os) {
os << " {\n return getImpl()->" << method.getName() << '(';
if (!method.isStatic())
os << "getOperation()" << (method.arg_empty() ? "" : ", ");
interleaveComma(
llvm::interleaveComma(
method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
os << ");\n }\n";
@ -135,7 +135,7 @@ static void emitModelDecl(OpInterface &interface, raw_ostream &os) {
// Add the arguments to the call.
os << method.getName() << '(';
interleaveComma(
llvm::interleaveComma(
method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) { os << arg.name; });
os << ");\n }\n";
@ -255,10 +255,10 @@ static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
if (method.isStatic())
os << "static ";
emitCPPType(method.getReturnType(), os) << method.getName() << '(';
interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
emitCPPType(arg.type, os) << arg.name;
});
llvm::interleaveComma(method.getArguments(), os,
[&](const OpInterfaceMethod::Argument &arg) {
emitCPPType(arg.type, os) << arg.name;
});
os << ");\n```\n";
// Emit the description.

View File

@ -500,7 +500,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
return lhs->getOperationName() < rhs->getOperationName();
});
interleaveComma(sortedResultOps, os, [&](const Operator *op) {
llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
os << '"' << op->getOperationName() << '"';
});
os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";

View File

@ -1305,7 +1305,7 @@ static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
os << " case Capability::" << enumerant.getSymbol()
<< ": {static const Capability implies[" << impliedCapsDefs.size()
<< "] = {";
mlir::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) {
llvm::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) {
os << "Capability::" << EnumAttrCase(capDef).getSymbol();
});
os << "}; return ArrayRef<Capability>(implies, " << impliedCapsDefs.size()