Memref subview operation has been initially designed to work on memrefs with strided layouts only and has never supported anything else. Port it to use the recently added StridedLayoutAttr instead of extracting the strided from implicitly from affine maps. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D133938
61 lines
2.3 KiB
C++
61 lines
2.3 KiB
C++
//===- InferShapeTest.cpp - unit tests for shape inference ----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::memref;
|
|
|
|
// Source memref has identity layout.
|
|
TEST(InferShapeTest, inferRankReducedShapeIdentity) {
|
|
MLIRContext ctx;
|
|
OpBuilder b(&ctx);
|
|
auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType());
|
|
auto reducedType = SubViewOp::inferRankReducedResultType(
|
|
/*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
|
|
auto expectedType = MemRefType::get(
|
|
{2}, b.getIndexType(),
|
|
StridedLayoutAttr::get(&ctx, /*offset=*/13, /*strides=*/{1}));
|
|
EXPECT_EQ(reducedType, expectedType);
|
|
}
|
|
|
|
// Source memref has non-identity layout.
|
|
TEST(InferShapeTest, inferRankReducedShapeNonIdentity) {
|
|
MLIRContext ctx;
|
|
OpBuilder b(&ctx);
|
|
AffineExpr dim0, dim1;
|
|
bindDims(&ctx, dim0, dim1);
|
|
auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
|
|
AffineMap::get(2, 0, 1000 * dim0 + dim1));
|
|
auto reducedType = SubViewOp::inferRankReducedResultType(
|
|
/*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
|
|
auto expectedType = MemRefType::get(
|
|
{2}, b.getIndexType(),
|
|
StridedLayoutAttr::get(&ctx, /*offset=*/2003, /*strides=*/{1}));
|
|
EXPECT_EQ(reducedType, expectedType);
|
|
}
|
|
|
|
TEST(InferShapeTest, inferRankReducedShapeToScalar) {
|
|
MLIRContext ctx;
|
|
OpBuilder b(&ctx);
|
|
AffineExpr dim0, dim1;
|
|
bindDims(&ctx, dim0, dim1);
|
|
auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
|
|
AffineMap::get(2, 0, 1000 * dim0 + dim1));
|
|
auto reducedType = SubViewOp::inferRankReducedResultType(
|
|
/*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1});
|
|
auto expectedType = MemRefType::get(
|
|
{}, b.getIndexType(),
|
|
StridedLayoutAttr::get(&ctx, /*offset=*/2003, /*strides=*/{}));
|
|
EXPECT_EQ(reducedType, expectedType);
|
|
}
|