[mlir] Add GlobalOp, GlobalLoadConstOp to ml_program.

The approach I took was to define a dialect 'extern' attribute that a GlobalOp can take as a value to signify external linkage. I think this approach should compose well and should also work with wherever the OpaqueElements work goes in the future (since that is just another kind of attribute). I special cased the GlobalOp parser/printer for this case because it is significantly easier on the eyes.

In the discussion, Jeff Niu had proposed an alternative syntax for GlobalOp that I ended up not taking. I did try to implement it but a) I don't think it made anything easier to read in the common case, and b) it made the parsing/printing logic a lot more complicated (I think I would need a completely custom parser/printer to do it well). Please have a look at the common cases where the global type and initial value type match: I don't think how I have it is too bad. The less common cases seem ok to me.

I chose to only implement the direct, constant load op since that is non side effecting and there was still discussion pending on that.

Differential Revision: https://reviews.llvm.org/D124318
This commit is contained in:
Stella Laurenzo 2022-04-22 19:59:34 -07:00
parent b21c03854c
commit 2bb252852c
13 changed files with 360 additions and 0 deletions

View File

@ -1,3 +1,10 @@
set(LLVM_TARGET_DEFINITIONS MLProgramOps.td)
add_mlir_dialect(MLProgramOps ml_program)
add_mlir_doc(MLProgramOps MLProgramOps Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS MLProgramAttributes.td)
mlir_tablegen(MLProgramAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(MLProgramAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRMLProgramAttributesIncGen)
add_dependencies(mlir-headers MLIRMLProgramAttributesIncGen)
add_mlir_doc(MLProgramAttributes MLProgramAttributes Dialects/ -gen-attrdef-doc)

View File

@ -8,6 +8,7 @@
#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OpDefinition.h"

View File

@ -0,0 +1,21 @@
//===- MLProgramAttributes.h - Attribute Classes ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_
#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_
#include "mlir/IR/Attributes.h"
//===----------------------------------------------------------------------===//
// Tablegen Attribute Declarations
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h.inc"
#endif // MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_

View File

@ -0,0 +1,44 @@
//===- MLProgramAttributed.td - Attr definitions -----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLPROGRAM_ATTRIBUTES
#define MLPROGRAM_ATTRIBUTES
include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
// Base class for MLProgram dialect attributes.
class MLProgram_Attr<string name, list<Trait> traits = []>
: AttrDef<MLProgram_Dialect, name, traits> {
let mnemonic = ?;
}
//===----------------------------------------------------------------------===//
// ExternAttr
//===----------------------------------------------------------------------===//
def MLProgram_ExternAttr : MLProgram_Attr<"Extern"> {
let summary = "Value used for a global signalling external resolution";
let description = [{
When used as the value for a GlobalOp, this indicates that the actual
value should be resolved externally in an implementation defined manner.
The `sym_name` of the global is the key for locating the value.
Examples:
```mlir
extern : tensor<4xi32>
```
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type);
let mnemonic = "extern";
let assemblyFormat = "";
}
#endif // MLPROGRAM_ATTRIBUTES

View File

@ -27,6 +27,7 @@ def MLProgram_Dialect : Dialect {
it is recommended to inquire further prior to using this dialect.
}];
let useDefaultAttributePrinterParser = 1;
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}

View File

@ -96,6 +96,101 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
def MLProgram_GlobalOp : MLProgram_Op<"global", [
Symbol
]> {
let summary = "Module level declaration of a global variable";
let description = [{
Declares a named global variable (or constant).
A global contains a value of a specified type which can be accessed at
runtime via appropriate load/store operations. It can be mutable or
constant, optionally taking an initial value or declared as
extern (in which case, the initial value is found in external storage
by symbol name).
Generally, the type of the global and the type of the initial value
will be the same. However, for type hierarchies which can have a more
generalized bounding type that can be assigned from a narrow type, this
is allowed (but not verified).
Examples:
```mlir
// Constant global.
ml_program.global @foobar(dense<4> : tensor<4xi32>) : tensor<?xi32>
// Constant with external linkage.
ml_program.global mutable @foobar(#ml_program.extern<tensor<4xi32>>)
: tensor<?xi32>
// Mutable global with an undefined initial value.
ml_program.global mutable @foobar : tensor<?xi32>
```
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttr:$type,
UnitAttr:$is_mutable,
OptionalAttr<AnyAttr>:$value,
OptionalAttr<StrAttr>:$sym_visibility
);
let assemblyFormat = [{
custom<SymbolVisibility>($sym_visibility)
(`mutable` $is_mutable^)?
$sym_name ``
custom<TypedInitialValue>($type, $value)
attr-dict
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// GlobalLoadConstOp
//===----------------------------------------------------------------------===//
def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [
NoSideEffect,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Direct load a constant value from a global";
let description = [{
Loads a constant (immutable) value from a global directly by symbol.
This op is only legal for globals that are not mutable and exists because
such a load can be considered to have no side effects.
Example:
```mlir
%0 = ml_program.global_load_const @foobar : tensor<?xi32>
```
}];
let arguments = (ins
FlatSymbolRefAttr:$global
);
let results = (outs
AnyType:$result
);
let assemblyFormat = [{
$global attr-dict `:` type($result)
}];
let extraClassDeclaration = [{
/// Gets the corresponding GlobalOp (or nullptr).
GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
}];
}
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//

View File

@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRMLProgram
DEPENDS
MLIRMLProgramOpsIncGen
MLIRMLProgramAttributesIncGen
LINK_LIBS PUBLIC
MLIRDialect

View File

@ -7,15 +7,42 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::ml_program;
//===----------------------------------------------------------------------===//
/// Tablegen Definitions
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc"
namespace {
struct MLProgramOpAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (attr.isa<ExternAttr>()) {
os << "extern";
return AliasResult::OverridableAlias;
}
return AliasResult::NoAlias;
}
};
} // namespace
void ml_program::MLProgramDialect::initialize() {
#define GET_ATTRDEF_LIST
addAttributes<
#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
>();
addInterfaces<MLProgramOpAsmDialectInterface>();
}

View File

@ -13,6 +13,69 @@
using namespace mlir;
using namespace mlir::ml_program;
//===----------------------------------------------------------------------===//
// Custom asm helpers
//===----------------------------------------------------------------------===//
/// some.op custom<TypeOrAttr>($type, $attr)
///
/// Uninitialized:
/// some.op : tensor<3xi32>
/// Initialized to narrower type than op:
/// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
static ParseResult parseTypedInitialValue(OpAsmParser &parser,
TypeAttr &typeAttr, Attribute &attr) {
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseAttribute(attr)))
return failure();
if (failed(parser.parseRParen()))
return failure();
}
Type type;
if (failed(parser.parseColonType(type)))
return failure();
typeAttr = TypeAttr::get(type);
return success();
}
static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
TypeAttr type, Attribute attr) {
if (attr) {
p << "(";
p.printAttribute(attr);
p << ")";
}
p << " : ";
p.printAttribute(type);
}
/// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
/// ->
/// some.op public @foo
/// some.op private @foo
static ParseResult parseSymbolVisibility(OpAsmParser &parser,
StringAttr &symVisibilityAttr) {
StringRef symVisibility;
(void)parser.parseOptionalKeyword(&symVisibility,
{"public", "private", "nested"});
if (symVisibility.empty())
return parser.emitError(parser.getCurrentLocation())
<< "expected 'public', 'private', or 'nested'";
if (!symVisibility.empty())
symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
return success();
}
static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
StringAttr symVisibilityAttr) {
if (!symVisibilityAttr)
p << "public";
else
p << symVisibilityAttr.getValue();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
@ -38,6 +101,43 @@ void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
LogicalResult GlobalOp::verify() {
if (!getIsMutable() && !getValue())
return emitOpError() << "immutable global must have an initial value";
return success();
}
//===----------------------------------------------------------------------===//
// GlobalLoadConstOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (referrent.getIsMutable())
return emitOpError() << "cannot load as const from mutable global "
<< getGlobal();
if (referrent.getType() != getResult().getType())
return emitOpError() << "cannot load from global typed "
<< referrent.getType() << " as "
<< getResult().getType();
return success();
}
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,7 @@
// RUN: mlir-opt %s --allow-unregistered-dialect | mlir-opt --allow-unregistered-dialect | FileCheck %s
// CHECK: #ml_program.extern : i32
"unregistered.attributes"() {
value = #ml_program.extern : i32
} : () -> ()

View File

@ -31,3 +31,30 @@ ml_program.subgraph @output_type_match(%arg0 : i64) -> i32 {
// expected-error @+1 {{doesn't match function result}}
ml_program.output %arg0 : i64
}
// -----
// expected-error @+1 {{immutable global must have an initial value}}
ml_program.global private @const : i32
// -----
ml_program.func @undef_global() -> i32 {
// expected-error @+1 {{undefined global: nothere}}
%0 = ml_program.global_load_const @nothere : i32
ml_program.return %0 : i32
}
// -----
ml_program.global private mutable @var : i32
ml_program.func @mutable_const_load() -> i32 {
// expected-error @+1 {{op cannot load as const from mutable global var}}
%0 = ml_program.global_load_const @var : i32
ml_program.return %0 : i32
}
// -----
ml_program.global private @var(42 : i64) : i64
ml_program.func @const_load_type_mismatch() -> i32 {
// expected-error @+1 {{cannot load from global typed 'i64' as 'i32'}}
%0 = ml_program.global_load_const @var : i32
ml_program.return %0 : i32
}

View File

@ -18,3 +18,12 @@ ml_program.subgraph @compute_subgraph(%arg0 : i32) -> i32 {
%0 = "unregistered.dummy"(%arg0) : (i32) -> i32
ml_program.output %0 : i32
}
// CHECK: ml_program.global private @global_same_type(dense<4> : tensor<4xi32>) : tensor<4xi32>
ml_program.global private @global_same_type(dense<4> : tensor<4xi32>) : tensor<4xi32>
// CHECK: ml_program.global private mutable @global_mutable_undef : tensor<?xi32>
ml_program.global private mutable @global_mutable_undef : tensor<?xi32>
// CHECK: ml_program.global private mutable @global_extern(#extern) : tensor<?xi32>
ml_program.global private mutable @global_extern(#ml_program.extern : tensor<4xi32>) : tensor<?xi32>

View File

@ -8560,6 +8560,7 @@ td_library(
name = "MLProgramOpsTdFiles",
srcs = [
"include/mlir/Dialect/MLProgram/IR/MLProgramBase.td",
"include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td",
"include/mlir/Dialect/MLProgram/IR/MLProgramOps.td",
],
includes = ["include"],
@ -8599,6 +8600,24 @@ gentbl_cc_library(
deps = [":MLProgramOpsTdFiles"],
)
gentbl_cc_library(
name = "MLProgramAttributesIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-attrdef-decls"],
"include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h.inc",
),
(
["-gen-attrdef-defs"],
"include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td",
deps = [":MLProgramOpsTdFiles"],
)
cc_library(
name = "MLProgramDialect",
srcs = glob([
@ -8612,6 +8631,7 @@ cc_library(
deps = [
":ControlFlowInterfaces",
":IR",
":MLProgramAttributesIncGen",
":MLProgramOpsIncGen",
":Pass",
":Support",