[mlir][linalg][bufferize][NFC] Move helper function to op interface
This is in preparation of changing the op traversal during bufferization. Differential Revision: https://reviews.llvm.org/D114040
This commit is contained in:
		
							parent
							
								
									8d0994ed21
								
							
						
					
					
						commit
						26c0dd83ab
					
				| 
						 | 
				
			
			@ -297,6 +297,11 @@ struct BufferizationState {
 | 
			
		|||
/// bufferization is necessary.
 | 
			
		||||
Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
 | 
			
		||||
 | 
			
		||||
/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
 | 
			
		||||
/// function returns immediately. Otherwise, it calls the `bufferize` interface
 | 
			
		||||
/// method of `BufferizableOpInterface`.
 | 
			
		||||
LogicalResult bufferizeOp(Operation *op, BufferizationState &state);
 | 
			
		||||
 | 
			
		||||
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
 | 
			
		||||
/// executed after the analysis, but before bufferization. They can be used
 | 
			
		||||
/// implement custom dialect-specific optimizations.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,9 +24,6 @@ static constexpr int64_t kBufferAlignments = 128;
 | 
			
		|||
/// Return default allocation callbacks.
 | 
			
		||||
std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
 | 
			
		||||
 | 
			
		||||
/// Bufferize one particular op.
 | 
			
		||||
LogicalResult bufferizeOp(Operation *op, BufferizationState &state);
 | 
			
		||||
 | 
			
		||||
/// Register external models implemented for the `BufferizableOpInterface`.
 | 
			
		||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,6 +7,7 @@
 | 
			
		|||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 | 
			
		||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
			
		||||
#include "mlir/IR/AsmState.h"
 | 
			
		||||
#include "mlir/IR/BlockAndValueMapping.h"
 | 
			
		||||
#include "mlir/IR/BuiltinOps.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -390,6 +391,31 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
 | 
			
		|||
  return operandBuffer;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LogicalResult
 | 
			
		||||
mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
 | 
			
		||||
                                                   BufferizationState &state) {
 | 
			
		||||
  OpBuilder b(op->getContext());
 | 
			
		||||
 | 
			
		||||
  // Skip BufferCast and TensorLoad ops.
 | 
			
		||||
  if (isa<memref::BufferCastOp, memref::TensorLoadOp>(op))
 | 
			
		||||
    return success();
 | 
			
		||||
 | 
			
		||||
  // Check if op has tensor results or operands.
 | 
			
		||||
  auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
 | 
			
		||||
  bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
 | 
			
		||||
  bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
 | 
			
		||||
  if (!hasTensorResult && !hasTensorOperand)
 | 
			
		||||
    return success();
 | 
			
		||||
 | 
			
		||||
  // Bufferize using `BufferizableOpInterface`.
 | 
			
		||||
  b.setInsertionPoint(op);
 | 
			
		||||
  if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
 | 
			
		||||
    return bufferizableOp.bufferize(b, state);
 | 
			
		||||
 | 
			
		||||
  // Other op with tensors. No bufferization method specified.
 | 
			
		||||
  return op->emitError() << "unsupported op with tensors";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Bufferization-specific BlockAndValueMapping support with debugging.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRBufferizableOpInterface
 | 
			
		|||
 | 
			
		||||
  LINK_LIBS PUBLIC
 | 
			
		||||
  MLIRIR
 | 
			
		||||
  MLIRMemRef
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -927,30 +927,6 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
 | 
			
		|||
// Bufferization entry-point for functions.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
LogicalResult
 | 
			
		||||
mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
 | 
			
		||||
                                                   BufferizationState &state) {
 | 
			
		||||
  OpBuilder b(op->getContext());
 | 
			
		||||
 | 
			
		||||
  // Skip BufferCast and TensorLoad ops.
 | 
			
		||||
  if (isa<memref::BufferCastOp, memref::TensorLoadOp>(op))
 | 
			
		||||
    return success();
 | 
			
		||||
 | 
			
		||||
  // Check if op has tensor results or operands.
 | 
			
		||||
  auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
 | 
			
		||||
  bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
 | 
			
		||||
  bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
 | 
			
		||||
  if (!hasTensorResult && !hasTensorOperand)
 | 
			
		||||
    return success();
 | 
			
		||||
 | 
			
		||||
  // Bufferize using `BufferizableOpInterface`.
 | 
			
		||||
  if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
 | 
			
		||||
    return bufferizableOp.bufferize(b, state);
 | 
			
		||||
 | 
			
		||||
  // Other op with tensors. No bufferization method specified.
 | 
			
		||||
  return op->emitError() << "unsupported op with tensors";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp,
 | 
			
		||||
                                              BufferizationState &state) {
 | 
			
		||||
  LLVM_DEBUG(llvm::dbgs() << "\n\n");
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6299,6 +6299,7 @@ cc_library(
 | 
			
		|||
    deps = [
 | 
			
		||||
        ":BufferizableOpInterfaceIncGen",
 | 
			
		||||
        ":IR",
 | 
			
		||||
        ":MemRefDialect",
 | 
			
		||||
        ":Support",
 | 
			
		||||
        "//llvm:Support",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue