[mlir][linalg][bufferize][NFC] Move arith interface impl to new build target
This makes ComprehensiveBufferize entirely independent of the arith dialect. Differential Revision: https://reviews.llvm.org/D114219
This commit is contained in:
parent
7bd87a03fd
commit
d3bb4fec2a
|
|
@ -0,0 +1,27 @@
|
||||||
|
//===- ArithInterfaceImpl.h - Arith Impl. of BufferizableOpInterface ------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
|
||||||
|
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class DialectRegistry;
|
||||||
|
|
||||||
|
namespace linalg {
|
||||||
|
namespace comprehensive_bufferize {
|
||||||
|
namespace arith_ext {
|
||||||
|
|
||||||
|
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||||
|
|
||||||
|
} // namespace arith_ext
|
||||||
|
} // namespace comprehensive_bufferize
|
||||||
|
} // namespace linalg
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/Transforms/BufferUtils.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace linalg {
|
||||||
|
namespace comprehensive_bufferize {
|
||||||
|
namespace arith_ext {
|
||||||
|
|
||||||
|
struct ConstantOpInterface
|
||||||
|
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
|
||||||
|
arith::ConstantOp> {
|
||||||
|
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||||
|
OpResult opResult) const {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||||
|
BufferizationState &state) const {
|
||||||
|
auto constantOp = cast<arith::ConstantOp>(op);
|
||||||
|
if (!constantOp.getResult().getType().isa<TensorType>())
|
||||||
|
return success();
|
||||||
|
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
|
||||||
|
"not a constant ranked tensor");
|
||||||
|
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp) {
|
||||||
|
return constantOp.emitError(
|
||||||
|
"cannot bufferize constants not within builtin.module op");
|
||||||
|
}
|
||||||
|
GlobalCreator globalCreator(moduleOp);
|
||||||
|
|
||||||
|
// Take a guard before anything else.
|
||||||
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
b.setInsertionPoint(constantOp);
|
||||||
|
|
||||||
|
auto globalMemref = globalCreator.getGlobalFor(constantOp);
|
||||||
|
Value memref = b.create<memref::GetGlobalOp>(
|
||||||
|
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
|
||||||
|
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
|
||||||
|
state.mapBuffer(constantOp, memref);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isWritable(Operation *op, Value value) const {
|
||||||
|
// Memory locations returned by memref::GetGlobalOp may not be written to.
|
||||||
|
assert(value.isa<OpResult>());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace arith_ext
|
||||||
|
} // namespace comprehensive_bufferize
|
||||||
|
} // namespace linalg
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
void mlir::linalg::comprehensive_bufferize::arith_ext::
|
||||||
|
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||||
|
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
set(LLVM_OPTIONAL_SOURCES
|
set(LLVM_OPTIONAL_SOURCES
|
||||||
|
ArithInterfaceImpl.cpp
|
||||||
BufferizableOpInterface.cpp
|
BufferizableOpInterface.cpp
|
||||||
ComprehensiveBufferize.cpp
|
ComprehensiveBufferize.cpp
|
||||||
LinalgInterfaceImpl.cpp
|
LinalgInterfaceImpl.cpp
|
||||||
|
|
@ -17,6 +18,17 @@ add_mlir_dialect_library(MLIRBufferizableOpInterface
|
||||||
MLIRMemRef
|
MLIRMemRef
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
|
||||||
|
ArithInterfaceImpl.cpp
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRArithmetic
|
||||||
|
MLIRBufferizableOpInterface
|
||||||
|
MLIRIR
|
||||||
|
MLIRMemRef
|
||||||
|
MLIRStandardOpsTransforms
|
||||||
|
)
|
||||||
|
|
||||||
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
|
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
|
||||||
LinalgInterfaceImpl.cpp
|
LinalgInterfaceImpl.cpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -116,16 +116,17 @@
|
||||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||||
#include "mlir/IR/AsmState.h"
|
#include "mlir/IR/AsmState.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "mlir/IR/Dominance.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/BufferUtils.h"
|
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/ADT/ScopeExit.h"
|
#include "llvm/ADT/ScopeExit.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "comprehensive-module-bufferize"
|
#define DEBUG_TYPE "comprehensive-module-bufferize"
|
||||||
|
|
@ -1287,52 +1288,6 @@ BufferizationOptions::BufferizationOptions()
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace linalg {
|
namespace linalg {
|
||||||
namespace comprehensive_bufferize {
|
namespace comprehensive_bufferize {
|
||||||
namespace arith_ext {
|
|
||||||
|
|
||||||
struct ConstantOpInterface
|
|
||||||
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
|
|
||||||
arith::ConstantOp> {
|
|
||||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
|
||||||
OpResult opResult) const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
|
||||||
BufferizationState &state) const {
|
|
||||||
auto constantOp = cast<arith::ConstantOp>(op);
|
|
||||||
if (!isaTensor(constantOp.getResult().getType()))
|
|
||||||
return success();
|
|
||||||
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
|
|
||||||
"not a constant ranked tensor");
|
|
||||||
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
|
|
||||||
if (!moduleOp) {
|
|
||||||
return constantOp.emitError(
|
|
||||||
"cannot bufferize constants not within builtin.module op");
|
|
||||||
}
|
|
||||||
GlobalCreator globalCreator(moduleOp);
|
|
||||||
|
|
||||||
// Take a guard before anything else.
|
|
||||||
OpBuilder::InsertionGuard g(b);
|
|
||||||
b.setInsertionPoint(constantOp);
|
|
||||||
|
|
||||||
auto globalMemref = globalCreator.getGlobalFor(constantOp);
|
|
||||||
Value memref = b.create<memref::GetGlobalOp>(
|
|
||||||
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
|
|
||||||
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
|
|
||||||
state.mapBuffer(constantOp, memref);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isWritable(Operation *op, Value value) const {
|
|
||||||
// Memory locations returned by memref::GetGlobalOp may not be written to.
|
|
||||||
assert(value.isa<OpResult>());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace arith_ext
|
|
||||||
|
|
||||||
namespace scf_ext {
|
namespace scf_ext {
|
||||||
|
|
||||||
struct ExecuteRegionOpInterface
|
struct ExecuteRegionOpInterface
|
||||||
|
|
@ -1813,7 +1768,6 @@ struct ReturnOpInterface
|
||||||
} // namespace std_ext
|
} // namespace std_ext
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||||
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
|
|
||||||
registry.addOpInterface<scf::ExecuteRegionOp,
|
registry.addOpInterface<scf::ExecuteRegionOp,
|
||||||
scf_ext::ExecuteRegionOpInterface>();
|
scf_ext::ExecuteRegionOpInterface>();
|
||||||
registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();
|
registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
||||||
MLIRAffine
|
MLIRAffine
|
||||||
MLIRAffineUtils
|
MLIRAffineUtils
|
||||||
MLIRAnalysis
|
MLIRAnalysis
|
||||||
|
MLIRArithBufferizableOpInterfaceImpl
|
||||||
MLIRArithmetic
|
MLIRArithmetic
|
||||||
MLIRBufferizableOpInterface
|
MLIRBufferizableOpInterface
|
||||||
MLIRComplex
|
MLIRComplex
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
|
||||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||||
|
|
@ -39,6 +40,7 @@ struct LinalgComprehensiveModuleBufferize
|
||||||
tensor::TensorDialect, vector::VectorDialect, scf::SCFDialect,
|
tensor::TensorDialect, vector::VectorDialect, scf::SCFDialect,
|
||||||
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
|
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
|
||||||
registerBufferizableOpInterfaceExternalModels(registry);
|
registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
|
|
||||||
|
|
@ -6306,6 +6306,26 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "ArithBufferizableOpInterfaceImpl",
|
||||||
|
srcs = [
|
||||||
|
"lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h",
|
||||||
|
],
|
||||||
|
includes = ["include"],
|
||||||
|
deps = [
|
||||||
|
":ArithmeticDialect",
|
||||||
|
":BufferizableOpInterface",
|
||||||
|
":IR",
|
||||||
|
":MemRefDialect",
|
||||||
|
":Support",
|
||||||
|
":TransformUtils",
|
||||||
|
"//llvm:Support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "LinalgBufferizableOpInterfaceImpl",
|
name = "LinalgBufferizableOpInterfaceImpl",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
@ -6563,6 +6583,7 @@ cc_library(
|
||||||
":Affine",
|
":Affine",
|
||||||
":AffineUtils",
|
":AffineUtils",
|
||||||
":Analysis",
|
":Analysis",
|
||||||
|
":ArithBufferizableOpInterfaceImpl",
|
||||||
":ArithmeticDialect",
|
":ArithmeticDialect",
|
||||||
":BufferizableOpInterface",
|
":BufferizableOpInterface",
|
||||||
":ComplexDialect",
|
":ComplexDialect",
|
||||||
|
|
@ -6604,7 +6625,6 @@ cc_library(
|
||||||
includes = ["include"],
|
includes = ["include"],
|
||||||
deps = [
|
deps = [
|
||||||
":Affine",
|
":Affine",
|
||||||
":ArithmeticDialect",
|
|
||||||
":BufferizableOpInterface",
|
":BufferizableOpInterface",
|
||||||
":DialectUtils",
|
":DialectUtils",
|
||||||
":IR",
|
":IR",
|
||||||
|
|
@ -6614,7 +6634,6 @@ cc_library(
|
||||||
":SCFDialect",
|
":SCFDialect",
|
||||||
":StandardOps",
|
":StandardOps",
|
||||||
":Support",
|
":Support",
|
||||||
":TransformUtils",
|
|
||||||
"//llvm:Support",
|
"//llvm:Support",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue