[mlir][linalg][bufferize][NFC] Move arith interface impl to new build target
This makes ComprehensiveBufferize entirely independent of the arith dialect. Differential Revision: https://reviews.llvm.org/D114219
This commit is contained in:
		
							parent
							
								
									7bd87a03fd
								
							
						
					
					
						commit
						d3bb4fec2a
					
				| 
						 | 
				
			
			@ -0,0 +1,27 @@
 | 
			
		|||
//===- ArithInterfaceImpl.h - Arith Impl. of BufferizableOpInterface ------===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
 | 
			
		||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
class DialectRegistry;
 | 
			
		||||
 | 
			
		||||
namespace linalg {
 | 
			
		||||
namespace comprehensive_bufferize {
 | 
			
		||||
namespace arith_ext {
 | 
			
		||||
 | 
			
		||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
 | 
			
		||||
 | 
			
		||||
} // namespace arith_ext
 | 
			
		||||
} // namespace comprehensive_bufferize
 | 
			
		||||
} // namespace linalg
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,73 @@
 | 
			
		|||
//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
 | 
			
		||||
 | 
			
		||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 | 
			
		||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 | 
			
		||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
			
		||||
#include "mlir/IR/Dialect.h"
 | 
			
		||||
#include "mlir/IR/Operation.h"
 | 
			
		||||
#include "mlir/Transforms/BufferUtils.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace linalg {
 | 
			
		||||
namespace comprehensive_bufferize {
 | 
			
		||||
namespace arith_ext {
 | 
			
		||||
 | 
			
		||||
struct ConstantOpInterface
 | 
			
		||||
    : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
 | 
			
		||||
                                                    arith::ConstantOp> {
 | 
			
		||||
  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
 | 
			
		||||
                                                OpResult opResult) const {
 | 
			
		||||
    return {};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  LogicalResult bufferize(Operation *op, OpBuilder &b,
 | 
			
		||||
                          BufferizationState &state) const {
 | 
			
		||||
    auto constantOp = cast<arith::ConstantOp>(op);
 | 
			
		||||
    if (!constantOp.getResult().getType().isa<TensorType>())
 | 
			
		||||
      return success();
 | 
			
		||||
    assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
 | 
			
		||||
           "not a constant ranked tensor");
 | 
			
		||||
    auto moduleOp = constantOp->getParentOfType<ModuleOp>();
 | 
			
		||||
    if (!moduleOp) {
 | 
			
		||||
      return constantOp.emitError(
 | 
			
		||||
          "cannot bufferize constants not within builtin.module op");
 | 
			
		||||
    }
 | 
			
		||||
    GlobalCreator globalCreator(moduleOp);
 | 
			
		||||
 | 
			
		||||
    // Take a guard before anything else.
 | 
			
		||||
    OpBuilder::InsertionGuard g(b);
 | 
			
		||||
    b.setInsertionPoint(constantOp);
 | 
			
		||||
 | 
			
		||||
    auto globalMemref = globalCreator.getGlobalFor(constantOp);
 | 
			
		||||
    Value memref = b.create<memref::GetGlobalOp>(
 | 
			
		||||
        constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
 | 
			
		||||
    state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
 | 
			
		||||
    state.mapBuffer(constantOp, memref);
 | 
			
		||||
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool isWritable(Operation *op, Value value) const {
 | 
			
		||||
    // Memory locations returned by memref::GetGlobalOp may not be written to.
 | 
			
		||||
    assert(value.isa<OpResult>());
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace arith_ext
 | 
			
		||||
} // namespace comprehensive_bufferize
 | 
			
		||||
} // namespace linalg
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
void mlir::linalg::comprehensive_bufferize::arith_ext::
 | 
			
		||||
    registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
 | 
			
		||||
  registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
set(LLVM_OPTIONAL_SOURCES
 | 
			
		||||
  ArithInterfaceImpl.cpp
 | 
			
		||||
  BufferizableOpInterface.cpp
 | 
			
		||||
  ComprehensiveBufferize.cpp
 | 
			
		||||
  LinalgInterfaceImpl.cpp
 | 
			
		||||
| 
						 | 
				
			
			@ -17,6 +18,17 @@ add_mlir_dialect_library(MLIRBufferizableOpInterface
 | 
			
		|||
  MLIRMemRef
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
 | 
			
		||||
  ArithInterfaceImpl.cpp
 | 
			
		||||
 | 
			
		||||
  LINK_LIBS PUBLIC
 | 
			
		||||
  MLIRArithmetic
 | 
			
		||||
  MLIRBufferizableOpInterface
 | 
			
		||||
  MLIRIR
 | 
			
		||||
  MLIRMemRef
 | 
			
		||||
  MLIRStandardOpsTransforms
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
 | 
			
		||||
  LinalgInterfaceImpl.cpp
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -116,16 +116,17 @@
 | 
			
		|||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
 | 
			
		||||
#include "mlir/IR/AsmState.h"
 | 
			
		||||
#include "mlir/IR/BlockAndValueMapping.h"
 | 
			
		||||
#include "mlir/IR/Dominance.h"
 | 
			
		||||
#include "mlir/IR/Operation.h"
 | 
			
		||||
#include "mlir/IR/TypeUtilities.h"
 | 
			
		||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
 | 
			
		||||
#include "mlir/Pass/Pass.h"
 | 
			
		||||
#include "mlir/Pass/PassManager.h"
 | 
			
		||||
#include "mlir/Transforms/BufferUtils.h"
 | 
			
		||||
#include "llvm/ADT/DenseSet.h"
 | 
			
		||||
#include "llvm/ADT/ScopeExit.h"
 | 
			
		||||
#include "llvm/ADT/SetVector.h"
 | 
			
		||||
#include "llvm/ADT/TypeSwitch.h"
 | 
			
		||||
#include "llvm/Support/Debug.h"
 | 
			
		||||
#include "llvm/Support/FormatVariadic.h"
 | 
			
		||||
 | 
			
		||||
#define DEBUG_TYPE "comprehensive-module-bufferize"
 | 
			
		||||
| 
						 | 
				
			
			@ -1287,52 +1288,6 @@ BufferizationOptions::BufferizationOptions()
 | 
			
		|||
namespace mlir {
 | 
			
		||||
namespace linalg {
 | 
			
		||||
namespace comprehensive_bufferize {
 | 
			
		||||
namespace arith_ext {
 | 
			
		||||
 | 
			
		||||
struct ConstantOpInterface
 | 
			
		||||
    : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
 | 
			
		||||
                                                    arith::ConstantOp> {
 | 
			
		||||
  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
 | 
			
		||||
                                                OpResult opResult) const {
 | 
			
		||||
    return {};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  LogicalResult bufferize(Operation *op, OpBuilder &b,
 | 
			
		||||
                          BufferizationState &state) const {
 | 
			
		||||
    auto constantOp = cast<arith::ConstantOp>(op);
 | 
			
		||||
    if (!isaTensor(constantOp.getResult().getType()))
 | 
			
		||||
      return success();
 | 
			
		||||
    assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
 | 
			
		||||
           "not a constant ranked tensor");
 | 
			
		||||
    auto moduleOp = constantOp->getParentOfType<ModuleOp>();
 | 
			
		||||
    if (!moduleOp) {
 | 
			
		||||
      return constantOp.emitError(
 | 
			
		||||
          "cannot bufferize constants not within builtin.module op");
 | 
			
		||||
    }
 | 
			
		||||
    GlobalCreator globalCreator(moduleOp);
 | 
			
		||||
 | 
			
		||||
    // Take a guard before anything else.
 | 
			
		||||
    OpBuilder::InsertionGuard g(b);
 | 
			
		||||
    b.setInsertionPoint(constantOp);
 | 
			
		||||
 | 
			
		||||
    auto globalMemref = globalCreator.getGlobalFor(constantOp);
 | 
			
		||||
    Value memref = b.create<memref::GetGlobalOp>(
 | 
			
		||||
        constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
 | 
			
		||||
    state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
 | 
			
		||||
    state.mapBuffer(constantOp, memref);
 | 
			
		||||
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool isWritable(Operation *op, Value value) const {
 | 
			
		||||
    // Memory locations returned by memref::GetGlobalOp may not be written to.
 | 
			
		||||
    assert(value.isa<OpResult>());
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace arith_ext
 | 
			
		||||
 | 
			
		||||
namespace scf_ext {
 | 
			
		||||
 | 
			
		||||
struct ExecuteRegionOpInterface
 | 
			
		||||
| 
						 | 
				
			
			@ -1813,7 +1768,6 @@ struct ReturnOpInterface
 | 
			
		|||
} // namespace std_ext
 | 
			
		||||
 | 
			
		||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
 | 
			
		||||
  registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
 | 
			
		||||
  registry.addOpInterface<scf::ExecuteRegionOp,
 | 
			
		||||
                          scf_ext::ExecuteRegionOpInterface>();
 | 
			
		||||
  registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
 | 
			
		|||
  MLIRAffine
 | 
			
		||||
  MLIRAffineUtils
 | 
			
		||||
  MLIRAnalysis
 | 
			
		||||
  MLIRArithBufferizableOpInterfaceImpl
 | 
			
		||||
  MLIRArithmetic
 | 
			
		||||
  MLIRBufferizableOpInterface
 | 
			
		||||
  MLIRComplex
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,6 +7,7 @@
 | 
			
		|||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "PassDetail.h"
 | 
			
		||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
 | 
			
		||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 | 
			
		||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
 | 
			
		||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -39,6 +40,7 @@ struct LinalgComprehensiveModuleBufferize
 | 
			
		|||
                tensor::TensorDialect, vector::VectorDialect, scf::SCFDialect,
 | 
			
		||||
                arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
 | 
			
		||||
    registerBufferizableOpInterfaceExternalModels(registry);
 | 
			
		||||
    arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
 | 
			
		||||
    linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
 | 
			
		||||
    tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
 | 
			
		||||
    vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6306,6 +6306,26 @@ cc_library(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "ArithBufferizableOpInterfaceImpl",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp",
 | 
			
		||||
    ],
 | 
			
		||||
    hdrs = [
 | 
			
		||||
        "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h",
 | 
			
		||||
    ],
 | 
			
		||||
    includes = ["include"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":ArithmeticDialect",
 | 
			
		||||
        ":BufferizableOpInterface",
 | 
			
		||||
        ":IR",
 | 
			
		||||
        ":MemRefDialect",
 | 
			
		||||
        ":Support",
 | 
			
		||||
        ":TransformUtils",
 | 
			
		||||
        "//llvm:Support",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "LinalgBufferizableOpInterfaceImpl",
 | 
			
		||||
    srcs = [
 | 
			
		||||
| 
						 | 
				
			
			@ -6563,6 +6583,7 @@ cc_library(
 | 
			
		|||
        ":Affine",
 | 
			
		||||
        ":AffineUtils",
 | 
			
		||||
        ":Analysis",
 | 
			
		||||
        ":ArithBufferizableOpInterfaceImpl",
 | 
			
		||||
        ":ArithmeticDialect",
 | 
			
		||||
        ":BufferizableOpInterface",
 | 
			
		||||
        ":ComplexDialect",
 | 
			
		||||
| 
						 | 
				
			
			@ -6604,7 +6625,6 @@ cc_library(
 | 
			
		|||
    includes = ["include"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":Affine",
 | 
			
		||||
        ":ArithmeticDialect",
 | 
			
		||||
        ":BufferizableOpInterface",
 | 
			
		||||
        ":DialectUtils",
 | 
			
		||||
        ":IR",
 | 
			
		||||
| 
						 | 
				
			
			@ -6614,7 +6634,6 @@ cc_library(
 | 
			
		|||
        ":SCFDialect",
 | 
			
		||||
        ":StandardOps",
 | 
			
		||||
        ":Support",
 | 
			
		||||
        ":TransformUtils",
 | 
			
		||||
        "//llvm:Support",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue