Remove the need for passing a location to parseAttribute/parseType.

Now that a proper parser is passed to these methods, there isn't a need to explicitly pass a source location. The source location can be recovered from the parser as necessary. This removes the need to explicitly decode an SMLoc in the case where we don't need to, which can be expensive.

This requires adding some basic nesting support to the parser for supporting nested parsers to allow for remapping source locations of the nested parsers to the top level parser for accurate diagnostics. This is due to the fact that the attribute and type parsers use different source buffers than the top level parser, as they may be represented in string form.

PiperOrigin-RevId: 278014858
This commit is contained in:
River Riddle 2019-11-01 15:39:30 -07:00 committed by A. Unique TensorFlower
parent 445cc3f6dd
commit 2ba4d802e0
14 changed files with 120 additions and 56 deletions

View File

@ -173,7 +173,7 @@ public:
llvm::Module &getLLVMModule();
/// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser, Location loc) const override;
Type parseType(DialectAsmParser &parser) const override;
/// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;

View File

@ -37,7 +37,7 @@ public:
static StringRef getDialectNamespace() { return "linalg"; }
/// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser, Location loc) const override;
Type parseType(DialectAsmParser &parser) const override;
/// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;

View File

@ -35,7 +35,7 @@ public:
QuantizationDialect(MLIRContext *context);
/// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser, Location loc) const override;
Type parseType(DialectAsmParser &parser) const override;
/// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;

View File

@ -46,7 +46,7 @@ public:
static std::string getAttributeName(Decoration decoration);
/// Parses a type registered to this dialect.
Type parseType(DialectAsmParser &parser, Location loc) const override;
Type parseType(DialectAsmParser &parser) const override;
/// Prints a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;

View File

@ -391,7 +391,7 @@ private:
friend DiagnosticEngine;
/// The engine that this diagnostic is to report to.
DiagnosticEngine *owner;
DiagnosticEngine *owner = nullptr;
/// The raw diagnostic that is inflight to be reported.
llvm::Optional<Diagnostic> impl;

View File

@ -117,8 +117,7 @@ public:
/// Parse an attribute registered to this dialect. If 'type' is nonnull, it
/// refers to the expected type of the attribute.
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type,
Location loc) const;
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
/// Print an attribute registered to this dialect. Note: The type of the
/// attribute need not be printed by this method as it is always printed by
@ -128,7 +127,7 @@ public:
}
/// Parse a type registered to this dialect.
virtual Type parseType(DialectAsmParser &parser, Location loc) const;
virtual Type parseType(DialectAsmParser &parser) const;
/// Print a type registered to this dialect.
virtual void printType(Type, DialectAsmPrinter &) const {

View File

@ -129,6 +129,9 @@ public:
/// Return the location of the original name token.
virtual llvm::SMLoc getNameLoc() const = 0;
/// Re-encode the given source location as an MLIR location and return it.
virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
/// Returns the full specification of the symbol being parsed. This allows for
/// using a separate parser if necessary.
virtual StringRef getFullSymbolSpec() const = 0;

View File

@ -1250,7 +1250,7 @@ llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
/// Parse a type registered to this dialect.
Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const {
Type LLVMDialect::parseType(DialectAsmParser &parser) const {
StringRef tyData = parser.getFullSymbolSpec();
// LLVM is not thread-safe, so lock access to it.
@ -1259,7 +1259,8 @@ Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const {
llvm::SMDiagnostic errorMessage;
llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module);
if (!type)
return (emitError(loc, errorMessage.getMessage()), nullptr);
return (parser.emitError(parser.getNameLoc(), errorMessage.getMessage()),
nullptr);
return LLVMType::get(getContext(), type);
}

View File

@ -108,8 +108,8 @@ Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
return getImpl()->getBufferSize();
}
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser,
Location loc) const {
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
StringRef spec = parser.getFullSymbolSpec();
StringRef origSpec = spec;
MLIRContext *context = getContext();

View File

@ -616,8 +616,8 @@ bool TypeParser::parseQuantParams(double &scale, int64_t &zeroPoint) {
}
/// Parse a type registered to this dialect.
Type QuantizationDialect::parseType(DialectAsmParser &parser,
Location loc) const {
Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
TypeParser typeParser(parser.getFullSymbolSpec(), getContext(), loc);
Type parsedType = typeParser.parseType();
if (parsedType == nullptr) {

View File

@ -610,8 +610,9 @@ static Type parseStructType(SPIRVDialect const &dialect, StringRef spec,
// | pointer-type
// | runtime-array-type
// | struct-type
Type SPIRVDialect::parseType(DialectAsmParser &parser, Location loc) const {
Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
StringRef spec = parser.getFullSymbolSpec();
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
if (spec.startswith("array"))
return parseArrayType(*this, spec, loc);

View File

@ -102,23 +102,23 @@ LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
}
/// Parse an attribute registered to this dialect.
Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type,
Location loc) const {
emitError(loc) << "dialect '" << getNamespace()
<< "' provides no attribute parsing hook";
Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
parser.emitError(parser.getNameLoc())
<< "dialect '" << getNamespace()
<< "' provides no attribute parsing hook";
return Attribute();
}
/// Parse a type registered to this dialect.
Type Dialect::parseType(DialectAsmParser &parser, Location loc) const {
Type Dialect::parseType(DialectAsmParser &parser) const {
// If this dialect allows unknown types, then represent this with OpaqueType.
if (allowsUnknownTypes()) {
auto ns = Identifier::get(getNamespace(), getContext());
return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
}
emitError(loc) << "dialect '" << getNamespace()
<< "' provides no type parsing hook";
parser.emitError(parser.getNameLoc())
<< "dialect '" << getNamespace() << "' provides no type parsing hook";
return Type();
}

View File

@ -45,6 +45,9 @@ public:
/// at the designated point in the input.
void resetPointer(const char *newPointer) { curPtr = newPointer; }
/// Returns the start of the buffer.
const char *getBufferBegin() { return curBuffer.data(); }
private:
// Helpers.
Token formToken(Token::Kind kind, const char *tokStart) {

View File

@ -52,16 +52,27 @@ namespace {
class Parser;
//===----------------------------------------------------------------------===//
// AliasState
// SymbolState
//===----------------------------------------------------------------------===//
/// This class contains record of any parsed top-level aliases.
struct AliasState {
/// This class contains record of any parsed top-level symbols.
struct SymbolState {
// A map from attribute alias identifier to Attribute.
llvm::StringMap<Attribute> attributeAliasDefinitions;
// A map from type alias identifier to Type.
llvm::StringMap<Type> typeAliasDefinitions;
/// A set of locations into the main parser memory buffer for each of the
/// active nested parsers. Given that some nested parsers, i.e. custom dialect
/// parsers, operate on a temporary memory buffer, this provides an anchor
/// point for emitting diagnostics.
SmallVector<llvm::SMLoc, 1> nestedParserLocs;
/// The top-level lexer that contains the original memory buffer provided by
/// the user. This is used by nested parsers to get a properly encoded source
/// location.
Lexer *topLevelLexer = nullptr;
};
//===----------------------------------------------------------------------===//
@ -72,9 +83,18 @@ struct AliasState {
/// such as the current lexer position etc.
struct ParserState {
ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx,
AliasState &aliases)
SymbolState &symbols)
: context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()),
aliases(aliases) {}
symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) {
// Set the top level lexer for the symbol state if one doesn't exist.
if (!symbols.topLevelLexer)
symbols.topLevelLexer = &lex;
}
~ParserState() {
// Reset the top level lexer if it refers the lexer in our state.
if (symbols.topLevelLexer == &lex)
symbols.topLevelLexer = nullptr;
}
ParserState(const ParserState &) = delete;
void operator=(const ParserState &) = delete;
@ -87,8 +107,11 @@ struct ParserState {
/// This is the next token that hasn't been consumed yet.
Token curToken;
/// Any parsed alias state.
AliasState &aliases;
/// The current state for symbol parsing.
SymbolState &symbols;
/// The depth of this parser in the nested parsing stack.
size_t parserDepth;
};
//===----------------------------------------------------------------------===//
@ -140,7 +163,32 @@ public:
/// Encode the specified source location information into an attribute for
/// attachment to the IR.
Location getEncodedSourceLocation(llvm::SMLoc loc) {
return state.lex.getEncodedSourceLocation(loc);
// If there are no active nested parsers, we can get the encoded source
// location directly.
if (state.parserDepth == 0)
return state.lex.getEncodedSourceLocation(loc);
// Otherwise, we need to re-encode it to point to the top level buffer.
return state.symbols.topLevelLexer->getEncodedSourceLocation(
remapLocationToTopLevelBuffer(loc));
}
/// Remaps the given SMLoc to the top level lexer of the parser. This is used
/// to adjust locations of potentially nested parsers to ensure that they can
/// be emitted properly as diagnostics.
llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) {
// If there are no active nested parsers, we can return location directly.
SymbolState &symbols = state.symbols;
if (state.parserDepth == 0)
return loc;
assert(symbols.topLevelLexer && "expected valid top-level lexer");
// Otherwise, we need to remap the location to the main parser. This is
// simply offseting the location onto the location of the last nested
// parser.
size_t offset = loc.getPointer() - state.lex.getBufferBegin();
auto *rawLoc =
symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset;
return llvm::SMLoc::getFromPointer(rawLoc);
}
//===--------------------------------------------------------------------===//
@ -388,6 +436,11 @@ public:
/// Return the location of the original name token.
llvm::SMLoc getNameLoc() const override { return nameLoc; }
/// Re-encode the given source location as an MLIR location and return it.
Location getEncodedSourceLoc(llvm::SMLoc loc) override {
return parser.getEncodedSourceLocation(loc);
}
/// Returns the full specification of the symbol being parsed. This allows
/// for using a separate parser if necessary.
StringRef getFullSymbolSpec() const override { return fullSpec; }
@ -517,7 +570,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
return (p.emitError("expected string literal data in dialect symbol"),
nullptr);
symbolData = p.getToken().getStringValue();
loc = p.getToken().getLoc();
loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1);
p.consumeToken(Token::string);
// Consume the '>'.
@ -529,6 +582,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
auto dotHalves = identifier.split('.');
dialectName = dotHalves.first;
auto prettyName = dotHalves.second;
loc = llvm::SMLoc::getFromPointer(prettyName.data());
// If the dialect's symbol is followed immediately by a <, then lex the body
// of it into prettyName.
@ -541,8 +595,16 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
symbolData = prettyName.str();
}
// Record the name location of the type remapped to the top level buffer.
llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc);
p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer);
// Call into the provided symbol construction function.
return createSymbol(dialectName, symbolData, loc);
Symbol sym = createSymbol(dialectName, symbolData, loc);
// Pop the last parser location.
p.getState().symbols.nestedParserLocs.pop_back();
return sym;
}
/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
@ -550,14 +612,14 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
/// string is returned in 'numRead'.
template <typename T, typename ParserFn>
static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
AliasState &aliasState, ParserFn &&parserFn,
SymbolState &symbolState, ParserFn &&parserFn,
size_t *numRead = nullptr) {
SourceMgr sourceMgr;
auto memBuffer = MemoryBuffer::getMemBuffer(
inputStr, /*BufferName=*/"<mlir_parser_buffer>",
/*RequiresNullTerminator=*/false);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
ParserState state(sourceMgr, context, aliasState);
ParserState state(sourceMgr, context, symbolState);
Parser parser(state);
Token startTok = parser.getToken();
@ -573,8 +635,7 @@ static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
// Otherwise, ensure that all of the tokens were parsed.
} else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) {
parser.emitError(endTok.getLoc(),
"encountered unexpected tokens after parsing");
parser.emitError(endTok.getLoc(), "encountered unexpected token");
return T();
}
return symbol;
@ -585,13 +646,12 @@ static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
//===----------------------------------------------------------------------===//
InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) {
auto diag = mlir::emitError(getEncodedSourceLocation(loc), message);
// If we hit a parse error in response to a lexer error, then the lexer
// already reported the error.
if (getToken().is(Token::error))
diag.abandon();
return diag;
return InFlightDiagnostic();
return mlir::emitError(getEncodedSourceLocation(loc), message);
}
//===----------------------------------------------------------------------===//
@ -701,24 +761,22 @@ Type Parser::parseComplexType() {
///
Type Parser::parseExtendedType() {
return parseExtendedSymbol<Type>(
*this, Token::exclamation_identifier, state.aliases.typeAliasDefinitions,
*this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData,
llvm::SMLoc loc) -> Type {
Location encodedLoc = getEncodedSourceLocation(loc);
// If we found a registered dialect, then ask it to parse the type.
if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
return parseSymbol<Type>(
symbolData, state.context, state.aliases, [&](Parser &parser) {
symbolData, state.context, state.symbols, [&](Parser &parser) {
CustomDialectAsmParser customParser(symbolData, parser);
return dialect->parseType(customParser, encodedLoc);
return dialect->parseType(customParser);
});
}
// Otherwise, form a new opaque type.
return OpaqueType::getChecked(
Identifier::get(dialectName, state.context), symbolData,
state.context, encodedLoc);
state.context, getEncodedSourceLocation(loc));
});
}
@ -1315,7 +1373,7 @@ Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
///
Attribute Parser::parseExtendedAttr(Type type) {
Attribute attr = parseExtendedSymbol<Attribute>(
*this, Token::hash_identifier, state.aliases.attributeAliasDefinitions,
*this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData,
llvm::SMLoc loc) -> Attribute {
// Parse an optional trailing colon type.
@ -1326,10 +1384,9 @@ Attribute Parser::parseExtendedAttr(Type type) {
// If we found a registered dialect, then ask it to parse the attribute.
if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
return parseSymbol<Attribute>(
symbolData, state.context, state.aliases, [&](Parser &parser) {
symbolData, state.context, state.symbols, [&](Parser &parser) {
CustomDialectAsmParser customParser(symbolData, parser);
return dialect->parseAttribute(customParser, attrType,
getEncodedSourceLocation(loc));
return dialect->parseAttribute(customParser, attrType);
});
}
@ -4242,7 +4299,7 @@ ParseResult ModuleParser::parseAttributeAliasDef() {
StringRef aliasName = getTokenSpelling().drop_front();
// Check for redefinitions.
if (getState().aliases.attributeAliasDefinitions.count(aliasName) > 0)
if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0)
return emitError("redefinition of attribute alias id '" + aliasName + "'");
// Make sure this isn't invading the dialect attribute namespace.
@ -4261,7 +4318,7 @@ ParseResult ModuleParser::parseAttributeAliasDef() {
if (!attr)
return failure();
getState().aliases.attributeAliasDefinitions[aliasName] = attr;
getState().symbols.attributeAliasDefinitions[aliasName] = attr;
return success();
}
@ -4274,7 +4331,7 @@ ParseResult ModuleParser::parseTypeAliasDef() {
StringRef aliasName = getTokenSpelling().drop_front();
// Check for redefinitions.
if (getState().aliases.typeAliasDefinitions.count(aliasName) > 0)
if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0)
return emitError("redefinition of type alias id '" + aliasName + "'");
// Make sure this isn't invading the dialect type namespace.
@ -4295,7 +4352,7 @@ ParseResult ModuleParser::parseTypeAliasDef() {
return failure();
// Register this alias with the parser state.
getState().aliases.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
return success();
}
@ -4374,7 +4431,7 @@ OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
AliasState aliasState;
SymbolState aliasState;
ParserState state(sourceMgr, context, aliasState);
if (ModuleParser(state).parseModule(*module))
return nullptr;
@ -4440,7 +4497,7 @@ OwningModuleRef mlir::parseSourceString(StringRef moduleStr,
template <typename T, typename ParserFn>
static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
size_t &numRead, ParserFn &&parserFn) {
AliasState aliasState;
SymbolState aliasState;
return parseSymbol<T>(
inputStr, context, aliasState,
[&](Parser &parser) {