[mlir][linalg][bufferize][NFC] Move tensor interface impl to new build target
This makes ComprehensiveBufferize entirely independent of the tensor dialect. Differential Revision: https://reviews.llvm.org/D114217
This commit is contained in:
parent
8ef460fc51
commit
bb273a35a0
|
|
@ -322,6 +322,26 @@ struct PostAnalysisStep {
|
||||||
SmallVector<Operation *> &newOps) = 0;
|
SmallVector<Operation *> &newOps) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
||||||
|
/// with the same shape as `shapedType` and specified `layout` and
|
||||||
|
/// `addressSpace`.
|
||||||
|
MemRefType getContiguousMemRefType(ShapedType shapedType,
|
||||||
|
MemRefLayoutAttrInterface layout = {},
|
||||||
|
Attribute memorySpace = {});
|
||||||
|
|
||||||
|
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
||||||
|
/// with the same shape as `shapedType` and specified `layout` and
|
||||||
|
/// `addressSpace` or an UnrankedMemRefType otherwise.
|
||||||
|
Type getContiguousOrUnrankedMemRefType(Type type,
|
||||||
|
MemRefLayoutAttrInterface layout = {},
|
||||||
|
Attribute memorySpace = {});
|
||||||
|
|
||||||
|
/// Return a MemRefType to which the `tensorType` can be bufferized in a
|
||||||
|
/// composable fashion. The layout must be the most dynamic possible and
|
||||||
|
/// canonicalize away once bufferization is finished.
|
||||||
|
MemRefType getDynamicMemRefType(RankedTensorType tensorType,
|
||||||
|
unsigned addressSpace = 0);
|
||||||
|
|
||||||
} // namespace comprehensive_bufferize
|
} // namespace comprehensive_bufferize
|
||||||
} // namespace linalg
|
} // namespace linalg
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,11 @@
|
||||||
|
//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
|
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
|
||||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
|
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
|
||||||
|
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class DialectRegistry;
|
||||||
|
|
||||||
|
namespace linalg {
|
||||||
|
namespace comprehensive_bufferize {
|
||||||
|
namespace tensor_ext {
|
||||||
|
|
||||||
|
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||||
|
|
||||||
|
} // namespace tensor_ext
|
||||||
|
} // namespace comprehensive_bufferize
|
||||||
|
} // namespace linalg
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
|
||||||
|
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
|
|
@ -528,3 +529,31 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
op->erase();
|
op->erase();
|
||||||
obsoleteOps.clear();
|
obsoleteOps.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
|
||||||
|
ShapedType shapedType, MemRefLayoutAttrInterface layout,
|
||||||
|
Attribute memorySpace) {
|
||||||
|
return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
|
||||||
|
layout, memorySpace);
|
||||||
|
}
|
||||||
|
|
||||||
|
Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType(
|
||||||
|
Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
|
||||||
|
if (type.isa<RankedTensorType, MemRefType>())
|
||||||
|
return getContiguousMemRefType(type.cast<ShapedType>(), layout,
|
||||||
|
memorySpace);
|
||||||
|
assert(!layout && "expected empty layout with UnrankedMemRefType");
|
||||||
|
return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
|
||||||
|
}
|
||||||
|
|
||||||
|
MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType(
|
||||||
|
RankedTensorType tensorType, unsigned addressSpace) {
|
||||||
|
// TODO: address space decisions to connect with the actual alloc.
|
||||||
|
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
|
||||||
|
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
|
||||||
|
ShapedType::kDynamicStrideOrOffset);
|
||||||
|
AffineMap stridedLayout = makeStridedLinearLayoutMap(
|
||||||
|
dynamicStrides, dynamicOffset, tensorType.getContext());
|
||||||
|
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
|
||||||
|
stridedLayout, addressSpace);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
|
||||||
BufferizableOpInterface.cpp
|
BufferizableOpInterface.cpp
|
||||||
ComprehensiveBufferize.cpp
|
ComprehensiveBufferize.cpp
|
||||||
LinalgInterfaceImpl.cpp
|
LinalgInterfaceImpl.cpp
|
||||||
|
TensorInterfaceImpl.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
add_mlir_dialect_library(MLIRBufferizableOpInterface
|
add_mlir_dialect_library(MLIRBufferizableOpInterface
|
||||||
|
|
@ -25,6 +26,16 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
|
||||||
MLIRTensor
|
MLIRTensor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
|
||||||
|
TensorInterfaceImpl.cpp
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRBufferizableOpInterface
|
||||||
|
MLIRIR
|
||||||
|
MLIRMemRef
|
||||||
|
MLIRTensor
|
||||||
|
)
|
||||||
|
|
||||||
add_mlir_dialect_library(MLIRComprehensiveBufferize
|
add_mlir_dialect_library(MLIRComprehensiveBufferize
|
||||||
ComprehensiveBufferize.cpp
|
ComprehensiveBufferize.cpp
|
||||||
|
|
||||||
|
|
@ -37,6 +48,5 @@ add_mlir_dialect_library(MLIRComprehensiveBufferize
|
||||||
MLIRSCF
|
MLIRSCF
|
||||||
MLIRStandard
|
MLIRStandard
|
||||||
MLIRStandardOpsTransforms
|
MLIRStandardOpsTransforms
|
||||||
MLIRTensor
|
|
||||||
MLIRVector
|
MLIRVector
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -587,45 +587,6 @@ getEquivalentEnclosingFuncBBArg(Value v,
|
||||||
// Bufferization-specific MemRefType support.
|
// Bufferization-specific MemRefType support.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
|
||||||
/// with the same shape as `shapedType` and specified `layout` and
|
|
||||||
/// `addressSpace`.
|
|
||||||
static MemRefType getContiguousMemRefType(ShapedType shapedType,
|
|
||||||
MemRefLayoutAttrInterface layout = {},
|
|
||||||
Attribute memorySpace = {}) {
|
|
||||||
return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
|
|
||||||
layout, memorySpace);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
|
||||||
/// with the same shape as `shapedType` and specified `layout` and
|
|
||||||
/// `addressSpace` or an UnrankedMemRefType otherwise.
|
|
||||||
static Type
|
|
||||||
getContiguousOrUnrankedMemRefType(Type type,
|
|
||||||
MemRefLayoutAttrInterface layout = {},
|
|
||||||
Attribute memorySpace = {}) {
|
|
||||||
if (type.isa<RankedTensorType, MemRefType>())
|
|
||||||
return getContiguousMemRefType(type.cast<ShapedType>(), layout,
|
|
||||||
memorySpace);
|
|
||||||
assert(!layout && "expected empty layout with UnrankedMemRefType");
|
|
||||||
return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return a MemRefType to which the `tensorType` can be bufferized in a
|
|
||||||
/// composable fashion. The layout must be the most dynamic possible and
|
|
||||||
/// canonicalize away once bufferization is finished.
|
|
||||||
static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
|
|
||||||
unsigned addressSpace = 0) {
|
|
||||||
// TODO: address space decisions to connect with the actual alloc.
|
|
||||||
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
|
|
||||||
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
|
|
||||||
ShapedType::kDynamicStrideOrOffset);
|
|
||||||
AffineMap stridedLayout = makeStridedLinearLayoutMap(
|
|
||||||
dynamicStrides, dynamicOffset, tensorType.getContext());
|
|
||||||
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
|
|
||||||
stridedLayout, addressSpace);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
|
/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
|
||||||
/// tensor is replaced by the corresponding buffer type.
|
/// tensor is replaced by the corresponding buffer type.
|
||||||
/// In order for all the callers to agree, this *must* bufferize to the most
|
/// In order for all the callers to agree, this *must* bufferize to the most
|
||||||
|
|
@ -1965,420 +1926,6 @@ struct ReturnOpInterface
|
||||||
|
|
||||||
} // namespace std_ext
|
} // namespace std_ext
|
||||||
|
|
||||||
namespace tensor_ext {
|
|
||||||
|
|
||||||
struct CastOpInterface
|
|
||||||
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
|
|
||||||
tensor::CastOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
|
||||||
OpResult opResult) const {
|
|
||||||
return {&op->getOpOperand(0)};
|
|
||||||
}
|
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return op->getResult(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return BufferRelation::Equivalent;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
|
||||||
BufferizationState &state) const {
|
|
||||||
auto castOp = cast<tensor::CastOp>(op);
|
|
||||||
|
|
||||||
// Take a guard before anything else.
|
|
||||||
OpBuilder::InsertionGuard g(b);
|
|
||||||
b.setInsertionPoint(castOp);
|
|
||||||
|
|
||||||
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
|
|
||||||
if (!resultBuffer)
|
|
||||||
return failure();
|
|
||||||
Type sourceType = resultBuffer.getType();
|
|
||||||
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
|
|
||||||
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
|
|
||||||
assert(rankedMemRefType || unrankedMemRefType);
|
|
||||||
Attribute memorySpace = rankedMemRefType
|
|
||||||
? rankedMemRefType.getMemorySpace()
|
|
||||||
: unrankedMemRefType.getMemorySpace();
|
|
||||||
TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
|
|
||||||
MemRefLayoutAttrInterface layout =
|
|
||||||
rankedMemRefType && tensorType.isa<RankedTensorType>()
|
|
||||||
? rankedMemRefType.getLayout()
|
|
||||||
: MemRefLayoutAttrInterface();
|
|
||||||
Type memRefType = getContiguousOrUnrankedMemRefType(
|
|
||||||
castOp.getResult().getType(), layout, memorySpace);
|
|
||||||
Value res =
|
|
||||||
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
|
|
||||||
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
|
|
||||||
state.mapBuffer(castOp.getResult(), res);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DimOpInterface
|
|
||||||
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
|
|
||||||
tensor::DimOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return OpResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
|
||||||
BufferizationState &state) const {
|
|
||||||
auto dimOp = cast<tensor::DimOp>(op);
|
|
||||||
|
|
||||||
// Take a guard before anything else.
|
|
||||||
OpBuilder::InsertionGuard g(b);
|
|
||||||
b.setInsertionPoint(dimOp);
|
|
||||||
|
|
||||||
if (dimOp.source().getType().isa<RankedTensorType>()) {
|
|
||||||
Value v = state.lookupBuffer(dimOp.source());
|
|
||||||
dimOp.result().replaceAllUsesWith(
|
|
||||||
b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ExtractSliceOpInterface
|
|
||||||
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
|
|
||||||
tensor::ExtractSliceOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
|
||||||
OpResult opResult) const {
|
|
||||||
return {&op->getOpOperand(0) /*source*/};
|
|
||||||
}
|
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return &opOperand == &op->getOpOperand(0) /*source*/
|
|
||||||
? op->getResult(0)
|
|
||||||
: OpResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return BufferRelation::None;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
|
||||||
BufferizationState &state) const {
|
|
||||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
|
||||||
LDBG("bufferize: " << *extractSliceOp << '\n');
|
|
||||||
|
|
||||||
// Take a guard before anything else.
|
|
||||||
OpBuilder::InsertionGuard g(b);
|
|
||||||
b.setInsertionPoint(extractSliceOp);
|
|
||||||
|
|
||||||
Location loc = extractSliceOp.getLoc();
|
|
||||||
Value srcMemref = state.lookupBuffer(extractSliceOp.source());
|
|
||||||
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
|
|
||||||
auto dstTensorType =
|
|
||||||
extractSliceOp.result().getType().cast<RankedTensorType>();
|
|
||||||
|
|
||||||
// If not inplaceable, alloc.
|
|
||||||
bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
|
|
||||||
Value alloc;
|
|
||||||
if (!inplace)
|
|
||||||
alloc = createNewAllocDeallocPairForShapedValue(
|
|
||||||
b, loc, extractSliceOp.result(), state);
|
|
||||||
|
|
||||||
// Bufferize to subview.
|
|
||||||
auto subviewMemRefType =
|
|
||||||
memref::SubViewOp::inferRankReducedResultType(
|
|
||||||
dstTensorType.getRank(), srcMemrefType,
|
|
||||||
extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
|
|
||||||
extractSliceOp.getMixedStrides())
|
|
||||||
.cast<MemRefType>();
|
|
||||||
Value subView = b.create<memref::SubViewOp>(
|
|
||||||
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
|
|
||||||
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
|
|
||||||
// Insert new alias.
|
|
||||||
state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
|
|
||||||
|
|
||||||
/// If not inplaceable, copy.
|
|
||||||
if (!inplace) {
|
|
||||||
// Do not copy if the copied data is never read.
|
|
||||||
if (isValueRead(extractSliceOp.result()))
|
|
||||||
state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
|
|
||||||
alloc);
|
|
||||||
subView = alloc;
|
|
||||||
}
|
|
||||||
|
|
||||||
state.mapBuffer(extractSliceOp.result(), subView);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ExtractOpInterface
|
|
||||||
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
|
|
||||||
tensor::ExtractOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return OpResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
|
||||||
BufferizationState &state) const {
|
|
||||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
|
||||||
|
|
||||||
// Take a guard before anything else.
|
|
||||||
OpBuilder::InsertionGuard g(b);
|
|
||||||
b.setInsertionPoint(extractOp);
|
|
||||||
|
|
||||||
Location loc = extractOp.getLoc();
|
|
||||||
Value srcMemref = state.lookupBuffer(extractOp.tensor());
|
|
||||||
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
|
|
||||||
extractOp.replaceAllUsesWith(l);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
|
|
||||||
/// equivalent operand / result and same offset/sizes/strides specification).
|
|
||||||
///
|
|
||||||
/// This is one particular type of relationship between ops on tensors that
|
|
||||||
/// reduce to an equivalence on buffers. This should be generalized and
|
|
||||||
/// exposed as interfaces on the proper types.
|
|
||||||
static bool
|
|
||||||
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
|
|
||||||
ExtractSliceOp st, InsertSliceOp sti) {
|
|
||||||
if (!st || !sti)
|
|
||||||
return false;
|
|
||||||
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
|
|
||||||
return false;
|
|
||||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return true if the source of a `insertSliceOp` bufferizes to an
|
|
||||||
/// equivalent ExtractSliceOp that bufferizes inplace.
|
|
||||||
static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
|
||||||
const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
|
|
||||||
LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
|
|
||||||
<< '\n');
|
|
||||||
bool foundOp = false;
|
|
||||||
aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
|
|
||||||
auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
|
|
||||||
if (extractSliceOp &&
|
|
||||||
areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
|
|
||||||
insertSliceOp) &&
|
|
||||||
aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
|
|
||||||
LDBG("\tfound: " << extractSliceOp.getOperation() << '\n');
|
|
||||||
foundOp = true;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!foundOp)
|
|
||||||
LDBG("\tnot equivalent\n");
|
|
||||||
|
|
||||||
return foundOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
|
||||||
/// the given InsertSliceOp.
|
|
||||||
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
|
||||||
Value value, InsertSliceOp insertOp) {
|
|
||||||
auto condition = [&](Value val) {
|
|
||||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
|
||||||
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
|
|
||||||
condition);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct InsertSliceOpInterface
|
|
||||||
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
|
|
||||||
tensor::InsertSliceOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
|
||||||
OpResult opResult) const {
|
|
||||||
return {&op->getOpOperand(1) /*dest*/};
|
|
||||||
}
|
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return &opOperand == &op->getOpOperand(1) /*dest*/
|
|
||||||
? op->getResult(0)
|
|
||||||
: OpResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
|
||||||
return BufferRelation::Equivalent;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
|
||||||
OpOperand *uConflictingWrite,
|
|
||||||
const BufferizationAliasInfo &aliasInfo) const {
|
|
||||||
Operation *readingOp = uRead->getOwner();
|
|
||||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
|
||||||
|
|
||||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
|
||||||
// uRead is an InsertSliceOp...
|
|
||||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
|
|
||||||
// As an example, consider the following IR.
|
|
||||||
//
|
|
||||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
|
||||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
|
||||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
|
||||||
// {inplace= [true] }
|
|
||||||
|
|
||||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
|
||||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
|
||||||
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
|
|
||||||
insertSliceOp))
|
|
||||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
|
||||||
// the destination tensor. The overwritten area is not read. If
|
|
||||||
// uConflictingWrite writes into exactly the memory location that is
|
|
||||||
// being read by uRead, this is not a conflict.
|
|
||||||
//
|
|
||||||
// In the above example:
|
|
||||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
|
||||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
|
||||||
//
|
|
||||||
// The read of %t does not conflict with the write of the FillOp
|
|
||||||
// (same aliases!) because the area that the FillOp operates on is
|
|
||||||
// exactly the one that is *not* read via %t.
|
|
||||||
return true;
|
|
||||||
|
|
||||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
|
||||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
|
||||||
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
|
|
||||||
// Case 2: The read of the source tensor and the write to the dest
|
|
||||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
|
||||||
// reading exactly that part of an equivalent tensor that the
|
|
||||||
// InsertSliceOp is writing.
|
|
||||||
//
|
|
||||||
// In the above example:
|
|
||||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
|
||||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If uConflictingWrite is an InsertSliceOp...
|
|
||||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
|
|
||||||
// As an example, consider the following IR.
|
|
||||||
//
|
|
||||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
|
||||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
|
||||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
|
||||||
// {inplace= [true] }
|
|
||||||
// %3 = vector.transfer_read %1, %cst
|
|
||||||
//
|
|
||||||
// In the above example:
|
|
||||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
|
||||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
|
||||||
// lastWrite = %1
|
|
||||||
//
|
|
||||||
// This is not a conflict because the InsertSliceOp overwrites the
|
|
||||||
// memory segment of %1 with the exact same data. (Effectively, there
|
|
||||||
// is no memory write here.)
|
|
||||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
|
||||||
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
|
|
||||||
insertSliceOp.source()) &&
|
|
||||||
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
|
|
||||||
insertSliceOp))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
|
||||||
BufferizationState &state) const {
|
|
||||||
// insert_slice ops arise from tiling and bufferizing them out-of-place is
|
|
||||||
// generally a deal breaker. When used with loops, this ends up cloning the
|
|
||||||
// whole tensor on every single iteration and is a symptom of a
|
|
||||||
// catastrophically bad scheduling decision.
|
|
||||||
// TODO: be very loud about it or even consider failing the pass.
|
|
||||||
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
|
|
||||||
LDBG("bufferize: " << *insertSliceOp << '\n');
|
|
||||||
|
|
||||||
// Take a guard before anything else.
|
|
||||||
OpBuilder::InsertionGuard g(b);
|
|
||||||
b.setInsertionPoint(insertSliceOp);
|
|
||||||
Location loc = insertSliceOp.getLoc();
|
|
||||||
|
|
||||||
// When bufferizing out-of-place, `getResultBuffer` allocates.
|
|
||||||
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
|
|
||||||
if (!dstMemref)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// A copy of the source buffer is needed if either:
|
|
||||||
// - The producer of `source` is not inplace. This is the case where a
|
|
||||||
// slice is computed out of place into the inplace full tensor.
|
|
||||||
// - The result is not inplace. This is the case where the whole tensor is
|
|
||||||
// cloned and the clone needs to be updated.
|
|
||||||
// TODO: Is this necessary?
|
|
||||||
bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
|
||||||
state.aliasInfo, insertSliceOp) ||
|
|
||||||
!state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
|
|
||||||
if (needCopy) {
|
|
||||||
LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
|
|
||||||
<< " -> copy\n");
|
|
||||||
// Take a subview of the dst.
|
|
||||||
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
|
|
||||||
auto subviewMemRefType =
|
|
||||||
memref::SubViewOp::inferRankReducedResultType(
|
|
||||||
insertSliceOp.getSourceType().getRank(), dstMemrefType,
|
|
||||||
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
|
|
||||||
insertSliceOp.getMixedStrides())
|
|
||||||
.cast<MemRefType>();
|
|
||||||
Value subView = b.create<memref::SubViewOp>(
|
|
||||||
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
|
|
||||||
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
|
|
||||||
// Insert new alias.
|
|
||||||
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
|
|
||||||
// Copy tensor.
|
|
||||||
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
|
|
||||||
state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
|
|
||||||
subView);
|
|
||||||
}
|
|
||||||
|
|
||||||
state.mapBuffer(insertSliceOp.result(), dstMemref);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace tensor_ext
|
|
||||||
|
|
||||||
namespace vector_ext {
|
namespace vector_ext {
|
||||||
|
|
||||||
struct TransferReadOpInterface
|
struct TransferReadOpInterface
|
||||||
|
|
@ -2484,13 +2031,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||||
registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
|
registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
|
||||||
registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
|
registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
|
||||||
registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
|
registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
|
||||||
registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
|
|
||||||
registry.addOpInterface<tensor::DimOp, tensor_ext::DimOpInterface>();
|
|
||||||
registry.addOpInterface<tensor::ExtractSliceOp,
|
|
||||||
tensor_ext::ExtractSliceOpInterface>();
|
|
||||||
registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
|
|
||||||
registry.addOpInterface<tensor::InsertSliceOp,
|
|
||||||
tensor_ext::InsertSliceOpInterface>();
|
|
||||||
registry.addOpInterface<vector::TransferReadOp,
|
registry.addOpInterface<vector::TransferReadOp,
|
||||||
vector_ext::TransferReadOpInterface>();
|
vector_ext::TransferReadOpInterface>();
|
||||||
registry.addOpInterface<vector::TransferWriteOp,
|
registry.addOpInterface<vector::TransferWriteOp,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,437 @@
|
||||||
|
//===- TensorInterfaceImpl.cpp - Tensor Impl. of BufferizableOpInterface --===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace linalg {
|
||||||
|
namespace comprehensive_bufferize {
|
||||||
|
namespace tensor_ext {
|
||||||
|
|
||||||
|
using tensor::ExtractSliceOp;
|
||||||
|
using tensor::InsertSliceOp;
|
||||||
|
|
||||||
|
struct CastOpInterface
|
||||||
|
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
|
||||||
|
tensor::CastOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||||
|
OpResult opResult) const {
|
||||||
|
return {&op->getOpOperand(0)};
|
||||||
|
}
|
||||||
|
|
||||||
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return op->getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return BufferRelation::Equivalent;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||||
|
BufferizationState &state) const {
|
||||||
|
auto castOp = cast<tensor::CastOp>(op);
|
||||||
|
|
||||||
|
// Take a guard before anything else.
|
||||||
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
b.setInsertionPoint(castOp);
|
||||||
|
|
||||||
|
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
|
||||||
|
if (!resultBuffer)
|
||||||
|
return failure();
|
||||||
|
Type sourceType = resultBuffer.getType();
|
||||||
|
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
|
||||||
|
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
|
||||||
|
assert(rankedMemRefType || unrankedMemRefType);
|
||||||
|
Attribute memorySpace = rankedMemRefType
|
||||||
|
? rankedMemRefType.getMemorySpace()
|
||||||
|
: unrankedMemRefType.getMemorySpace();
|
||||||
|
TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
|
||||||
|
MemRefLayoutAttrInterface layout =
|
||||||
|
rankedMemRefType && tensorType.isa<RankedTensorType>()
|
||||||
|
? rankedMemRefType.getLayout()
|
||||||
|
: MemRefLayoutAttrInterface();
|
||||||
|
Type memRefType = getContiguousOrUnrankedMemRefType(
|
||||||
|
castOp.getResult().getType(), layout, memorySpace);
|
||||||
|
Value res =
|
||||||
|
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
|
||||||
|
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
|
||||||
|
state.mapBuffer(castOp.getResult(), res);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DimOpInterface
|
||||||
|
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
|
||||||
|
tensor::DimOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return OpResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||||
|
BufferizationState &state) const {
|
||||||
|
auto dimOp = cast<tensor::DimOp>(op);
|
||||||
|
|
||||||
|
// Take a guard before anything else.
|
||||||
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
b.setInsertionPoint(dimOp);
|
||||||
|
|
||||||
|
if (dimOp.source().getType().isa<RankedTensorType>()) {
|
||||||
|
Value v = state.lookupBuffer(dimOp.source());
|
||||||
|
dimOp.result().replaceAllUsesWith(
|
||||||
|
b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ExtractSliceOpInterface
|
||||||
|
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
|
||||||
|
tensor::ExtractSliceOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||||
|
OpResult opResult) const {
|
||||||
|
return {&op->getOpOperand(0) /*source*/};
|
||||||
|
}
|
||||||
|
|
||||||
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return &opOperand == &op->getOpOperand(0) /*source*/
|
||||||
|
? op->getResult(0)
|
||||||
|
: OpResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return BufferRelation::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||||
|
BufferizationState &state) const {
|
||||||
|
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||||
|
|
||||||
|
// Take a guard before anything else.
|
||||||
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
b.setInsertionPoint(extractSliceOp);
|
||||||
|
|
||||||
|
Location loc = extractSliceOp.getLoc();
|
||||||
|
Value srcMemref = state.lookupBuffer(extractSliceOp.source());
|
||||||
|
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
|
||||||
|
auto dstTensorType =
|
||||||
|
extractSliceOp.result().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
// If not inplaceable, alloc.
|
||||||
|
bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
|
||||||
|
Value alloc;
|
||||||
|
if (!inplace)
|
||||||
|
alloc = state.allocationFns.createAllocDeallocFn(
|
||||||
|
b, loc, extractSliceOp.result(), state);
|
||||||
|
|
||||||
|
// Bufferize to subview.
|
||||||
|
auto subviewMemRefType =
|
||||||
|
memref::SubViewOp::inferRankReducedResultType(
|
||||||
|
dstTensorType.getRank(), srcMemrefType,
|
||||||
|
extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
|
||||||
|
extractSliceOp.getMixedStrides())
|
||||||
|
.cast<MemRefType>();
|
||||||
|
Value subView = b.create<memref::SubViewOp>(
|
||||||
|
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
|
||||||
|
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
|
||||||
|
// Insert new alias.
|
||||||
|
state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
|
||||||
|
|
||||||
|
/// If not inplaceable, copy.
|
||||||
|
if (!inplace) {
|
||||||
|
// Do not copy if the copied data is never read.
|
||||||
|
if (isValueRead(extractSliceOp.result()))
|
||||||
|
state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
|
||||||
|
alloc);
|
||||||
|
subView = alloc;
|
||||||
|
}
|
||||||
|
|
||||||
|
state.mapBuffer(extractSliceOp.result(), subView);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ExtractOpInterface
|
||||||
|
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
|
||||||
|
tensor::ExtractOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return OpResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||||
|
BufferizationState &state) const {
|
||||||
|
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||||
|
|
||||||
|
// Take a guard before anything else.
|
||||||
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
b.setInsertionPoint(extractOp);
|
||||||
|
|
||||||
|
Location loc = extractOp.getLoc();
|
||||||
|
Value srcMemref = state.lookupBuffer(extractOp.tensor());
|
||||||
|
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
|
||||||
|
extractOp.replaceAllUsesWith(l);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
|
||||||
|
/// equivalent operand / result and same offset/sizes/strides specification).
|
||||||
|
///
|
||||||
|
/// This is one particular type of relationship between ops on tensors that
|
||||||
|
/// reduce to an equivalence on buffers. This should be generalized and
|
||||||
|
/// exposed as interfaces on the proper types.
|
||||||
|
static bool
|
||||||
|
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
|
||||||
|
ExtractSliceOp st, InsertSliceOp sti) {
|
||||||
|
if (!st || !sti)
|
||||||
|
return false;
|
||||||
|
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
|
||||||
|
return false;
|
||||||
|
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if the source of a `insertSliceOp` bufferizes to an
|
||||||
|
/// equivalent ExtractSliceOp that bufferizes inplace.
|
||||||
|
static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
||||||
|
const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
|
||||||
|
bool foundOp = false;
|
||||||
|
aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
|
||||||
|
auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
|
||||||
|
if (extractSliceOp &&
|
||||||
|
areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
|
||||||
|
insertSliceOp) &&
|
||||||
|
aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
|
||||||
|
foundOp = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return foundOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||||
|
/// the given InsertSliceOp.
|
||||||
|
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
||||||
|
Value value, InsertSliceOp insertOp) {
|
||||||
|
auto condition = [&](Value val) {
|
||||||
|
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||||
|
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
|
||||||
|
condition);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct InsertSliceOpInterface
|
||||||
|
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
|
||||||
|
tensor::InsertSliceOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||||
|
OpResult opResult) const {
|
||||||
|
return {&op->getOpOperand(1) /*dest*/};
|
||||||
|
}
|
||||||
|
|
||||||
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return &opOperand == &op->getOpOperand(1) /*dest*/
|
||||||
|
? op->getResult(0)
|
||||||
|
: OpResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
||||||
|
return BufferRelation::Equivalent;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||||
|
OpOperand *uConflictingWrite,
|
||||||
|
const BufferizationAliasInfo &aliasInfo) const {
|
||||||
|
Operation *readingOp = uRead->getOwner();
|
||||||
|
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||||
|
|
||||||
|
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||||
|
// uRead is an InsertSliceOp...
|
||||||
|
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
|
||||||
|
// As an example, consider the following IR.
|
||||||
|
//
|
||||||
|
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||||
|
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||||
|
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||||
|
// {inplace= [true] }
|
||||||
|
|
||||||
|
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||||
|
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||||
|
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
|
||||||
|
insertSliceOp))
|
||||||
|
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||||
|
// the destination tensor. The overwritten area is not read. If
|
||||||
|
// uConflictingWrite writes into exactly the memory location that is
|
||||||
|
// being read by uRead, this is not a conflict.
|
||||||
|
//
|
||||||
|
// In the above example:
|
||||||
|
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||||
|
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||||
|
//
|
||||||
|
// The read of %t does not conflict with the write of the FillOp
|
||||||
|
// (same aliases!) because the area that the FillOp operates on is
|
||||||
|
// exactly the one that is *not* read via %t.
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||||
|
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||||
|
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
|
||||||
|
// Case 2: The read of the source tensor and the write to the dest
|
||||||
|
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||||
|
// reading exactly that part of an equivalent tensor that the
|
||||||
|
// InsertSliceOp is writing.
|
||||||
|
//
|
||||||
|
// In the above example:
|
||||||
|
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||||
|
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If uConflictingWrite is an InsertSliceOp...
|
||||||
|
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
|
||||||
|
// As an example, consider the following IR.
|
||||||
|
//
|
||||||
|
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||||
|
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||||
|
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||||
|
// {inplace= [true] }
|
||||||
|
// %3 = vector.transfer_read %1, %cst
|
||||||
|
//
|
||||||
|
// In the above example:
|
||||||
|
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||||
|
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||||
|
// lastWrite = %1
|
||||||
|
//
|
||||||
|
// This is not a conflict because the InsertSliceOp overwrites the
|
||||||
|
// memory segment of %1 with the exact same data. (Effectively, there
|
||||||
|
// is no memory write here.)
|
||||||
|
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||||
|
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
|
||||||
|
insertSliceOp.source()) &&
|
||||||
|
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
|
||||||
|
insertSliceOp))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||||
|
BufferizationState &state) const {
|
||||||
|
// insert_slice ops arise from tiling and bufferizing them out-of-place is
|
||||||
|
// generally a deal breaker. When used with loops, this ends up cloning the
|
||||||
|
// whole tensor on every single iteration and is a symptom of a
|
||||||
|
// catastrophically bad scheduling decision.
|
||||||
|
// TODO: be very loud about it or even consider failing the pass.
|
||||||
|
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
|
||||||
|
|
||||||
|
// Take a guard before anything else.
|
||||||
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
b.setInsertionPoint(insertSliceOp);
|
||||||
|
Location loc = insertSliceOp.getLoc();
|
||||||
|
|
||||||
|
// When bufferizing out-of-place, `getResultBuffer` allocates.
|
||||||
|
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
|
||||||
|
if (!dstMemref)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// A copy of the source buffer is needed if either:
|
||||||
|
// - The producer of `source` is not inplace. This is the case where a
|
||||||
|
// slice is computed out of place into the inplace full tensor.
|
||||||
|
// - The result is not inplace. This is the case where the whole tensor is
|
||||||
|
// cloned and the clone needs to be updated.
|
||||||
|
// TODO: Is this necessary?
|
||||||
|
bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
||||||
|
state.aliasInfo, insertSliceOp) ||
|
||||||
|
!state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
|
||||||
|
if (needCopy) {
|
||||||
|
// Take a subview of the dst.
|
||||||
|
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
|
||||||
|
auto subviewMemRefType =
|
||||||
|
memref::SubViewOp::inferRankReducedResultType(
|
||||||
|
insertSliceOp.getSourceType().getRank(), dstMemrefType,
|
||||||
|
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
|
||||||
|
insertSliceOp.getMixedStrides())
|
||||||
|
.cast<MemRefType>();
|
||||||
|
Value subView = b.create<memref::SubViewOp>(
|
||||||
|
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
|
||||||
|
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
|
||||||
|
// Insert new alias.
|
||||||
|
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
|
||||||
|
// Copy tensor.
|
||||||
|
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
|
||||||
|
state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
|
||||||
|
subView);
|
||||||
|
}
|
||||||
|
|
||||||
|
state.mapBuffer(insertSliceOp.result(), dstMemref);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensor_ext
|
||||||
|
} // namespace comprehensive_bufferize
|
||||||
|
} // namespace linalg
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
void mlir::linalg::comprehensive_bufferize::tensor_ext::
|
||||||
|
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||||
|
registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
|
||||||
|
registry.addOpInterface<tensor::DimOp, tensor_ext::DimOpInterface>();
|
||||||
|
registry.addOpInterface<tensor::ExtractSliceOp,
|
||||||
|
tensor_ext::ExtractSliceOpInterface>();
|
||||||
|
registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
|
||||||
|
registry.addOpInterface<tensor::InsertSliceOp,
|
||||||
|
tensor_ext::InsertSliceOpInterface>();
|
||||||
|
}
|
||||||
|
|
@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
||||||
MLIRStandardOpsTransforms
|
MLIRStandardOpsTransforms
|
||||||
MLIRStandardToLLVM
|
MLIRStandardToLLVM
|
||||||
MLIRTensor
|
MLIRTensor
|
||||||
|
MLIRTensorBufferizableOpInterfaceImpl
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRTransformUtils
|
MLIRTransformUtils
|
||||||
MLIRVector
|
MLIRVector
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@
|
||||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
|
||||||
#include "mlir/Dialect/Linalg/Passes.h"
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
|
@ -38,6 +39,7 @@ struct LinalgComprehensiveModuleBufferize
|
||||||
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
|
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
|
||||||
registerBufferizableOpInterfaceExternalModels(registry);
|
registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
|
||||||
|
|
@ -6326,6 +6326,25 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "TensorBufferizableOpInterfaceImpl",
|
||||||
|
srcs = [
|
||||||
|
"lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h",
|
||||||
|
],
|
||||||
|
includes = ["include"],
|
||||||
|
deps = [
|
||||||
|
":BufferizableOpInterface",
|
||||||
|
":IR",
|
||||||
|
":MemRefDialect",
|
||||||
|
":Support",
|
||||||
|
":TensorDialect",
|
||||||
|
"//llvm:Support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
td_library(
|
td_library(
|
||||||
name = "LinalgDocTdFiles",
|
name = "LinalgDocTdFiles",
|
||||||
srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
|
srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
|
||||||
|
|
@ -6545,6 +6564,7 @@ cc_library(
|
||||||
":StandardOps",
|
":StandardOps",
|
||||||
":StandardOpsTransforms",
|
":StandardOpsTransforms",
|
||||||
":Support",
|
":Support",
|
||||||
|
":TensorBufferizableOpInterfaceImpl",
|
||||||
":TensorDialect",
|
":TensorDialect",
|
||||||
":TransformUtils",
|
":TransformUtils",
|
||||||
":VectorOps",
|
":VectorOps",
|
||||||
|
|
@ -6575,7 +6595,6 @@ cc_library(
|
||||||
":SCFDialect",
|
":SCFDialect",
|
||||||
":StandardOps",
|
":StandardOps",
|
||||||
":Support",
|
":Support",
|
||||||
":TensorDialect",
|
|
||||||
":TransformUtils",
|
":TransformUtils",
|
||||||
":VectorOps",
|
":VectorOps",
|
||||||
"//llvm:Support",
|
"//llvm:Support",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue