[mlir] Support MemRefType with multiple AffineMaps in getStridesAndOffset

Compose multiple AffineMaps into single map before strides extraction.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D99166
This commit is contained in:
Vladislav Vinogradov 2021-03-23 13:30:30 +03:00
parent ffa455d4d4
commit 70b6f16e07
4 changed files with 63 additions and 14 deletions

View File

@ -673,26 +673,23 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
SmallVectorImpl<AffineExpr> &strides,
AffineExpr &offset) {
auto affineMaps = t.getAffineMaps();
// For now strides are only computed on a single affine map with a single
// result (i.e. the closed subset of linearization maps that are compatible
// with striding semantics).
// TODO: support more forms on a per-need basis.
if (affineMaps.size() > 1)
return failure();
if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1)
return failure();
AffineMap m;
if (!affineMaps.empty()) {
m = affineMaps.back();
for (size_t i = affineMaps.size() - 1; i > 0; --i)
m = m.compose(affineMaps[i - 1]);
assert(!m.isIdentity() && "unexpected identity map");
}
auto zero = getAffineConstantExpr(0, t.getContext());
auto one = getAffineConstantExpr(1, t.getContext());
offset = zero;
strides.assign(t.getRank(), zero);
AffineMap m;
if (!affineMaps.empty()) {
m = affineMaps.front();
assert(!m.isIdentity() && "unexpected identity map");
}
// Canonical case for empty map.
if (!m) {
// 0-D corner case, offset is already 0.

View File

@ -60,7 +60,8 @@ func @f(%0: index) {
// CHECK: MemRefType offset: 123 strides:
%100 = memref.alloc(%0, %0)[%0, %0] : memref<?x?x16xf32, affine_map<(i, j, k)[M, N]->(i + j, j, k)>, affine_map<(i, j, k)[M, N]->(M * i + N * j + k + 1)>>
// CHECK: MemRefType memref<?x?x16xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 + d1, d1, d2)>, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2 + 1)>> cannot be converted to strided form
// CHECK: MemRefType offset: 1 strides: ?, ?, 1
%101 = memref.alloc() : memref<3x4x5xf32, affine_map<(i, j, k)->(i floordiv 4 + j + k)>>
// CHECK: MemRefType memref<3x4x5xf32, affine_map<(d0, d1, d2) -> (d0 floordiv 4 + d1 + d2)>> cannot be converted to strided form
%102 = memref.alloc() : memref<3x4x5xf32, affine_map<(i, j, k)->(i ceildiv 4 + j + k)>>

View File

@ -1,6 +1,7 @@
add_mlir_unittest(MLIRIRTests
AttributeTest.cpp
DialectTest.cpp
MemRefTypeTest.cpp
OperationSupportTest.cpp
ShapedTypeTest.cpp
)

View File

@ -0,0 +1,50 @@
//===- MemRefTypeTest.cpp - MemRefType unit tests -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::detail;
namespace {
TEST(MemRefTypeTest, GetStridesAndOffset) {
MLIRContext context;
SmallVector<int64_t> shape({2, 3, 4});
Type f32 = FloatType::getF32(&context);
AffineMap map1 = makeStridedLinearLayoutMap({12, 4, 1}, 5, &context);
MemRefType type1 = MemRefType::get(shape, f32, {map1});
SmallVector<int64_t> strides1;
int64_t offset1 = -1;
LogicalResult res1 = getStridesAndOffset(type1, strides1, offset1);
ASSERT_TRUE(res1.succeeded());
ASSERT_EQ(3, strides1.size());
EXPECT_EQ(12, strides1[0]);
EXPECT_EQ(4, strides1[1]);
EXPECT_EQ(1, strides1[2]);
ASSERT_EQ(5, offset1);
AffineMap map2 = AffineMap::getPermutationMap({1, 2, 0}, &context);
AffineMap map3 = makeStridedLinearLayoutMap({8, 2, 1}, 0, &context);
MemRefType type2 = MemRefType::get(shape, f32, {map2, map3});
SmallVector<int64_t> strides2;
int64_t offset2 = -1;
LogicalResult res2 = getStridesAndOffset(type2, strides2, offset2);
ASSERT_TRUE(res2.succeeded());
ASSERT_EQ(3, strides2.size());
EXPECT_EQ(1, strides2[0]);
EXPECT_EQ(8, strides2[1]);
EXPECT_EQ(2, strides2[2]);
ASSERT_EQ(0, offset2);
}
} // end namespace