279 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			279 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			C++
		
	
	
	
//===- Traits.cpp - Common op traits shared by dialects -------------------===//
 | 
						|
//
 | 
						|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
						|
// See https://llvm.org/LICENSE.txt for license information.
 | 
						|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
						|
//
 | 
						|
//===----------------------------------------------------------------------===//
 | 
						|
 | 
						|
#include "mlir/Dialect/Traits.h"
 | 
						|
#include "mlir/IR/BuiltinTypes.h"
 | 
						|
#include "mlir/IR/TypeUtilities.h"
 | 
						|
#include "llvm/Support/FormatVariadic.h"
 | 
						|
 | 
						|
using namespace mlir;
 | 
						|
 | 
						|
bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
 | 
						|
                                                 ArrayRef<int64_t> shape2) {
 | 
						|
  SmallVector<SmallVector<int64_t, 6>, 2> extents;
 | 
						|
  extents.emplace_back(shape1.begin(), shape1.end());
 | 
						|
  extents.emplace_back(shape2.begin(), shape2.end());
 | 
						|
  return staticallyKnownBroadcastable(extents);
 | 
						|
}
 | 
						|
 | 
						|
bool OpTrait::util::staticallyKnownBroadcastable(
 | 
						|
    ArrayRef<SmallVector<int64_t, 6>> shapes) {
 | 
						|
  assert(!shapes.empty() && "Expected at least one shape");
 | 
						|
  size_t maxRank = shapes[0].size();
 | 
						|
  for (size_t i = 1; i != shapes.size(); ++i)
 | 
						|
    maxRank = std::max(maxRank, shapes[i].size());
 | 
						|
 | 
						|
  // We look backwards through every column of `shapes`.
 | 
						|
  for (size_t i = 0; i != maxRank; ++i) {
 | 
						|
    bool seenDynamic = false;
 | 
						|
    Optional<int64_t> nonOneDim;
 | 
						|
    for (ArrayRef<int64_t> extent : shapes) {
 | 
						|
      int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
 | 
						|
 | 
						|
      if (dim == 1)
 | 
						|
        continue;
 | 
						|
 | 
						|
      // Dimensions are compatible when
 | 
						|
      //.  1. One is dynamic, the rest are 1
 | 
						|
      if (ShapedType::isDynamic(dim)) {
 | 
						|
        if (seenDynamic || nonOneDim)
 | 
						|
          return false;
 | 
						|
        seenDynamic = true;
 | 
						|
      }
 | 
						|
 | 
						|
      //   2. All are 1 or a specific constant.
 | 
						|
      if (nonOneDim && dim != *nonOneDim)
 | 
						|
        return false;
 | 
						|
 | 
						|
      nonOneDim = dim;
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return true;
 | 
						|
}
 | 
						|
 | 
						|
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
 | 
						|
                                        ArrayRef<int64_t> shape2,
 | 
						|
                                        SmallVectorImpl<int64_t> &resultShape) {
 | 
						|
  // To compute the result broadcasted shape, we compare operand shapes
 | 
						|
  // element-wise: starting with the trailing dimensions, and working the
 | 
						|
  // way backward. Two dimensions are compatible when
 | 
						|
  //   1. they are equal, or
 | 
						|
  //   2. one of them is 1
 | 
						|
  // The result shape has the maximum among the two inputs at every
 | 
						|
  // dimension index.
 | 
						|
 | 
						|
  resultShape.clear();
 | 
						|
  if (shape1.size() > shape2.size()) {
 | 
						|
    std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
 | 
						|
  } else {
 | 
						|
    std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
 | 
						|
  }
 | 
						|
 | 
						|
  auto i1 = shape1.rbegin(), e1 = shape1.rend();
 | 
						|
  auto i2 = shape2.rbegin(), e2 = shape2.rend();
 | 
						|
  auto iR = resultShape.rbegin();
 | 
						|
 | 
						|
  // Check each dimension is consistent.
 | 
						|
  for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
 | 
						|
    if (*i1 == -1 || *i2 == -1) {
 | 
						|
      // One or both dimensions is unknown. Follow TensorFlow behavior:
 | 
						|
      // - If either dimension is greater than 1, we assume that the program is
 | 
						|
      //   correct, and the other dimension will be broadcast to match it.
 | 
						|
      // - If either dimension is 1, the other dimension is the output.
 | 
						|
      if (*i1 > 1) {
 | 
						|
        *iR = *i1;
 | 
						|
      } else if (*i2 > 1) {
 | 
						|
        *iR = *i2;
 | 
						|
      } else if (*i1 == 1) {
 | 
						|
        *iR = *i2;
 | 
						|
      } else if (*i2 == 1) {
 | 
						|
        *iR = *i1;
 | 
						|
      } else {
 | 
						|
        *iR = -1;
 | 
						|
      }
 | 
						|
    } else {
 | 
						|
      if (*i1 == *i2 || *i2 == 1) {
 | 
						|
        *iR = *i1;
 | 
						|
      } else if (*i1 == 1) {
 | 
						|
        *iR = *i2;
 | 
						|
      } else {
 | 
						|
        // This dimension of the two operand types is incompatible.
 | 
						|
        resultShape.clear();
 | 
						|
        return false;
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  return true;
 | 
						|
}
 | 
						|
 | 
						|
/// Returns the shape of the given type. Scalars will be considered as having a
 | 
						|
/// shape with zero dimensions.
 | 
						|
static ArrayRef<int64_t> getShape(Type type) {
 | 
						|
  if (auto sType = type.dyn_cast<ShapedType>())
 | 
						|
    return sType.getShape();
 | 
						|
  return {};
 | 
						|
}
 | 
						|
 | 
						|
/// Returns the result broadcast composition type from the two given types by
 | 
						|
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
 | 
						|
/// either of the input types has dynamic shape. Returns null type if the two
 | 
						|
/// given types are not broadcast-compatible.
 | 
						|
///
 | 
						|
/// elementType, if specified, will be used as the element type of the
 | 
						|
/// broadcasted result type. Otherwise it is required that the element type of
 | 
						|
/// type1 and type2 is the same and this element type will be used as the
 | 
						|
/// resultant element type.
 | 
						|
Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
 | 
						|
                                       Type elementType) {
 | 
						|
  // If the elementType is not specified, then the use the common element type
 | 
						|
  // of the inputs or fail if there is no common element type.
 | 
						|
  if (!elementType) {
 | 
						|
    elementType = getElementTypeOrSelf(type1);
 | 
						|
    if (elementType != getElementTypeOrSelf(type2))
 | 
						|
      return {};
 | 
						|
  }
 | 
						|
 | 
						|
  // If one of the types is unranked tensor, then the other type shouldn't be
 | 
						|
  // vector and the result should have unranked tensor type.
 | 
						|
  if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
 | 
						|
    if (type1.isa<VectorType>() || type2.isa<VectorType>())
 | 
						|
      return {};
 | 
						|
    return UnrankedTensorType::get(elementType);
 | 
						|
  }
 | 
						|
 | 
						|
  // Returns the type kind if the given type is a vector or ranked tensor type.
 | 
						|
  // Returns llvm::None otherwise.
 | 
						|
  auto getCompositeTypeKind = [](Type type) -> Optional<TypeID> {
 | 
						|
    if (type.isa<VectorType, RankedTensorType>())
 | 
						|
      return type.getTypeID();
 | 
						|
    return llvm::None;
 | 
						|
  };
 | 
						|
 | 
						|
  // Make sure the composite type, if has, is consistent.
 | 
						|
  Optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
 | 
						|
  Optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
 | 
						|
  Optional<TypeID> resultCompositeKind;
 | 
						|
 | 
						|
  if (compositeKind1 && compositeKind2) {
 | 
						|
    // Disallow mixing vector and tensor.
 | 
						|
    if (compositeKind1 != compositeKind2)
 | 
						|
      return {};
 | 
						|
    resultCompositeKind = compositeKind1;
 | 
						|
  } else if (compositeKind1) {
 | 
						|
    resultCompositeKind = compositeKind1;
 | 
						|
  } else if (compositeKind2) {
 | 
						|
    resultCompositeKind = compositeKind2;
 | 
						|
  }
 | 
						|
 | 
						|
  // Get the shape of each type.
 | 
						|
  SmallVector<int64_t, 4> resultShape;
 | 
						|
  if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
 | 
						|
    return {};
 | 
						|
 | 
						|
  // Compose the final broadcasted type
 | 
						|
  if (resultCompositeKind == VectorType::getTypeID())
 | 
						|
    return VectorType::get(resultShape, elementType);
 | 
						|
  if (resultCompositeKind == RankedTensorType::getTypeID())
 | 
						|
    return RankedTensorType::get(resultShape, elementType);
 | 
						|
  return elementType;
 | 
						|
}
 | 
						|
 | 
						|
/// Returns a tuple corresponding to whether range has tensor or vector type.
 | 
						|
template <typename iterator_range>
 | 
						|
static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
 | 
						|
  return std::make_tuple(
 | 
						|
      llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
 | 
						|
      llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
 | 
						|
}
 | 
						|
 | 
						|
static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
 | 
						|
                                            ArrayRef<int64_t> existing) {
 | 
						|
  auto isCompatible = [](int64_t dim1, int64_t dim2) {
 | 
						|
    // If the inferred and existing dim is the same, or one of them is unknown
 | 
						|
    // then it is compatible, else if the inferred dim is 1 then it is also
 | 
						|
    // compatible. But if the existing dim is 1 and the inferred is greater than
 | 
						|
    // 1 then flag.
 | 
						|
    return dim1 == dim2 || dim1 == -1 || dim2 == -1 || dim1 == 1;
 | 
						|
  };
 | 
						|
  if (inferred.size() != existing.size())
 | 
						|
    return false;
 | 
						|
  for (auto p : llvm::zip(inferred, existing))
 | 
						|
    if (!isCompatible(std::get<0>(p), std::get<1>(p)))
 | 
						|
      return false;
 | 
						|
  return true;
 | 
						|
}
 | 
						|
 | 
						|
static std::string getShapeString(ArrayRef<int64_t> shape) {
 | 
						|
  // TODO: should replace with printing shape more uniformly across here and
 | 
						|
  // when in type.
 | 
						|
  std::string ret;
 | 
						|
  llvm::raw_string_ostream ss(ret);
 | 
						|
  ss << '\'';
 | 
						|
  llvm::interleave(
 | 
						|
      shape, ss,
 | 
						|
      [&](int64_t dim) {
 | 
						|
        if (ShapedType::isDynamic(dim))
 | 
						|
          ss << '?';
 | 
						|
        else
 | 
						|
          ss << dim;
 | 
						|
      },
 | 
						|
      "x");
 | 
						|
  ss << '\'';
 | 
						|
  return ss.str();
 | 
						|
}
 | 
						|
 | 
						|
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
 | 
						|
  // Ensure broadcasting only tensor or only vector types.
 | 
						|
  auto operandsHasTensorVectorType =
 | 
						|
      hasTensorOrVectorType(op->getOperandTypes());
 | 
						|
  auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
 | 
						|
  if ((std::get<0>(operandsHasTensorVectorType) ||
 | 
						|
       std::get<0>(resultsHasTensorVectorType)) &&
 | 
						|
      (std::get<1>(operandsHasTensorVectorType) ||
 | 
						|
       std::get<1>(resultsHasTensorVectorType)))
 | 
						|
    return op->emitError("cannot broadcast vector with tensor");
 | 
						|
 | 
						|
  auto rankedOperands = make_filter_range(
 | 
						|
      op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
 | 
						|
 | 
						|
  // If all operands are unranked, then all result shapes are possible.
 | 
						|
  if (rankedOperands.empty())
 | 
						|
    return success();
 | 
						|
 | 
						|
  // Compute broadcasted shape of operands (which requires that operands are
 | 
						|
  // broadcast compatible). The results need to be broadcast compatible with
 | 
						|
  // this result shape.
 | 
						|
  SmallVector<int64_t, 4> resultShape;
 | 
						|
  (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
 | 
						|
                                  resultShape);
 | 
						|
  for (auto other : make_early_inc_range(rankedOperands)) {
 | 
						|
    SmallVector<int64_t, 4> temp = resultShape;
 | 
						|
    if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
 | 
						|
      return op->emitOpError("operands don't have broadcast-compatible shapes");
 | 
						|
  }
 | 
						|
 | 
						|
  auto rankedResults = make_filter_range(
 | 
						|
      op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
 | 
						|
 | 
						|
  // If all of the results are unranked then no further verification.
 | 
						|
  if (rankedResults.empty())
 | 
						|
    return success();
 | 
						|
 | 
						|
  for (auto type : rankedResults) {
 | 
						|
    ArrayRef<int64_t> actualSuffix =
 | 
						|
        getShape(type).take_back(resultShape.size());
 | 
						|
    if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
 | 
						|
      return op->emitOpError()
 | 
						|
             << "result type " << getShapeString(getShape(type))
 | 
						|
             << " not broadcast compatible with broadcasted operands's shapes "
 | 
						|
             << getShapeString(resultShape);
 | 
						|
  }
 | 
						|
  return success();
 | 
						|
}
 |