[mlir] Convert NamedAttribute to be a class

NamedAttribute is currently represented as an std::pair, but this
creates an extremely clunky .first/.second API. This commit
converts it to a class, with better accessors (getName/getValue)
and also opens the door for more convenient API in the future.

Differential Revision: https://reviews.llvm.org/D113956
This commit is contained in:
River Riddle 2021-11-18 05:23:32 +00:00
parent 4a8734deb7
commit 0c7890c844
54 changed files with 302 additions and 220 deletions

View File

@ -2222,7 +2222,7 @@ getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands,
NamedAttribute targetOffsetAttr = NamedAttribute targetOffsetAttr =
*owner->getAttrDictionary().getNamed(offsetAttr); *owner->getAttrDictionary().getNamed(offsetAttr);
return getSubOperands( return getSubOperands(
pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(), pos, operands, targetOffsetAttr.getValue().cast<DenseIntElementsAttr>(),
mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr));
} }

View File

@ -72,9 +72,8 @@ def LLVM_OneResultOpBuilder :
[{ [{
if (resultType) $_state.addTypes(resultType); if (resultType) $_state.addTypes(resultType);
$_state.addOperands(operands); $_state.addOperands(operands);
for (auto namedAttr : attributes) { for (auto namedAttr : attributes)
$_state.addAttribute(namedAttr.first, namedAttr.second); $_state.addAttribute(namedAttr.getName(), namedAttr.getValue());
}
}]>; }]>;
def LLVM_ZeroResultOpBuilder : def LLVM_ZeroResultOpBuilder :
@ -82,9 +81,8 @@ def LLVM_ZeroResultOpBuilder :
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{ [{
$_state.addOperands(operands); $_state.addOperands(operands);
for (auto namedAttr : attributes) { for (auto namedAttr : attributes)
$_state.addAttribute(namedAttr.first, namedAttr.second); $_state.addAttribute(namedAttr.getName(), namedAttr.getValue());
}
}]>; }]>;
// Compatibility builder that takes an instance of wrapped llvm::VoidType // Compatibility builder that takes an instance of wrapped llvm::VoidType

View File

@ -136,13 +136,60 @@ inline ::llvm::hash_code hash_value(Attribute arg) {
// NamedAttribute // NamedAttribute
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// NamedAttribute is combination of a name, represented by a StringAttr, and a /// NamedAttribute represents a combination of a name and an Attribute value.
/// value, represented by an Attribute. The attribute pointer should always be class NamedAttribute {
/// non-null. public:
using NamedAttribute = std::pair<StringAttr, Attribute>; NamedAttribute(StringAttr name, Attribute value);
bool operator<(const NamedAttribute &lhs, const NamedAttribute &rhs); /// Return the name of the attribute.
bool operator<(const NamedAttribute &lhs, StringRef rhs); StringAttr getName() const;
/// Return the dialect of the name of this attribute, if the name is prefixed
/// by a dialect namespace. For example, `llvm.fast_math` would return the
/// LLVM dialect (if it is loaded). Returns nullptr if the dialect isn't
/// loaded, or if the name is not prefixed by a dialect namespace.
Dialect *getNameDialect() const;
/// Return the value of the attribute.
Attribute getValue() const { return value; }
/// Set the name of this attribute.
void setName(StringAttr newName);
/// Set the value of this attribute.
void setValue(Attribute newValue) {
assert(value && "expected valid attribute value");
value = newValue;
}
/// Compare this attribute to the provided attribute, ordering by name.
bool operator<(const NamedAttribute &rhs) const;
/// Compare this attribute to the provided string, ordering by name.
bool operator<(StringRef rhs) const;
bool operator==(const NamedAttribute &rhs) const {
return name == rhs.name && value == rhs.value;
}
bool operator!=(const NamedAttribute &rhs) const { return !(*this == rhs); }
private:
NamedAttribute(Attribute name, Attribute value) : name(name), value(value) {}
/// Allow access to internals to enable hashing.
friend ::llvm::hash_code hash_value(const NamedAttribute &arg);
friend DenseMapInfo<NamedAttribute>;
/// The name of the attribute. This is represented as a StringAttr, but
/// type-erased to Attribute in the field.
Attribute name;
/// The value of the attribute.
Attribute value;
};
inline ::llvm::hash_code hash_value(const NamedAttribute &arg) {
using AttrPairT = std::pair<Attribute, Attribute>;
return DenseMapInfo<AttrPairT>::getHashValue(AttrPairT(arg.name, arg.value));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AttributeTraitBase // AttributeTraitBase
@ -227,6 +274,23 @@ template <> struct PointerLikeTypeTraits<mlir::Attribute> {
mlir::AttributeStorage *>::NumLowBitsAvailable; mlir::AttributeStorage *>::NumLowBitsAvailable;
}; };
template <> struct DenseMapInfo<mlir::NamedAttribute> {
static mlir::NamedAttribute getEmptyKey() {
auto emptyAttr = llvm::DenseMapInfo<mlir::Attribute>::getEmptyKey();
return mlir::NamedAttribute(emptyAttr, emptyAttr);
}
static mlir::NamedAttribute getTombstoneKey() {
auto tombAttr = llvm::DenseMapInfo<mlir::Attribute>::getTombstoneKey();
return mlir::NamedAttribute(tombAttr, tombAttr);
}
static unsigned getHashValue(mlir::NamedAttribute val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::NamedAttribute lhs, mlir::NamedAttribute rhs) {
return lhs == rhs;
}
};
} // namespace llvm } // namespace llvm
#endif #endif

View File

@ -609,10 +609,10 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
// that they contain a dialect prefix in their name. Call the dialect, if // that they contain a dialect prefix in their name. Call the dialect, if
// registered, to verify the attributes themselves. // registered, to verify the attributes themselves.
for (auto attr : argAttrs) { for (auto attr : argAttrs) {
if (!attr.first.strref().contains('.')) if (!attr.getName().strref().contains('.'))
return funcOp.emitOpError( return funcOp.emitOpError(
"arguments may only have dialect attributes"); "arguments may only have dialect attributes");
if (Dialect *dialect = attr.first.getReferencedDialect()) { if (Dialect *dialect = attr.getNameDialect()) {
if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
/*argIndex=*/i, attr))) /*argIndex=*/i, attr)))
return failure(); return failure();
@ -643,9 +643,9 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
// that they contain a dialect prefix in their name. Call the dialect, if // that they contain a dialect prefix in their name. Call the dialect, if
// registered, to verify the attributes themselves. // registered, to verify the attributes themselves.
for (auto attr : resultAttrs) { for (auto attr : resultAttrs) {
if (!attr.first.strref().contains('.')) if (!attr.getName().strref().contains('.'))
return funcOp.emitOpError("results may only have dialect attributes"); return funcOp.emitOpError("results may only have dialect attributes");
if (Dialect *dialect = attr.first.getReferencedDialect()) { if (Dialect *dialect = attr.getNameDialect()) {
if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
/*resultIndex=*/i, /*resultIndex=*/i,
attr))) attr)))

View File

@ -373,7 +373,7 @@ public:
bool (*)(NamedAttribute)> { bool (*)(NamedAttribute)> {
static bool filter(NamedAttribute attr) { static bool filter(NamedAttribute attr) {
// Dialect attributes are prefixed by the dialect name, like operations. // Dialect attributes are prefixed by the dialect name, like operations.
return attr.first.strref().count('.'); return attr.getName().strref().count('.');
} }
explicit dialect_attr_iterator(ArrayRef<NamedAttribute>::iterator it, explicit dialect_attr_iterator(ArrayRef<NamedAttribute>::iterator it,
@ -407,7 +407,7 @@ public:
NamedAttrList attrs; NamedAttrList attrs;
attrs.append(std::begin(dialectAttrs), std::end(dialectAttrs)); attrs.append(std::begin(dialectAttrs), std::end(dialectAttrs));
for (auto attr : getAttrs()) for (auto attr : getAttrs())
if (!attr.first.strref().contains('.')) if (!attr.getName().strref().contains('.'))
attrs.push_back(attr); attrs.push_back(attr);
setAttrs(attrs.getDictionary(getContext())); setAttrs(attrs.getDictionary(getContext()));
} }

View File

@ -382,7 +382,7 @@ template <typename IteratorT, typename NameT>
std::pair<IteratorT, bool> findAttrUnsorted(IteratorT first, IteratorT last, std::pair<IteratorT, bool> findAttrUnsorted(IteratorT first, IteratorT last,
NameT name) { NameT name) {
for (auto it = first; it != last; ++it) for (auto it = first; it != last; ++it)
if (it->first == name) if (it->getName() == name)
return {it, true}; return {it, true};
return {last, false}; return {last, false};
} }
@ -399,7 +399,7 @@ std::pair<IteratorT, bool> findAttrSorted(IteratorT first, IteratorT last,
while (length > 0) { while (length > 0) {
ptrdiff_t half = length / 2; ptrdiff_t half = length / 2;
IteratorT mid = first + half; IteratorT mid = first + half;
int compare = mid->first.strref().compare(name); int compare = mid->getName().strref().compare(name);
if (compare < 0) { if (compare < 0) {
first = mid + 1; first = mid + 1;
length = length - half - 1; length = length - half - 1;

View File

@ -81,7 +81,7 @@ public:
amendOperation(Operation *op, NamedAttribute attribute, amendOperation(Operation *op, NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const { LLVM::ModuleTranslation &moduleTranslation) const {
if (const LLVMTranslationDialectInterface *iface = if (const LLVMTranslationDialectInterface *iface =
getInterfaceFor(attribute.first.getReferencedDialect())) { getInterfaceFor(attribute.getNameDialect())) {
return iface->amendOperation(op, attribute, moduleTranslation); return iface->amendOperation(op, attribute, moduleTranslation);
} }
return success(); return success();

View File

@ -83,7 +83,7 @@ MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
intptr_t pos) { intptr_t pos) {
NamedAttribute attribute = NamedAttribute attribute =
unwrap(attr).cast<DictionaryAttr>().getValue()[pos]; unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
return {wrap(attribute.first), wrap(attribute.second)}; return {wrap(attribute.getName()), wrap(attribute.getValue())};
} }
MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,

View File

@ -432,7 +432,7 @@ intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
NamedAttribute attr = unwrap(op)->getAttrs()[pos]; NamedAttribute attr = unwrap(op)->getAttrs()[pos];
return MlirNamedAttribute{wrap(attr.first), wrap(attr.second)}; return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
} }
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,

View File

@ -55,9 +55,9 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// not specific to function modeling. // not specific to function modeling.
SmallVector<NamedAttribute, 4> attributes; SmallVector<NamedAttribute, 4> attributes;
for (const auto &attr : gpuFuncOp->getAttrs()) { for (const auto &attr : gpuFuncOp->getAttrs()) {
if (attr.first == SymbolTable::getSymbolAttrName() || if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.first == function_like_impl::getTypeAttrName() || attr.getName() == function_like_impl::getTypeAttrName() ||
attr.first == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
continue; continue;
attributes.push_back(attr); attributes.push_back(attr);
} }

View File

@ -216,10 +216,10 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
rewriter.getFunctionType(signatureConverter.getConvertedTypes(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
llvm::None)); llvm::None));
for (const auto &namedAttr : funcOp->getAttrs()) { for (const auto &namedAttr : funcOp->getAttrs()) {
if (namedAttr.first == function_like_impl::getTypeAttrName() || if (namedAttr.getName() == function_like_impl::getTypeAttrName() ||
namedAttr.first == SymbolTable::getSymbolAttrName()) namedAttr.getName() == SymbolTable::getSymbolAttrName())
continue; continue;
newFuncOp->setAttr(namedAttr.first, namedAttr.second); newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
} }
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),

View File

@ -544,10 +544,10 @@ static LogicalResult processParallelLoop(
// Propagate custom user defined optional attributes, that can be used at // Propagate custom user defined optional attributes, that can be used at
// later stage, such as extension data for GPU kernel dispatch // later stage, such as extension data for GPU kernel dispatch
for (const auto &namedAttr : parallelOp->getAttrs()) { for (const auto &namedAttr : parallelOp->getAttrs()) {
if (namedAttr.first == gpu::getMappingAttrName() || if (namedAttr.getName() == gpu::getMappingAttrName() ||
namedAttr.first == ParallelOp::getOperandSegmentSizeAttr()) namedAttr.getName() == ParallelOp::getOperandSegmentSizeAttr())
continue; continue;
launchOp->setAttr(namedAttr.first, namedAttr.second); launchOp->setAttr(namedAttr.getName(), namedAttr.getValue());
} }
Block *body = parallelOp.getBody(); Block *body = parallelOp.getBody();

View File

@ -53,11 +53,11 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
bool filterArgAttrs, bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) { SmallVectorImpl<NamedAttribute> &result) {
for (const auto &attr : attrs) { for (const auto &attr : attrs) {
if (attr.first == SymbolTable::getSymbolAttrName() || if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.first == function_like_impl::getTypeAttrName() || attr.getName() == function_like_impl::getTypeAttrName() ||
attr.first == "std.varargs" || attr.getName() == "std.varargs" ||
(filterArgAttrs && (filterArgAttrs &&
attr.first == function_like_impl::getArgDictAttrName())) attr.getName() == function_like_impl::getArgDictAttrName()))
continue; continue;
result.push_back(attr); result.push_back(attr);
} }
@ -255,7 +255,7 @@ protected:
rewriter.getArrayAttr(newArgAttrs))); rewriter.getArrayAttr(newArgAttrs)));
} }
for (auto pair : llvm::enumerate(attributes)) { for (auto pair : llvm::enumerate(attributes)) {
if (pair.value().first == "llvm.linkage") { if (pair.value().getName() == "llvm.linkage") {
attributes.erase(attributes.begin() + pair.index()); attributes.erase(attributes.begin() + pair.index());
break; break;
} }
@ -448,9 +448,9 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type,
symbolRef.getValue()); symbolRef.getValue());
for (const NamedAttribute &attr : op->getAttrs()) { for (const NamedAttribute &attr : op->getAttrs()) {
if (attr.first.strref() == "value") if (attr.getName().strref() == "value")
continue; continue;
newOp->setAttr(attr.first, attr.second); newOp->setAttr(attr.getName(), attr.getValue());
} }
rewriter.replaceOp(op, newOp->getResults()); rewriter.replaceOp(op, newOp->getResults());
return success(); return success();

View File

@ -90,10 +90,10 @@ void SimplifyAffineStructures::runOnFunction() {
SmallVector<Operation *> opsToSimplify; SmallVector<Operation *> opsToSimplify;
func.walk([&](Operation *op) { func.walk([&](Operation *op) {
for (auto attr : op->getAttrs()) { for (auto attr : op->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) if (auto mapAttr = attr.getValue().dyn_cast<AffineMapAttr>())
simplifyAndUpdateAttribute(op, attr.first, mapAttr); simplifyAndUpdateAttribute(op, attr.getName(), mapAttr);
else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>()) else if (auto setAttr = attr.getValue().dyn_cast<IntegerSetAttr>())
simplifyAndUpdateAttribute(op, attr.first, setAttr); simplifyAndUpdateAttribute(op, attr.getName(), setAttr);
} }
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op)) if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))

View File

@ -367,8 +367,8 @@ void DLTIDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op, LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) { NamedAttribute attr) {
if (attr.first == DLTIDialect::kDataLayoutAttrName) { if (attr.getName() == DLTIDialect::kDataLayoutAttrName) {
if (!attr.second.isa<DataLayoutSpecAttr>()) { if (!attr.getValue().isa<DataLayoutSpecAttr>()) {
return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName
<< "' is expected to be a #dlti.dl_spec attribute"; << "' is expected to be a #dlti.dl_spec attribute";
} }
@ -377,6 +377,6 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
return success(); return success();
} }
return op->emitError() << "attribute '" << attr.first.getValue() return op->emitError() << "attribute '" << attr.getName().getValue()
<< "' not supported by dialect"; << "' not supported by dialect";
} }

View File

@ -174,8 +174,8 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) { NamedAttribute attr) {
if (!attr.second.isa<UnitAttr>() || if (!attr.getValue().isa<UnitAttr>() ||
attr.first != getContainerModuleAttrName()) attr.getName() != getContainerModuleAttrName())
return success(); return success();
auto module = dyn_cast<ModuleOp>(op); auto module = dyn_cast<ModuleOp>(op);

View File

@ -51,9 +51,9 @@ static constexpr const char kNonTemporalAttrName[] = "nontemporal";
static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute, 8> filteredAttrs( SmallVector<NamedAttribute, 8> filteredAttrs(
llvm::make_filter_range(attrs, [&](NamedAttribute attr) { llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
if (attr.first == "fastmathFlags") { if (attr.getName() == "fastmathFlags") {
auto defAttr = FMFAttr::get(attr.second.getContext(), {}); auto defAttr = FMFAttr::get(attr.getValue().getContext(), {});
return defAttr != attr.second; return defAttr != attr.getValue();
} }
return true; return true;
})); }));
@ -201,7 +201,8 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
Optional<NamedAttribute> alignmentAttr = Optional<NamedAttribute> alignmentAttr =
result.attributes.getNamed("alignment"); result.attributes.getNamed("alignment");
if (alignmentAttr.hasValue()) { if (alignmentAttr.hasValue()) {
auto alignmentInt = alignmentAttr.getValue().second.dyn_cast<IntegerAttr>(); auto alignmentInt =
alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>();
if (!alignmentInt) if (!alignmentInt)
return parser.emitError(parser.getNameLoc(), return parser.emitError(parser.getNameLoc(),
"expected integer alignment"); "expected integer alignment");
@ -2317,15 +2318,15 @@ LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) { NamedAttribute attr) {
// If the `llvm.loop` attribute is present, enforce the following structure, // If the `llvm.loop` attribute is present, enforce the following structure,
// which the module translation can assume. // which the module translation can assume.
if (attr.first.strref() == LLVMDialect::getLoopAttrName()) { if (attr.getName() == LLVMDialect::getLoopAttrName()) {
auto loopAttr = attr.second.dyn_cast<DictionaryAttr>(); auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>();
if (!loopAttr) if (!loopAttr)
return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName()
<< "' to be a dictionary attribute"; << "' to be a dictionary attribute";
Optional<NamedAttribute> parallelAccessGroup = Optional<NamedAttribute> parallelAccessGroup =
loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
if (parallelAccessGroup.hasValue()) { if (parallelAccessGroup.hasValue()) {
auto accessGroups = parallelAccessGroup->second.dyn_cast<ArrayAttr>(); auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>();
if (!accessGroups) if (!accessGroups)
return op->emitOpError() return op->emitOpError()
<< "expected '" << LLVMDialect::getParallelAccessAttrName() << "expected '" << LLVMDialect::getParallelAccessAttrName()
@ -2353,7 +2354,8 @@ LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
Optional<NamedAttribute> loopOptions = Optional<NamedAttribute> loopOptions =
loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
if (loopOptions.hasValue() && !loopOptions->second.isa<LoopOptionsAttr>()) if (loopOptions.hasValue() &&
!loopOptions->getValue().isa<LoopOptionsAttr>())
return op->emitOpError() return op->emitOpError()
<< "expected '" << LLVMDialect::getLoopOptionsAttrName() << "expected '" << LLVMDialect::getLoopOptionsAttrName()
<< "' to be a `loopopts` attribute"; << "' to be a `loopopts` attribute";
@ -2363,9 +2365,9 @@ LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
// syntax. Try parsing it and report errors in case of failure. Users of this // syntax. Try parsing it and report errors in case of failure. Users of this
// attribute may assume it is well-formed and can pass it to the (asserting) // attribute may assume it is well-formed and can pass it to the (asserting)
// llvm::DataLayout constructor. // llvm::DataLayout constructor.
if (attr.first.strref() != LLVM::LLVMDialect::getDataLayoutAttrName()) if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
return success(); return success();
if (auto stringAttr = attr.second.dyn_cast<StringAttr>()) if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>())
return verifyDataLayoutString( return verifyDataLayoutString(
stringAttr.getValue(), stringAttr.getValue(),
[op](const Twine &message) { op->emitOpError() << message.str(); }); [op](const Twine &message) { op->emitOpError() << message.str(); });
@ -2381,13 +2383,13 @@ LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
unsigned argIdx, unsigned argIdx,
NamedAttribute argAttr) { NamedAttribute argAttr) {
// Check that llvm.noalias is a unit attribute. // Check that llvm.noalias is a unit attribute.
if (argAttr.first == LLVMDialect::getNoAliasAttrName() && if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() &&
!argAttr.second.isa<UnitAttr>()) !argAttr.getValue().isa<UnitAttr>())
return op->emitError() return op->emitError()
<< "expected llvm.noalias argument attribute to be a unit attribute"; << "expected llvm.noalias argument attribute to be a unit attribute";
// Check that llvm.align is an integer attribute. // Check that llvm.align is an integer attribute.
if (argAttr.first == LLVMDialect::getAlignAttrName() && if (argAttr.getName() == LLVMDialect::getAlignAttrName() &&
!argAttr.second.isa<IntegerAttr>()) !argAttr.getValue().isa<IntegerAttr>())
return op->emitError() return op->emitError()
<< "llvm.align argument attribute of non integer type"; << "llvm.align argument attribute of non integer type";
return success(); return success();

View File

@ -57,7 +57,7 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
return failure(); return failure();
for (auto &attr : result.attributes) { for (auto &attr : result.attributes) {
if (attr.first != "return_value_and_is_valid") if (attr.getName() != "return_value_and_is_valid")
continue; continue;
auto structType = resultType.dyn_cast<LLVM::LLVMStructType>(); auto structType = resultType.dyn_cast<LLVM::LLVMStructType>();
if (structType && !structType.getBody().empty()) if (structType && !structType.getBody().empty())
@ -249,7 +249,7 @@ void NVVMDialect::initialize() {
LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) { NamedAttribute attr) {
// Kernel function attribute should be attached to functions. // Kernel function attribute should be attached to functions.
if (attr.first == NVVMDialect::getKernelFuncAttrName()) { if (attr.getName() == NVVMDialect::getKernelFuncAttrName()) {
if (!isa<LLVM::LLVMFuncOp>(op)) { if (!isa<LLVM::LLVMFuncOp>(op)) {
return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
<< "' attribute attached to unexpected op"; << "' attribute attached to unexpected op";

View File

@ -96,7 +96,7 @@ void ROCDLDialect::initialize() {
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op, LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) { NamedAttribute attr) {
// Kernel function attribute should be attached to functions. // Kernel function attribute should be attached to functions.
if (attr.first == ROCDLDialect::getKernelFuncAttrName()) { if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
if (!isa<LLVM::LLVMFuncOp>(op)) { if (!isa<LLVM::LLVMFuncOp>(op)) {
return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName() return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
<< "' attribute attached to unexpected op"; << "' attribute attached to unexpected op";

View File

@ -583,7 +583,7 @@ static void print(OpAsmPrinter &p, GenericOp op) {
genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
SmallVector<NamedAttribute, 8> genericAttrs; SmallVector<NamedAttribute, 8> genericAttrs;
for (auto attr : op->getAttrs()) for (auto attr : op->getAttrs())
if (genericAttrNamesSet.count(attr.first.strref()) > 0) if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
genericAttrs.push_back(attr); genericAttrs.push_back(attr);
if (!genericAttrs.empty()) { if (!genericAttrs.empty()) {
auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs); auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
@ -598,7 +598,7 @@ static void print(OpAsmPrinter &p, GenericOp op) {
bool hasExtraAttrs = false; bool hasExtraAttrs = false;
for (NamedAttribute n : op->getAttrs()) { for (NamedAttribute n : op->getAttrs()) {
if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.first.strref()))) if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
break; break;
} }
if (hasExtraAttrs) { if (hasExtraAttrs) {
@ -753,8 +753,8 @@ struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
// Copy over unknown attributes. They might be load bearing for some flow. // Copy over unknown attributes. They might be load bearing for some flow.
ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames(); ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
for (NamedAttribute kv : genericOp->getAttrs()) { for (NamedAttribute kv : genericOp->getAttrs()) {
if (!llvm::is_contained(odsAttrs, kv.first.getValue())) { if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) {
newOp->setAttr(kv.first, kv.second); newOp->setAttr(kv.getName(), kv.getValue());
} }
} }

View File

@ -152,30 +152,30 @@ LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) { NamedAttribute attr) {
using comprehensive_bufferize::BufferizableOpInterface; using comprehensive_bufferize::BufferizableOpInterface;
if (attr.first == BufferizableOpInterface::kInplaceableAttrName) { if (attr.getName() == BufferizableOpInterface::kInplaceableAttrName) {
if (!attr.second.isa<BoolAttr>()) { if (!attr.getValue().isa<BoolAttr>()) {
return op->emitError() return op->emitError()
<< "'" << BufferizableOpInterface::kInplaceableAttrName << "'" << BufferizableOpInterface::kInplaceableAttrName
<< "' is expected to be a boolean attribute"; << "' is expected to be a boolean attribute";
} }
if (!op->hasTrait<OpTrait::FunctionLike>()) if (!op->hasTrait<OpTrait::FunctionLike>())
return op->emitError() << "expected " << attr.first return op->emitError() << "expected " << attr.getName()
<< " to be used on function-like operations"; << " to be used on function-like operations";
return success(); return success();
} }
if (attr.first == BufferizableOpInterface::kBufferLayoutAttrName) { if (attr.getName() == BufferizableOpInterface::kBufferLayoutAttrName) {
if (!attr.second.isa<AffineMapAttr>()) { if (!attr.getValue().isa<AffineMapAttr>()) {
return op->emitError() return op->emitError()
<< "'" << BufferizableOpInterface::kBufferLayoutAttrName << "'" << BufferizableOpInterface::kBufferLayoutAttrName
<< "' is expected to be a affine map attribute"; << "' is expected to be a affine map attribute";
} }
if (!op->hasTrait<OpTrait::FunctionLike>()) if (!op->hasTrait<OpTrait::FunctionLike>())
return op->emitError() << "expected " << attr.first return op->emitError() << "expected " << attr.getName()
<< " to be used on function-like operations"; << " to be used on function-like operations";
return success(); return success();
} }
if (attr.first == LinalgDialect::kMemoizedIndexingMapsAttrName) if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
return success(); return success();
return op->emitError() << "attribute '" << attr.first return op->emitError() << "attribute '" << attr.getName()
<< "' not supported by the linalg dialect"; << "' not supported by the linalg dialect";
} }

View File

@ -1211,8 +1211,8 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attribute) { NamedAttribute attribute) {
StringRef symbol = attribute.first.strref(); StringRef symbol = attribute.getName().strref();
Attribute attr = attribute.second; Attribute attr = attribute.getValue();
// TODO: figure out a way to generate the description from the // TODO: figure out a way to generate the description from the
// StructAttr definition. // StructAttr definition.
@ -1237,8 +1237,8 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
/// `valueType` is valid. /// `valueType` is valid.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
NamedAttribute attribute) { NamedAttribute attribute) {
StringRef symbol = attribute.first.strref(); StringRef symbol = attribute.getName().strref();
Attribute attr = attribute.second; Attribute attr = attribute.getValue();
if (symbol != spirv::getInterfaceVarABIAttrName()) if (symbol != spirv::getInterfaceVarABIAttrName())
return emitError(loc, "found unsupported '") return emitError(loc, "found unsupported '")

View File

@ -76,7 +76,7 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
static llvm::hash_code computeHash(SymbolOpInterface symbolOp) { static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
auto range = auto range =
llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) { llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
return attr.first != SymbolTable::getSymbolAttrName(); return attr.getName() != SymbolTable::getSymbolAttrName();
}); });
return llvm::hash_combine( return llvm::hash_combine(

View File

@ -44,9 +44,8 @@ public:
// Save all named attributes except "type" attribute. // Save all named attributes except "type" attribute.
for (const auto &attr : op->getAttrs()) { for (const auto &attr : op->getAttrs()) {
if (attr.first == "type") { if (attr.getName() == "type")
continue; continue;
}
globalVarAttrs.push_back(attr); globalVarAttrs.push_back(attr);
} }

View File

@ -580,9 +580,9 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
// Copy over all attributes other than the function name and type. // Copy over all attributes other than the function name and type.
for (const auto &namedAttr : funcOp->getAttrs()) { for (const auto &namedAttr : funcOp->getAttrs()) {
if (namedAttr.first != function_like_impl::getTypeAttrName() && if (namedAttr.getName() != function_like_impl::getTypeAttrName() &&
namedAttr.first != SymbolTable::getSymbolAttrName()) namedAttr.getName() != SymbolTable::getSymbolAttrName())
newFuncOp->setAttr(namedAttr.first, namedAttr.second); newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
} }
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),

View File

@ -188,12 +188,12 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attribute) { NamedAttribute attribute) {
// Verify shape.lib attribute. // Verify shape.lib attribute.
if (attribute.first == "shape.lib") { if (attribute.getName() == "shape.lib") {
if (!op->hasTrait<OpTrait::SymbolTable>()) if (!op->hasTrait<OpTrait::SymbolTable>())
return op->emitError( return op->emitError(
"shape.lib attribute may only be on op implementing SymbolTable"); "shape.lib attribute may only be on op implementing SymbolTable");
if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) { if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
if (!symbol) if (!symbol)
return op->emitError("shape function library ") return op->emitError("shape function library ")
@ -204,7 +204,7 @@ LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
<< symbolRef << " required to be shape function library"; << symbolRef << " required to be shape function library";
} }
if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) { if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
// Verify all entries are function libraries and mappings in libraries // Verify all entries are function libraries and mappings in libraries
// refer to unique ops. // refer to unique ops.
DenseSet<StringAttr> key; DenseSet<StringAttr> key;
@ -219,10 +219,10 @@ LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
return op->emitError() return op->emitError()
<< it << " does not refer to FunctionLibraryOp"; << it << " does not refer to FunctionLibraryOp";
for (auto mapping : shapeFnLib.getMapping()) { for (auto mapping : shapeFnLib.getMapping()) {
if (!key.insert(mapping.first).second) { if (!key.insert(mapping.getName()).second) {
return op->emitError("only one op to shape mapping allowed, found " return op->emitError("only one op to shape mapping allowed, found "
"multiple for `") "multiple for `")
<< mapping.first << "`"; << mapping.getName() << "`";
} }
} }
} }

View File

@ -54,8 +54,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned ptr = 0; unsigned ptr = 0;
unsigned ind = 0; unsigned ind = 0;
for (const NamedAttribute &attr : dict) { for (const NamedAttribute &attr : dict) {
if (attr.first == "dimLevelType") { if (attr.getName() == "dimLevelType") {
auto arrayAttr = attr.second.dyn_cast<ArrayAttr>(); auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) { if (!arrayAttr) {
parser.emitError(parser.getNameLoc(), parser.emitError(parser.getNameLoc(),
"expected an array for dimension level types"); "expected an array for dimension level types");
@ -82,24 +82,24 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
return {}; return {};
} }
} }
} else if (attr.first == "dimOrdering") { } else if (attr.getName() == "dimOrdering") {
auto affineAttr = attr.second.dyn_cast<AffineMapAttr>(); auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
if (!affineAttr) { if (!affineAttr) {
parser.emitError(parser.getNameLoc(), parser.emitError(parser.getNameLoc(),
"expected an affine map for dimension ordering"); "expected an affine map for dimension ordering");
return {}; return {};
} }
map = affineAttr.getValue(); map = affineAttr.getValue();
} else if (attr.first == "pointerBitWidth") { } else if (attr.getName() == "pointerBitWidth") {
auto intAttr = attr.second.dyn_cast<IntegerAttr>(); auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) { if (!intAttr) {
parser.emitError(parser.getNameLoc(), parser.emitError(parser.getNameLoc(),
"expected an integral pointer bitwidth"); "expected an integral pointer bitwidth");
return {}; return {};
} }
ptr = intAttr.getInt(); ptr = intAttr.getInt();
} else if (attr.first == "indexBitWidth") { } else if (attr.getName() == "indexBitWidth") {
auto intAttr = attr.second.dyn_cast<IntegerAttr>(); auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) { if (!intAttr) {
parser.emitError(parser.getNameLoc(), parser.emitError(parser.getNameLoc(),
"expected an integral index bitwidth"); "expected an integral index bitwidth");
@ -108,7 +108,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
ind = intAttr.getInt(); ind = intAttr.getInt();
} else { } else {
parser.emitError(parser.getNameLoc(), "unexpected key: ") parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.first.str(); << attr.getName().strref();
return {}; return {};
} }
} }

View File

@ -486,7 +486,7 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
traitAttrsSet.insert(attrNames.begin(), attrNames.end()); traitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs; SmallVector<NamedAttribute, 8> attrs;
for (auto attr : op->getAttrs()) for (auto attr : op->getAttrs())
if (traitAttrsSet.count(attr.first.strref()) > 0) if (traitAttrsSet.count(attr.getName().strref()) > 0)
attrs.push_back(attr); attrs.push_back(attr);
auto dictAttr = DictionaryAttr::get(op.getContext(), attrs); auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);

View File

@ -411,7 +411,7 @@ private:
// Consider the attributes of the operation for aliases. // Consider the attributes of the operation for aliases.
for (const NamedAttribute &attr : op->getAttrs()) for (const NamedAttribute &attr : op->getAttrs())
printAttribute(attr.second); printAttribute(attr.getValue());
} }
/// Print the given block. If 'printBlockArgs' is false, the arguments of the /// Print the given block. If 'printBlockArgs' is false, the arguments of the
@ -483,14 +483,14 @@ private:
return; return;
if (elidedAttrs.empty()) { if (elidedAttrs.empty()) {
for (const NamedAttribute &attr : attrs) for (const NamedAttribute &attr : attrs)
printAttribute(attr.second); printAttribute(attr.getValue());
return; return;
} }
llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
elidedAttrs.end()); elidedAttrs.end());
for (const NamedAttribute &attr : attrs) for (const NamedAttribute &attr : attrs)
if (!elidedAttrsSet.contains(attr.first.strref())) if (!elidedAttrsSet.contains(attr.getName().strref()))
printAttribute(attr.second); printAttribute(attr.getValue());
} }
void printOptionalAttrDictWithKeyword( void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs, ArrayRef<NamedAttribute> attrs,
@ -2031,24 +2031,22 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
elidedAttrs.end()); elidedAttrs.end());
auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
return !elidedAttrsSet.contains(attr.first.strref()); return !elidedAttrsSet.contains(attr.getName().strref());
}); });
if (!filteredAttrs.empty()) if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs); printFilteredAttributesFn(filteredAttrs);
} }
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
assert(attr.first.size() != 0 && "expected valid named attribute");
// Print the name without quotes if possible. // Print the name without quotes if possible.
::printKeywordOrString(attr.first.strref(), os); ::printKeywordOrString(attr.getName().strref(), os);
// Pretty printing elides the attribute value for unit attributes. // Pretty printing elides the attribute value for unit attributes.
if (attr.second.isa<UnitAttr>()) if (attr.getValue().isa<UnitAttr>())
return; return;
os << " = "; os << " = ";
printAttribute(attr.second); printAttribute(attr.getValue());
} }
void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {

View File

@ -23,9 +23,27 @@ MLIRContext *Attribute::getContext() const { return getDialect().getContext(); }
// NamedAttribute // NamedAttribute
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { NamedAttribute::NamedAttribute(StringAttr name, Attribute value)
return lhs.first.compare(rhs.first) < 0; : name(name), value(value) {
assert(name && value && "expected valid attribute name and value");
assert(name.size() != 0 && "expected valid attribute name");
} }
bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
return lhs.first.getValue().compare(rhs) < 0; StringAttr NamedAttribute::getName() const { return name.cast<StringAttr>(); }
Dialect *NamedAttribute::getNameDialect() const {
return getName().getReferencedDialect();
}
void NamedAttribute::setName(StringAttr newName) {
assert(name && "expected valid attribute name");
name = newName;
}
bool NamedAttribute::operator<(const NamedAttribute &rhs) const {
return getName().compare(rhs.getName()) < 0;
}
bool NamedAttribute::operator<(StringRef rhs) const {
return getName().getValue().compare(rhs) < 0;
} }

View File

@ -119,11 +119,12 @@ findDuplicateElement(ArrayRef<NamedAttribute> value) {
return none; return none;
if (value.size() == 2) if (value.size() == 2)
return value[0].first == value[1].first ? value[0] : none; return value[0].getName() == value[1].getName() ? value[0] : none;
auto it = std::adjacent_find( auto it = std::adjacent_find(value.begin(), value.end(),
value.begin(), value.end(), [](NamedAttribute l, NamedAttribute r) {
[](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); return l.getName() == r.getName();
});
return it != value.end() ? *it : none; return it != value.end() ? *it : none;
} }
@ -154,9 +155,6 @@ DictionaryAttr DictionaryAttr::get(MLIRContext *context,
ArrayRef<NamedAttribute> value) { ArrayRef<NamedAttribute> value) {
if (value.empty()) if (value.empty())
return DictionaryAttr::getEmpty(context); return DictionaryAttr::getEmpty(context);
assert(llvm::all_of(value,
[](const NamedAttribute &attr) { return attr.second; }) &&
"value cannot have null entries");
// We need to sort the element list to canonicalize it. // We need to sort the element list to canonicalize it.
SmallVector<NamedAttribute, 8> storage; SmallVector<NamedAttribute, 8> storage;
@ -173,10 +171,8 @@ DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
if (value.empty()) if (value.empty())
return DictionaryAttr::getEmpty(context); return DictionaryAttr::getEmpty(context);
// Ensure that the attribute elements are unique and sorted. // Ensure that the attribute elements are unique and sorted.
assert(llvm::is_sorted(value, assert(llvm::is_sorted(
[](NamedAttribute l, NamedAttribute r) { value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&
return l.first.strref() < r.first.strref();
}) &&
"expected attribute values to be sorted"); "expected attribute values to be sorted");
assert(!findDuplicateElement(value) && assert(!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique"); "DictionaryAttr element names must be unique");
@ -186,11 +182,11 @@ DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
/// Return the specified attribute if present, null otherwise. /// Return the specified attribute if present, null otherwise.
Attribute DictionaryAttr::get(StringRef name) const { Attribute DictionaryAttr::get(StringRef name) const {
auto it = impl::findAttrSorted(begin(), end(), name); auto it = impl::findAttrSorted(begin(), end(), name);
return it.second ? it.first->second : Attribute(); return it.second ? it.first->getValue() : Attribute();
} }
Attribute DictionaryAttr::get(StringAttr name) const { Attribute DictionaryAttr::get(StringAttr name) const {
auto it = impl::findAttrSorted(begin(), end(), name); auto it = impl::findAttrSorted(begin(), end(), name);
return it.second ? it.first->second : Attribute(); return it.second ? it.first->getValue() : Attribute();
} }
/// Return the specified named attribute if present, None otherwise. /// Return the specified named attribute if present, None otherwise.
@ -226,16 +222,16 @@ DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
void DictionaryAttr::walkImmediateSubElements( void DictionaryAttr::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn, function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const { function_ref<void(Type)> walkTypesFn) const {
for (Attribute attr : llvm::make_second_range(getValue())) for (const NamedAttribute &attr : getValue())
walkAttrsFn(attr); walkAttrsFn(attr.getValue());
} }
SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute( SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
ArrayRef<std::pair<size_t, Attribute>> replacements) const { ArrayRef<std::pair<size_t, Attribute>> replacements) const {
std::vector<NamedAttribute> vec = getValue().vec(); std::vector<NamedAttribute> vec = getValue().vec();
for (auto &it : replacements) { for (auto &it : replacements)
vec[it.first].second = it.second; vec[it.first].setValue(it.second);
}
// The above only modifies the mapped value, but not the key, and therefore // The above only modifies the mapped value, but not the key, and therefore
// not the order of the elements. It remains sorted // not the order of the elements. It remains sorted
return getWithSorted(getContext(), vec); return getWithSorted(getContext(), vec);

View File

@ -153,12 +153,17 @@ static LogicalResult verify(FuncOp op) {
/// from this function to dest. /// from this function to dest.
void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
// Add the attributes of this function to dest. // Add the attributes of this function to dest.
llvm::MapVector<StringAttr, Attribute> newAttrs; llvm::MapVector<StringAttr, Attribute> newAttrMap;
for (const auto &attr : dest->getAttrs()) for (const auto &attr : dest->getAttrs())
newAttrs.insert(attr); newAttrMap.insert({attr.getName(), attr.getValue()});
for (const auto &attr : (*this)->getAttrs()) for (const auto &attr : (*this)->getAttrs())
newAttrs.insert(attr); newAttrMap.insert({attr.getName(), attr.getValue()});
dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector()));
auto newAttrs = llvm::to_vector(llvm::map_range(
newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
return NamedAttribute(attrPair.first, attrPair.second);
}));
dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
// Clone the body. // Clone the body.
getBody().cloneInto(&dest.getBody(), mapper); getBody().cloneInto(&dest.getBody(), mapper);
@ -235,10 +240,9 @@ DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
// Take the first and only (if present) attribute that implements the // Take the first and only (if present) attribute that implements the
// interface. This needs a linear search, but is called only once per data // interface. This needs a linear search, but is called only once per data
// layout object construction that is used for repeated queries. // layout object construction that is used for repeated queries.
for (Attribute attr : llvm::make_second_range(getOperation()->getAttrs())) { for (NamedAttribute attr : getOperation()->getAttrs())
if (auto spec = attr.dyn_cast<DataLayoutSpecInterface>()) if (auto spec = attr.getValue().dyn_cast<DataLayoutSpecInterface>())
return spec; return spec;
}
return {}; return {};
} }
@ -246,30 +250,30 @@ static LogicalResult verify(ModuleOp op) {
// Check that none of the attributes are non-dialect attributes, except for // Check that none of the attributes are non-dialect attributes, except for
// the symbol related attributes. // the symbol related attributes.
for (auto attr : op->getAttrs()) { for (auto attr : op->getAttrs()) {
if (!attr.first.strref().contains('.') && if (!attr.getName().strref().contains('.') &&
!llvm::is_contained( !llvm::is_contained(
ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(), ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
mlir::SymbolTable::getVisibilityAttrName()}, mlir::SymbolTable::getVisibilityAttrName()},
attr.first.strref())) attr.getName().strref()))
return op.emitOpError() << "can only contain attributes with " return op.emitOpError() << "can only contain attributes with "
"dialect-prefixed names, found: '" "dialect-prefixed names, found: '"
<< attr.first.getValue() << "'"; << attr.getName().getValue() << "'";
} }
// Check that there is at most one data layout spec attribute. // Check that there is at most one data layout spec attribute.
StringRef layoutSpecAttrName; StringRef layoutSpecAttrName;
DataLayoutSpecInterface layoutSpec; DataLayoutSpecInterface layoutSpec;
for (const NamedAttribute &na : op->getAttrs()) { for (const NamedAttribute &na : op->getAttrs()) {
if (auto spec = na.second.dyn_cast<DataLayoutSpecInterface>()) { if (auto spec = na.getValue().dyn_cast<DataLayoutSpecInterface>()) {
if (layoutSpec) { if (layoutSpec) {
InFlightDiagnostic diag = InFlightDiagnostic diag =
op.emitOpError() << "expects at most one data layout attribute"; op.emitOpError() << "expects at most one data layout attribute";
diag.attachNote() << "'" << layoutSpecAttrName diag.attachNote() << "'" << layoutSpecAttrName
<< "' is a data layout attribute"; << "' is a data layout attribute";
diag.attachNote() << "'" << na.first.getValue() diag.attachNote() << "'" << na.getName().getValue()
<< "' is a data layout attribute"; << "' is a data layout attribute";
} }
layoutSpecAttrName = na.first.strref(); layoutSpecAttrName = na.getName().strref();
layoutSpec = spec; layoutSpec = spec;
} }
} }

View File

@ -72,11 +72,8 @@ void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) {
} }
void NamedAttrList::push_back(NamedAttribute newAttribute) { void NamedAttrList::push_back(NamedAttribute newAttribute) {
assert(newAttribute.second && "unexpected null attribute"); if (isSorted())
if (isSorted()) { dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
dictionarySorted.setInt(attrs.empty() ||
attrs.back().first.compare(newAttribute.first) < 0);
}
dictionarySorted.setPointer(nullptr); dictionarySorted.setPointer(nullptr);
attrs.push_back(newAttribute); attrs.push_back(newAttribute);
} }
@ -84,11 +81,11 @@ void NamedAttrList::push_back(NamedAttribute newAttribute) {
/// Return the specified attribute if present, null otherwise. /// Return the specified attribute if present, null otherwise.
Attribute NamedAttrList::get(StringRef name) const { Attribute NamedAttrList::get(StringRef name) const {
auto it = findAttr(*this, name); auto it = findAttr(*this, name);
return it.second ? it.first->second : Attribute(); return it.second ? it.first->getValue() : Attribute();
} }
Attribute NamedAttrList::get(StringAttr name) const { Attribute NamedAttrList::get(StringAttr name) const {
auto it = findAttr(*this, name); auto it = findAttr(*this, name);
return it.second ? it.first->second : Attribute(); return it.second ? it.first->getValue() : Attribute();
} }
/// Return the specified named attribute if present, None otherwise. /// Return the specified named attribute if present, None otherwise.
@ -112,12 +109,14 @@ Attribute NamedAttrList::set(StringAttr name, Attribute value) {
if (it.second) { if (it.second) {
// Update the existing attribute by swapping out the old value for the new // Update the existing attribute by swapping out the old value for the new
// value. Return the old value. // value. Return the old value.
if (it.first->second != value) { Attribute oldValue = it.first->getValue();
std::swap(it.first->second, value); if (it.first->getValue() != value) {
it.first->setValue(value);
// If the attributes have changed, the dictionary is invalidated. // If the attributes have changed, the dictionary is invalidated.
dictionarySorted.setPointer(nullptr); dictionarySorted.setPointer(nullptr);
} }
return value; return oldValue;
} }
// Perform a string lookup to insert the new attribute into its sorted // Perform a string lookup to insert the new attribute into its sorted
// position. // position.
@ -137,7 +136,7 @@ Attribute NamedAttrList::set(StringRef name, Attribute value) {
Attribute Attribute
NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) { NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
// Erasing does not affect the sorted property. // Erasing does not affect the sorted property.
Attribute attr = it->second; Attribute attr = it->getValue();
attrs.erase(it); attrs.erase(it);
dictionarySorted.setPointer(nullptr); dictionarySorted.setPointer(nullptr);
return attr; return attr;
@ -485,11 +484,12 @@ void MutableOperandRange::updateLength(unsigned newLength) {
// Update any of the provided segment attributes. // Update any of the provided segment attributes.
for (OperandSegment &segment : operandSegments) { for (OperandSegment &segment : operandSegments) {
auto attr = segment.second.second.cast<DenseIntElementsAttr>(); auto attr = segment.second.getValue().cast<DenseIntElementsAttr>();
SmallVector<int32_t, 8> segments(attr.getValues<int32_t>()); SmallVector<int32_t, 8> segments(attr.getValues<int32_t>());
segments[segment.first] += diff; segments[segment.first] += diff;
segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments); segment.second.setValue(
owner->setAttr(segment.second.first, segment.second.second); DenseIntElementsAttr::get(attr.getType(), segments));
owner->setAttr(segment.second.getName(), segment.second.getValue());
} }
} }
@ -500,21 +500,21 @@ MutableOperandRangeRange::MutableOperandRangeRange(
const MutableOperandRange &operands, NamedAttribute operandSegmentAttr) const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
: MutableOperandRangeRange( : MutableOperandRangeRange(
OwnerT(operands, operandSegmentAttr), 0, OwnerT(operands, operandSegmentAttr), 0,
operandSegmentAttr.second.cast<DenseElementsAttr>().size()) {} operandSegmentAttr.getValue().cast<DenseElementsAttr>().size()) {}
MutableOperandRange MutableOperandRangeRange::join() const { MutableOperandRange MutableOperandRangeRange::join() const {
return getBase().first; return getBase().first;
} }
MutableOperandRangeRange::operator OperandRangeRange() const { MutableOperandRangeRange::operator OperandRangeRange() const {
return OperandRangeRange(getBase().first, return OperandRangeRange(
getBase().second.second.cast<DenseElementsAttr>()); getBase().first, getBase().second.getValue().cast<DenseElementsAttr>());
} }
MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
ptrdiff_t index) { ptrdiff_t index) {
auto sizeData = auto sizeData =
object.second.second.cast<DenseElementsAttr>().getValues<uint32_t>(); object.second.getValue().cast<DenseElementsAttr>().getValues<uint32_t>();
uint32_t startIndex = uint32_t startIndex =
std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
return object.first.slice( return object.first.slice(

View File

@ -170,7 +170,7 @@ LogicalResult OperationVerifier::verifyOperation(
/// Verify that all of the attributes are okay. /// Verify that all of the attributes are okay.
for (auto attr : op.getAttrs()) { for (auto attr : op.getAttrs()) {
// Check for any optional dialect specific attributes. // Check for any optional dialect specific attributes.
if (auto *dialect = attr.first.getReferencedDialect()) if (auto *dialect = attr.getNameDialect())
if (failed(dialect->verifyOperationAttribute(&op, attr))) if (failed(dialect->verifyOperationAttribute(&op, attr)))
return failure(); return failure();
} }

View File

@ -1123,7 +1123,7 @@ public:
Optional<NamedAttribute> duplicate = opState.attributes.findDuplicate(); Optional<NamedAttribute> duplicate = opState.attributes.findDuplicate();
if (duplicate) if (duplicate)
return emitError(getNameLoc(), "attribute '") return emitError(getNameLoc(), "attribute '")
<< duplicate->first.getValue() << duplicate->getName().getValue()
<< "' occurs more than once in the attribute list"; << "' occurs more than once in the attribute list";
return success(); return success();
} }

View File

@ -812,7 +812,7 @@ CppEmitter::emitOperandsAndAttributes(Operation &op,
// Insert comma in between operands and non-filtered attributes if needed. // Insert comma in between operands and non-filtered attributes if needed.
if (op.getNumOperands() > 0) { if (op.getNumOperands() > 0) {
for (NamedAttribute attr : op.getAttrs()) { for (NamedAttribute attr : op.getAttrs()) {
if (!llvm::is_contained(exclude, attr.first.strref())) { if (!llvm::is_contained(exclude, attr.getName().strref())) {
os << ", "; os << ", ";
break; break;
} }
@ -820,10 +820,10 @@ CppEmitter::emitOperandsAndAttributes(Operation &op,
} }
// Emit attributes. // Emit attributes.
auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
if (llvm::is_contained(exclude, attr.first.strref())) if (llvm::is_contained(exclude, attr.getName().strref()))
return success(); return success();
os << "/* " << attr.first.getValue() << " */"; os << "/* " << attr.getName().getValue() << " */";
if (failed(emitAttribute(op.getLoc(), attr.second))) if (failed(emitAttribute(op.getLoc(), attr.getValue())))
return failure(); return failure();
return success(); return success();
}; };

View File

@ -224,9 +224,9 @@ static void setLoopMetadata(Operation &opInst, llvm::Instruction &llvmInst,
SmallVector<llvm::Metadata *> parallelAccess; SmallVector<llvm::Metadata *> parallelAccess;
parallelAccess.push_back( parallelAccess.push_back(
llvm::MDString::get(ctx, "llvm.loop.parallel_accesses")); llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
for (SymbolRefAttr accessGroupRef : for (SymbolRefAttr accessGroupRef : parallelAccessGroup->getValue()
parallelAccessGroup->second.cast<ArrayAttr>() .cast<ArrayAttr>()
.getAsRange<SymbolRefAttr>()) .getAsRange<SymbolRefAttr>())
parallelAccess.push_back( parallelAccess.push_back(
moduleTranslation.getAccessGroup(opInst, accessGroupRef)); moduleTranslation.getAccessGroup(opInst, accessGroupRef));
loopOptions.push_back(llvm::MDNode::get(ctx, parallelAccess)); loopOptions.push_back(llvm::MDNode::get(ctx, parallelAccess));

View File

@ -57,7 +57,7 @@ public:
LogicalResult LogicalResult
amendOperation(Operation *op, NamedAttribute attribute, amendOperation(Operation *op, NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final { LLVM::ModuleTranslation &moduleTranslation) const final {
if (attribute.first == NVVM::NVVMDialect::getKernelFuncAttrName()) { if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op); auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func) if (!func)
return failure(); return failure();

View File

@ -64,7 +64,7 @@ public:
LogicalResult LogicalResult
amendOperation(Operation *op, NamedAttribute attribute, amendOperation(Operation *op, NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final { LLVM::ModuleTranslation &moduleTranslation) const final {
if (attribute.first == ROCDL::ROCDLDialect::getKernelFuncAttrName()) { if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op); auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func) if (!func)
return failure(); return failure();

View File

@ -521,7 +521,7 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
defaultValue); defaultValue);
if (decorations.count(resultID)) { if (decorations.count(resultID)) {
for (auto attr : decorations[resultID].getAttrs()) for (auto attr : decorations[resultID].getAttrs())
op->setAttr(attr.first, attr.second); op->setAttr(attr.getName(), attr.getValue());
} }
specConstMap[resultID] = op; specConstMap[resultID] = op;
return op; return op;
@ -591,9 +591,8 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
// Decorations. // Decorations.
if (decorations.count(variableID)) { if (decorations.count(variableID)) {
for (auto attr : decorations[variableID].getAttrs()) { for (auto attr : decorations[variableID].getAttrs())
varOp->setAttr(attr.first, attr.second); varOp->setAttr(attr.getName(), attr.getValue());
}
} }
globalVariableMap[variableID] = varOp; globalVariableMap[variableID] = varOp;
return success(); return success();

View File

@ -295,8 +295,9 @@ LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
(void)encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, (void)encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable,
operands); operands);
for (auto attr : op->getAttrs()) { for (auto attr : op->getAttrs()) {
if (llvm::any_of(elidedAttrs, if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
[&](StringRef elided) { return attr.first == elided; })) { return attr.getName() == elided;
})) {
continue; continue;
} }
if (failed(processDecoration(op.getLoc(), resultID, attr))) { if (failed(processDecoration(op.getLoc(), resultID, attr))) {
@ -364,8 +365,9 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
// Encode decorations. // Encode decorations.
for (auto attr : varOp->getAttrs()) { for (auto attr : varOp->getAttrs()) {
if (llvm::any_of(elidedAttrs, if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
[&](StringRef elided) { return attr.first == elided; })) { return attr.getName() == elided;
})) {
continue; continue;
} }
if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {

View File

@ -205,7 +205,7 @@ void Serializer::processMemoryModel() {
LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr) { NamedAttribute attr) {
auto attrName = attr.first.strref(); auto attrName = attr.getName().strref();
auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true);
auto decoration = spirv::symbolizeDecoration(decorationName); auto decoration = spirv::symbolizeDecoration(decorationName);
if (!decoration) { if (!decoration) {
@ -219,13 +219,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::Binding: case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet: case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location: case spirv::Decoration::Location:
if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) { if (auto intAttr = attr.getValue().dyn_cast<IntegerAttr>()) {
args.push_back(intAttr.getValue().getZExtValue()); args.push_back(intAttr.getValue().getZExtValue());
break; break;
} }
return emitError(loc, "expected integer attribute for ") << attrName; return emitError(loc, "expected integer attribute for ") << attrName;
case spirv::Decoration::BuiltIn: case spirv::Decoration::BuiltIn:
if (auto strAttr = attr.second.dyn_cast<StringAttr>()) { if (auto strAttr = attr.getValue().dyn_cast<StringAttr>()) {
auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
if (enumVal) { if (enumVal) {
args.push_back(static_cast<uint32_t>(enumVal.getValue())); args.push_back(static_cast<uint32_t>(enumVal.getValue()));
@ -243,7 +243,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::Restrict: case spirv::Decoration::Restrict:
case spirv::Decoration::RelaxedPrecision: case spirv::Decoration::RelaxedPrecision:
// For unit attributes, the args list has no values so we do nothing // For unit attributes, the args list has no values so we do nothing
if (auto unitAttr = attr.second.dyn_cast<UnitAttr>()) if (auto unitAttr = attr.getValue().dyn_cast<UnitAttr>())
break; break;
return emitError(loc, "expected unit attribute for ") << attrName; return emitError(loc, "expected unit attribute for ") << attrName;
default: default:

View File

@ -90,7 +90,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
// Perform index rewrites for the dereferencing op and then replace the op // Perform index rewrites for the dereferencing op and then replace the op
NamedAttribute oldMapAttrPair = NamedAttribute oldMapAttrPair =
affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue(); AffineMap oldMap = oldMapAttrPair.getValue().cast<AffineMapAttr>().getValue();
unsigned oldMapNumInputs = oldMap.getNumInputs(); unsigned oldMapNumInputs = oldMap.getNumInputs();
SmallVector<Value, 4> oldMapOperands( SmallVector<Value, 4> oldMapOperands(
op->operand_begin() + memRefOperandPos + 1, op->operand_begin() + memRefOperandPos + 1,
@ -194,8 +194,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
// Add attribute for 'newMap', other Attributes do not change. // Add attribute for 'newMap', other Attributes do not change.
auto newMapAttr = AffineMapAttr::get(newMap); auto newMapAttr = AffineMapAttr::get(newMap);
for (auto namedAttr : op->getAttrs()) { for (auto namedAttr : op->getAttrs()) {
if (namedAttr.first == oldMapAttrPair.first) if (namedAttr.getName() == oldMapAttrPair.getName())
state.attributes.push_back({namedAttr.first, newMapAttr}); state.attributes.push_back({namedAttr.getName(), newMapAttr});
else else
state.attributes.push_back(namedAttr); state.attributes.push_back(namedAttr);
} }

View File

@ -221,8 +221,8 @@ private:
if (printAttrs) { if (printAttrs) {
os << "\n"; os << "\n";
for (const NamedAttribute &attr : op->getAttrs()) { for (const NamedAttribute &attr : op->getAttrs()) {
os << '\n' << attr.first.getValue() << ": "; os << '\n' << attr.getName().getValue() << ": ";
emitMlirAttr(os, attr.second); emitMlirAttr(os, attr.getValue());
} }
} }
}); });

View File

@ -288,7 +288,7 @@ void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
LogicalResult TestDialect::verifyOperationAttribute(Operation *op, LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
NamedAttribute namedAttr) { NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr") if (namedAttr.getName() == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'"; return op->emitError() << "invalid to use 'test.invalid_attr'";
return success(); return success();
} }
@ -297,7 +297,7 @@ LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex, unsigned regionIndex,
unsigned argIndex, unsigned argIndex,
NamedAttribute namedAttr) { NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr") if (namedAttr.getName() == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'"; return op->emitError() << "invalid to use 'test.invalid_attr'";
return success(); return success();
} }
@ -306,7 +306,7 @@ LogicalResult
TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
unsigned resultIndex, unsigned resultIndex,
NamedAttribute namedAttr) { NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr") if (namedAttr.getName() == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'"; return op->emitError() << "invalid to use 'test.invalid_attr'";
return success(); return success();
} }
@ -942,7 +942,7 @@ static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
// If the attribute dictionary contains no 'names' attribute, infer it from // If the attribute dictionary contains no 'names' attribute, infer it from
// the SSA name (if specified). // the SSA name (if specified).
bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
return attr.first == "names"; return attr.getName() == "names";
}); });
// If there was no name specified, check to see if there was a useful name // If there was no name specified, check to see if there was a useful name

View File

@ -243,7 +243,7 @@ void TestDerivedAttributeDriver::runOnFunction() {
if (!dAttr) if (!dAttr)
return; return;
for (auto d : dAttr) for (auto d : dAttr)
dOp.emitRemark() << d.first.getValue() << " = " << d.second; dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
}); });
} }

View File

@ -23,7 +23,7 @@ struct TestElementsAttrInterface
void runOnOperation() override { void runOnOperation() override {
getOperation().walk([&](Operation *op) { getOperation().walk([&](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) { for (NamedAttribute attr : op->getAttrs()) {
auto elementsAttr = attr.second.dyn_cast<ElementsAttr>(); auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
if (!elementsAttr) if (!elementsAttr)
continue; continue;
testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t"); testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");

View File

@ -37,8 +37,8 @@ struct TestPrintNestingPass
if (!op->getAttrs().empty()) { if (!op->getAttrs().empty()) {
printIndent() << op->getAttrs().size() << " attributes:\n"; printIndent() << op->getAttrs().size() << " attributes:\n";
for (NamedAttribute attr : op->getAttrs()) for (NamedAttribute attr : op->getAttrs())
printIndent() << " - '" << attr.first.getValue() << "' : '" printIndent() << " - '" << attr.getName().getValue() << "' : '"
<< attr.second << "'\n"; << attr.getValue() << "'\n";
} }
// Recurse into each of the regions attached to the operation. // Recurse into each of the regions attached to the operation.

View File

@ -51,7 +51,7 @@ def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
// CHECK-LABEL: OpD definitions // CHECK-LABEL: OpD definitions
// CHECK: void OpD::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) // CHECK: void OpD::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
// CHECK: odsState.addTypes({attr.second.cast<::mlir::TypeAttr>().getValue()}); // CHECK: odsState.addTypes({attr.getValue().cast<::mlir::TypeAttr>().getValue()});
def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, F32Attr:$attr); let arguments = (ins I32:$x, F32Attr:$attr);
@ -60,7 +60,7 @@ def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
// CHECK-LABEL: OpE definitions // CHECK-LABEL: OpE definitions
// CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) // CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
// CHECK: odsState.addTypes({attr.second.getType()}); // CHECK: odsState.addTypes({attr.getValue().getType()});
def OpF : NS_Op<"one_variadic_result_op", []> { def OpF : NS_Op<"one_variadic_result_op", []> {
let results = (outs Variadic<I32>:$x); let results = (outs Variadic<I32>:$x);

View File

@ -1449,11 +1449,11 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
<< "AttrName(" << builderOpState << "AttrName(" << builderOpState
<< ".name);\n" << ".name);\n"
" for (auto attr : attributes) {\n" " for (auto attr : attributes) {\n"
" if (attr.first != attrName) continue;\n"; " if (attr.getName() != attrName) continue;\n";
if (namedAttr.attr.isTypeAttr()) { if (namedAttr.attr.isTypeAttr()) {
resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()"; resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()";
} else { } else {
resultType = "attr.second.getType()"; resultType = "attr.getValue().getType()";
} }
// Operands // Operands

View File

@ -673,7 +673,8 @@ static void emitDecorationSerialization(const Operator &op, StringRef tabs,
// All non-argument attributes translated into OpDecorate instruction // All non-argument attributes translated into OpDecorate instruction
os << tabs << formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar); os << tabs << formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar);
os << tabs os << tabs
<< formatv(" if (llvm::is_contained({0}, attr.first)) {{", elidedAttrs); << formatv(" if (llvm::is_contained({0}, attr.getName())) {{",
elidedAttrs);
os << tabs << " continue;\n"; os << tabs << " continue;\n";
os << tabs << " }\n"; os << tabs << " }\n";
os << tabs os << tabs

View File

@ -237,11 +237,11 @@ TEST(NamedAttrListTest, TestAppendAssign) {
{ {
auto it = attrs.begin(); auto it = attrs.begin();
EXPECT_EQ(it->first, b.getStringAttr("foo")); EXPECT_EQ(it->getName(), b.getStringAttr("foo"));
EXPECT_EQ(it->second, b.getStringAttr("bar")); EXPECT_EQ(it->getValue(), b.getStringAttr("bar"));
++it; ++it;
EXPECT_EQ(it->first, b.getStringAttr("baz")); EXPECT_EQ(it->getName(), b.getStringAttr("baz"));
EXPECT_EQ(it->second, b.getStringAttr("boo")); EXPECT_EQ(it->getValue(), b.getStringAttr("boo"));
} }
attrs.append("foo", b.getStringAttr("zoo")); attrs.append("foo", b.getStringAttr("zoo"));
@ -261,11 +261,11 @@ TEST(NamedAttrListTest, TestAppendAssign) {
{ {
auto it = attrs.begin(); auto it = attrs.begin();
EXPECT_EQ(it->first, b.getStringAttr("foo")); EXPECT_EQ(it->getName(), b.getStringAttr("foo"));
EXPECT_EQ(it->second, b.getStringAttr("f")); EXPECT_EQ(it->getValue(), b.getStringAttr("f"));
++it; ++it;
EXPECT_EQ(it->first, b.getStringAttr("zoo")); EXPECT_EQ(it->getName(), b.getStringAttr("zoo"));
EXPECT_EQ(it->second, b.getStringAttr("z")); EXPECT_EQ(it->getValue(), b.getStringAttr("z"));
} }
attrs.assign({}); attrs.assign({});

View File

@ -62,7 +62,8 @@ protected:
EXPECT_EQ(op->getAttrs().size(), attrs.size()); EXPECT_EQ(op->getAttrs().size(), attrs.size());
for (unsigned idx : llvm::seq<unsigned>(0U, attrs.size())) for (unsigned idx : llvm::seq<unsigned>(0U, attrs.size()))
EXPECT_EQ(op->getAttr(attrs[idx].first.strref()), attrs[idx].second); EXPECT_EQ(op->getAttr(attrs[idx].getName().strref()),
attrs[idx].getValue());
concreteOp.erase(); concreteOp.erase();
} }

View File

@ -62,7 +62,7 @@ TEST(StructsGenTest, ClassofExtraFalse) {
// Add an extra NamedAttribute. // Add an extra NamedAttribute.
auto wrongId = mlir::StringAttr::get(&context, "wrong"); auto wrongId = mlir::StringAttr::get(&context, "wrong");
auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].getValue());
newValues.push_back(wrongAttr); newValues.push_back(wrongAttr);
// Make a new DictionaryAttr and validate. // Make a new DictionaryAttr and validate.
@ -84,7 +84,7 @@ TEST(StructsGenTest, ClassofBadNameFalse) {
// Add a copy of the first attribute with the wrong name. // Add a copy of the first attribute with the wrong name.
auto wrongId = mlir::StringAttr::get(&context, "wrong"); auto wrongId = mlir::StringAttr::get(&context, "wrong");
auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].getValue());
newValues.push_back(wrongAttr); newValues.push_back(wrongAttr);
auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
@ -108,7 +108,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) {
auto elementsType = mlir::RankedTensorType::get({3}, i64Type); auto elementsType = mlir::RankedTensorType::get({3}, i64Type);
auto elementsAttr = auto elementsAttr =
mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3}); mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3});
mlir::StringAttr id = expectedValues.back().first; mlir::StringAttr id = expectedValues.back().getName();
auto wrongAttr = mlir::NamedAttribute(id, elementsAttr); auto wrongAttr = mlir::NamedAttribute(id, elementsAttr);
newValues.push_back(wrongAttr); newValues.push_back(wrongAttr);