[mlir][tosa] Added shape propagation for TOSA pool operations.
Pool operations perform the same shape propagation. Included the shape propagation and tests for these avg_pool2d and max_pool2d. Differential Revision: https://reviews.llvm.org/D105665
This commit is contained in:
parent
6611fbc62a
commit
f2832c2295
|
@ -56,7 +56,10 @@ def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Operator: avg_pool2d
|
// Operator: avg_pool2d
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [NoSideEffect]> {
|
def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
|
||||||
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||||
|
["inferReturnTypeComponents"]>,
|
||||||
|
NoSideEffect]> {
|
||||||
let summary = "Performs max pooling on the input.";
|
let summary = "Performs max pooling on the input.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -233,7 +236,10 @@ def Tosa_MatMulOp : Tosa_Op<"matmul", [
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Operator: max_pool2d
|
// Operator: max_pool2d
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [NoSideEffect]> {
|
def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
|
||||||
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||||
|
["inferReturnTypeComponents"]>,
|
||||||
|
NoSideEffect]> {
|
||||||
let summary = "Performs max pooling on the input.";
|
let summary = "Performs max pooling on the input.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -845,6 +845,62 @@ NARY_SHAPE_INFER(tosa::TanhOp)
|
||||||
NARY_SHAPE_INFER(tosa::SigmoidOp)
|
NARY_SHAPE_INFER(tosa::SigmoidOp)
|
||||||
#undef PRED_SHAPE_INFER
|
#undef PRED_SHAPE_INFER
|
||||||
|
|
||||||
|
static LogicalResult poolingInferReturnTypes(
|
||||||
|
ValueRange operands, DictionaryAttr attributes,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||||
|
RankedTensorType inputTy = operands[0].getType().dyn_cast<RankedTensorType>();
|
||||||
|
llvm::SmallVector<int64_t> outputShape;
|
||||||
|
outputShape.resize(4, -1);
|
||||||
|
|
||||||
|
// We only know the rank if the input type is unranked.
|
||||||
|
if (!inputTy) {
|
||||||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batch and number of channels are identical for pooling layer.
|
||||||
|
outputShape[0] = inputTy.getDimSize(0);
|
||||||
|
outputShape[3] = inputTy.getDimSize(3);
|
||||||
|
|
||||||
|
int32_t height = inputTy.getDimSize(1);
|
||||||
|
int32_t width = inputTy.getDimSize(2);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> kernel;
|
||||||
|
llvm::SmallVector<int64_t> stride;
|
||||||
|
llvm::SmallVector<int64_t> pad;
|
||||||
|
|
||||||
|
getI64Values(attributes.get("kernel").cast<ArrayAttr>(), kernel);
|
||||||
|
getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
|
||||||
|
getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
|
||||||
|
|
||||||
|
if (height != -1) {
|
||||||
|
int32_t padded = height + pad[0] + pad[1] - kernel[0];
|
||||||
|
outputShape[1] = padded / stride[0] + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (width != -1) {
|
||||||
|
int32_t padded = width + pad[2] + pad[3] - kernel[1];
|
||||||
|
outputShape[2] = padded / stride[1] + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
|
||||||
|
MLIRContext *context, ::llvm::Optional<Location> location,
|
||||||
|
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||||
|
return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult MaxPool2dOp::inferReturnTypeComponents(
|
||||||
|
MLIRContext *context, ::llvm::Optional<Location> location,
|
||||||
|
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||||
|
return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TOSA Operator Definitions.
|
// TOSA Operator Definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -660,3 +660,51 @@ func @scatter_minimum_static(%arg0 : tensor<?x4x?xi32>, %arg1 : tensor<3x?xi32>,
|
||||||
%0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x4x?xi32>, tensor<3x?xi32>, tensor<?x?x5xi32>) -> (tensor<?x?x?xi32>)
|
%0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<?x4x?xi32>, tensor<3x?xi32>, tensor<?x?x5xi32>) -> (tensor<?x?x?xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_pool_static
|
||||||
|
func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) {
|
||||||
|
// CHECK: -> tensor<3x2x4x7xf32>
|
||||||
|
%0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
|
||||||
|
// CHECK: -> tensor<3x2x4x7xf32>
|
||||||
|
%1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_pool_dynamic_input
|
||||||
|
func @test_pool_dynamic_input(%arg0: tensor<?x?x?x?xf32>) {
|
||||||
|
// CHECK: -> tensor<?x?x?x?xf32>
|
||||||
|
%0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
|
||||||
|
// CHECK: -> tensor<?x?x?x?xf32>
|
||||||
|
%1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_pool_padded
|
||||||
|
func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) {
|
||||||
|
// CHECK: -> tensor<3x5x11x7xf32>
|
||||||
|
%0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [1, 2, 3, 4], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
|
||||||
|
// CHECK: -> tensor<3x5x11x7xf32>
|
||||||
|
%1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [1, 2, 3, 4], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_pool_stride
|
||||||
|
func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) {
|
||||||
|
// CHECK: -> tensor<3x4x4x7xf32>
|
||||||
|
%0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [2, 3]} : (tensor<3x11x12x7xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
|
||||||
|
// CHECK: -> tensor<3x4x4x7xf32>
|
||||||
|
%1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [2, 3]} : (tensor<3x11x12x7xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue