diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index d09e8150f4c5..eb39537c03d4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -173,7 +173,7 @@ public: llvm::Module &getLLVMModule(); /// Parse a type registered to this dialect. - Type parseType(DialectAsmParser &parser, Location loc) const override; + Type parseType(DialectAsmParser &parser) const override; /// Print a type registered to this dialect. void printType(Type type, DialectAsmPrinter &os) const override; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h index 888895335c78..81c309704b8d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -37,7 +37,7 @@ public: static StringRef getDialectNamespace() { return "linalg"; } /// Parse a type registered to this dialect. - Type parseType(DialectAsmParser &parser, Location loc) const override; + Type parseType(DialectAsmParser &parser) const override; /// Print a type registered to this dialect. void printType(Type type, DialectAsmPrinter &os) const override; diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h index f1ac3836321f..020d34918d48 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h @@ -35,7 +35,7 @@ public: QuantizationDialect(MLIRContext *context); /// Parse a type registered to this dialect. - Type parseType(DialectAsmParser &parser, Location loc) const override; + Type parseType(DialectAsmParser &parser) const override; /// Print a type registered to this dialect. void printType(Type type, DialectAsmPrinter &os) const override; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h index 6401eba0c957..2571e5d89284 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h @@ -46,7 +46,7 @@ public: static std::string getAttributeName(Decoration decoration); /// Parses a type registered to this dialect. - Type parseType(DialectAsmParser &parser, Location loc) const override; + Type parseType(DialectAsmParser &parser) const override; /// Prints a type registered to this dialect. void printType(Type type, DialectAsmPrinter &os) const override; diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h index d53327371c2c..1d284f6ccd77 100644 --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -391,7 +391,7 @@ private: friend DiagnosticEngine; /// The engine that this diagnostic is to report to. - DiagnosticEngine *owner; + DiagnosticEngine *owner = nullptr; /// The raw diagnostic that is inflight to be reported. llvm::Optional impl; diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index bd84bee73a0b..4880fd0ca186 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -117,8 +117,7 @@ public: /// Parse an attribute registered to this dialect. If 'type' is nonnull, it /// refers to the expected type of the attribute. - virtual Attribute parseAttribute(DialectAsmParser &parser, Type type, - Location loc) const; + virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const; /// Print an attribute registered to this dialect. Note: The type of the /// attribute need not be printed by this method as it is always printed by @@ -128,7 +127,7 @@ public: } /// Parse a type registered to this dialect. - virtual Type parseType(DialectAsmParser &parser, Location loc) const; + virtual Type parseType(DialectAsmParser &parser) const; /// Print a type registered to this dialect. virtual void printType(Type, DialectAsmPrinter &) const { diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h index c662a4cc5c29..c538b8162dcd 100644 --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -129,6 +129,9 @@ public: /// Return the location of the original name token. virtual llvm::SMLoc getNameLoc() const = 0; + /// Re-encode the given source location as an MLIR location and return it. + virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0; + /// Returns the full specification of the symbol being parsed. This allows for /// using a separate parser if necessary. virtual StringRef getFullSymbolSpec() const = 0; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 39decf961df0..cc2de877436c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1250,7 +1250,7 @@ llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; } llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; } /// Parse a type registered to this dialect. -Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const { +Type LLVMDialect::parseType(DialectAsmParser &parser) const { StringRef tyData = parser.getFullSymbolSpec(); // LLVM is not thread-safe, so lock access to it. @@ -1259,7 +1259,8 @@ Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const { llvm::SMDiagnostic errorMessage; llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module); if (!type) - return (emitError(loc, errorMessage.getMessage()), nullptr); + return (parser.emitError(parser.getNameLoc(), errorMessage.getMessage()), + nullptr); return LLVMType::get(getContext(), type); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp index 4a7bcd87d584..41e76a87b726 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -108,8 +108,8 @@ Optional mlir::linalg::BufferType::getBufferSize() { return getImpl()->getBufferSize(); } -Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser, - Location loc) const { +Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const { + Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); StringRef spec = parser.getFullSymbolSpec(); StringRef origSpec = spec; MLIRContext *context = getContext(); diff --git a/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp b/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp index 360c1b58f88c..26212f69c3ce 100644 --- a/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp @@ -616,8 +616,8 @@ bool TypeParser::parseQuantParams(double &scale, int64_t &zeroPoint) { } /// Parse a type registered to this dialect. -Type QuantizationDialect::parseType(DialectAsmParser &parser, - Location loc) const { +Type QuantizationDialect::parseType(DialectAsmParser &parser) const { + Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); TypeParser typeParser(parser.getFullSymbolSpec(), getContext(), loc); Type parsedType = typeParser.parseType(); if (parsedType == nullptr) { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index 26d1ff18d01a..abe47240b2fc 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -610,8 +610,9 @@ static Type parseStructType(SPIRVDialect const &dialect, StringRef spec, // | pointer-type // | runtime-array-type // | struct-type -Type SPIRVDialect::parseType(DialectAsmParser &parser, Location loc) const { +Type SPIRVDialect::parseType(DialectAsmParser &parser) const { StringRef spec = parser.getFullSymbolSpec(); + Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); if (spec.startswith("array")) return parseArrayType(*this, spec, loc); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 7882e4f1f19d..c6266b096687 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -102,23 +102,23 @@ LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, } /// Parse an attribute registered to this dialect. -Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type, - Location loc) const { - emitError(loc) << "dialect '" << getNamespace() - << "' provides no attribute parsing hook"; +Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { + parser.emitError(parser.getNameLoc()) + << "dialect '" << getNamespace() + << "' provides no attribute parsing hook"; return Attribute(); } /// Parse a type registered to this dialect. -Type Dialect::parseType(DialectAsmParser &parser, Location loc) const { +Type Dialect::parseType(DialectAsmParser &parser) const { // If this dialect allows unknown types, then represent this with OpaqueType. if (allowsUnknownTypes()) { auto ns = Identifier::get(getNamespace(), getContext()); return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext()); } - emitError(loc) << "dialect '" << getNamespace() - << "' provides no type parsing hook"; + parser.emitError(parser.getNameLoc()) + << "dialect '" << getNamespace() << "' provides no type parsing hook"; return Type(); } diff --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h index 896c26cc927b..b18077114578 100644 --- a/mlir/lib/Parser/Lexer.h +++ b/mlir/lib/Parser/Lexer.h @@ -45,6 +45,9 @@ public: /// at the designated point in the input. void resetPointer(const char *newPointer) { curPtr = newPointer; } + /// Returns the start of the buffer. + const char *getBufferBegin() { return curBuffer.data(); } + private: // Helpers. Token formToken(Token::Kind kind, const char *tokStart) { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index a6e02279adb9..368f262ade7f 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -52,16 +52,27 @@ namespace { class Parser; //===----------------------------------------------------------------------===// -// AliasState +// SymbolState //===----------------------------------------------------------------------===// -/// This class contains record of any parsed top-level aliases. -struct AliasState { +/// This class contains record of any parsed top-level symbols. +struct SymbolState { // A map from attribute alias identifier to Attribute. llvm::StringMap attributeAliasDefinitions; // A map from type alias identifier to Type. llvm::StringMap typeAliasDefinitions; + + /// A set of locations into the main parser memory buffer for each of the + /// active nested parsers. Given that some nested parsers, i.e. custom dialect + /// parsers, operate on a temporary memory buffer, this provides an anchor + /// point for emitting diagnostics. + SmallVector nestedParserLocs; + + /// The top-level lexer that contains the original memory buffer provided by + /// the user. This is used by nested parsers to get a properly encoded source + /// location. + Lexer *topLevelLexer = nullptr; }; //===----------------------------------------------------------------------===// @@ -72,9 +83,18 @@ struct AliasState { /// such as the current lexer position etc. struct ParserState { ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, - AliasState &aliases) + SymbolState &symbols) : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), - aliases(aliases) {} + symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { + // Set the top level lexer for the symbol state if one doesn't exist. + if (!symbols.topLevelLexer) + symbols.topLevelLexer = &lex; + } + ~ParserState() { + // Reset the top level lexer if it refers the lexer in our state. + if (symbols.topLevelLexer == &lex) + symbols.topLevelLexer = nullptr; + } ParserState(const ParserState &) = delete; void operator=(const ParserState &) = delete; @@ -87,8 +107,11 @@ struct ParserState { /// This is the next token that hasn't been consumed yet. Token curToken; - /// Any parsed alias state. - AliasState &aliases; + /// The current state for symbol parsing. + SymbolState &symbols; + + /// The depth of this parser in the nested parsing stack. + size_t parserDepth; }; //===----------------------------------------------------------------------===// @@ -140,7 +163,32 @@ public: /// Encode the specified source location information into an attribute for /// attachment to the IR. Location getEncodedSourceLocation(llvm::SMLoc loc) { - return state.lex.getEncodedSourceLocation(loc); + // If there are no active nested parsers, we can get the encoded source + // location directly. + if (state.parserDepth == 0) + return state.lex.getEncodedSourceLocation(loc); + // Otherwise, we need to re-encode it to point to the top level buffer. + return state.symbols.topLevelLexer->getEncodedSourceLocation( + remapLocationToTopLevelBuffer(loc)); + } + + /// Remaps the given SMLoc to the top level lexer of the parser. This is used + /// to adjust locations of potentially nested parsers to ensure that they can + /// be emitted properly as diagnostics. + llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) { + // If there are no active nested parsers, we can return location directly. + SymbolState &symbols = state.symbols; + if (state.parserDepth == 0) + return loc; + assert(symbols.topLevelLexer && "expected valid top-level lexer"); + + // Otherwise, we need to remap the location to the main parser. This is + // simply offseting the location onto the location of the last nested + // parser. + size_t offset = loc.getPointer() - state.lex.getBufferBegin(); + auto *rawLoc = + symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset; + return llvm::SMLoc::getFromPointer(rawLoc); } //===--------------------------------------------------------------------===// @@ -388,6 +436,11 @@ public: /// Return the location of the original name token. llvm::SMLoc getNameLoc() const override { return nameLoc; } + /// Re-encode the given source location as an MLIR location and return it. + Location getEncodedSourceLoc(llvm::SMLoc loc) override { + return parser.getEncodedSourceLocation(loc); + } + /// Returns the full specification of the symbol being parsed. This allows /// for using a separate parser if necessary. StringRef getFullSymbolSpec() const override { return fullSpec; } @@ -517,7 +570,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, return (p.emitError("expected string literal data in dialect symbol"), nullptr); symbolData = p.getToken().getStringValue(); - loc = p.getToken().getLoc(); + loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); p.consumeToken(Token::string); // Consume the '>'. @@ -529,6 +582,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, auto dotHalves = identifier.split('.'); dialectName = dotHalves.first; auto prettyName = dotHalves.second; + loc = llvm::SMLoc::getFromPointer(prettyName.data()); // If the dialect's symbol is followed immediately by a <, then lex the body // of it into prettyName. @@ -541,8 +595,16 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, symbolData = prettyName.str(); } + // Record the name location of the type remapped to the top level buffer. + llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); + p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); + // Call into the provided symbol construction function. - return createSymbol(dialectName, symbolData, loc); + Symbol sym = createSymbol(dialectName, symbolData, loc); + + // Pop the last parser location. + p.getState().symbols.nestedParserLocs.pop_back(); + return sym; } /// Parses a symbol, of type 'T', and returns it if parsing was successful. If @@ -550,14 +612,14 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, /// string is returned in 'numRead'. template static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, - AliasState &aliasState, ParserFn &&parserFn, + SymbolState &symbolState, ParserFn &&parserFn, size_t *numRead = nullptr) { SourceMgr sourceMgr; auto memBuffer = MemoryBuffer::getMemBuffer( inputStr, /*BufferName=*/"", /*RequiresNullTerminator=*/false); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); - ParserState state(sourceMgr, context, aliasState); + ParserState state(sourceMgr, context, symbolState); Parser parser(state); Token startTok = parser.getToken(); @@ -573,8 +635,7 @@ static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, // Otherwise, ensure that all of the tokens were parsed. } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { - parser.emitError(endTok.getLoc(), - "encountered unexpected tokens after parsing"); + parser.emitError(endTok.getLoc(), "encountered unexpected token"); return T(); } return symbol; @@ -585,13 +646,12 @@ static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, //===----------------------------------------------------------------------===// InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { - auto diag = mlir::emitError(getEncodedSourceLocation(loc), message); - // If we hit a parse error in response to a lexer error, then the lexer // already reported the error. if (getToken().is(Token::error)) - diag.abandon(); - return diag; + return InFlightDiagnostic(); + + return mlir::emitError(getEncodedSourceLocation(loc), message); } //===----------------------------------------------------------------------===// @@ -701,24 +761,22 @@ Type Parser::parseComplexType() { /// Type Parser::parseExtendedType() { return parseExtendedSymbol( - *this, Token::exclamation_identifier, state.aliases.typeAliasDefinitions, + *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, [&](StringRef dialectName, StringRef symbolData, llvm::SMLoc loc) -> Type { - Location encodedLoc = getEncodedSourceLocation(loc); - // If we found a registered dialect, then ask it to parse the type. if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { return parseSymbol( - symbolData, state.context, state.aliases, [&](Parser &parser) { + symbolData, state.context, state.symbols, [&](Parser &parser) { CustomDialectAsmParser customParser(symbolData, parser); - return dialect->parseType(customParser, encodedLoc); + return dialect->parseType(customParser); }); } // Otherwise, form a new opaque type. return OpaqueType::getChecked( Identifier::get(dialectName, state.context), symbolData, - state.context, encodedLoc); + state.context, getEncodedSourceLocation(loc)); }); } @@ -1315,7 +1373,7 @@ Parser::parseAttributeDict(SmallVectorImpl &attributes) { /// Attribute Parser::parseExtendedAttr(Type type) { Attribute attr = parseExtendedSymbol( - *this, Token::hash_identifier, state.aliases.attributeAliasDefinitions, + *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, [&](StringRef dialectName, StringRef symbolData, llvm::SMLoc loc) -> Attribute { // Parse an optional trailing colon type. @@ -1326,10 +1384,9 @@ Attribute Parser::parseExtendedAttr(Type type) { // If we found a registered dialect, then ask it to parse the attribute. if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { return parseSymbol( - symbolData, state.context, state.aliases, [&](Parser &parser) { + symbolData, state.context, state.symbols, [&](Parser &parser) { CustomDialectAsmParser customParser(symbolData, parser); - return dialect->parseAttribute(customParser, attrType, - getEncodedSourceLocation(loc)); + return dialect->parseAttribute(customParser, attrType); }); } @@ -4242,7 +4299,7 @@ ParseResult ModuleParser::parseAttributeAliasDef() { StringRef aliasName = getTokenSpelling().drop_front(); // Check for redefinitions. - if (getState().aliases.attributeAliasDefinitions.count(aliasName) > 0) + if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) return emitError("redefinition of attribute alias id '" + aliasName + "'"); // Make sure this isn't invading the dialect attribute namespace. @@ -4261,7 +4318,7 @@ ParseResult ModuleParser::parseAttributeAliasDef() { if (!attr) return failure(); - getState().aliases.attributeAliasDefinitions[aliasName] = attr; + getState().symbols.attributeAliasDefinitions[aliasName] = attr; return success(); } @@ -4274,7 +4331,7 @@ ParseResult ModuleParser::parseTypeAliasDef() { StringRef aliasName = getTokenSpelling().drop_front(); // Check for redefinitions. - if (getState().aliases.typeAliasDefinitions.count(aliasName) > 0) + if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) return emitError("redefinition of type alias id '" + aliasName + "'"); // Make sure this isn't invading the dialect type namespace. @@ -4295,7 +4352,7 @@ ParseResult ModuleParser::parseTypeAliasDef() { return failure(); // Register this alias with the parser state. - getState().aliases.typeAliasDefinitions.try_emplace(aliasName, aliasedType); + getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); return success(); } @@ -4374,7 +4431,7 @@ OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); - AliasState aliasState; + SymbolState aliasState; ParserState state(sourceMgr, context, aliasState); if (ModuleParser(state).parseModule(*module)) return nullptr; @@ -4440,7 +4497,7 @@ OwningModuleRef mlir::parseSourceString(StringRef moduleStr, template static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, size_t &numRead, ParserFn &&parserFn) { - AliasState aliasState; + SymbolState aliasState; return parseSymbol( inputStr, context, aliasState, [&](Parser &parser) {