diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h index 99218f491dde..852407292979 100644 --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -40,7 +40,7 @@ class ArrayAttr; /// Assuming `sizes` is `[s0, .. sn]`, return the vector /// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`. /// -/// `sizes` elements are asserted to be non-negative. +/// `sizes` elements `s1` to `sn` are asserted to be non-negative. /// /// Return an empty vector if `sizes` is empty. SmallVector computeSuffixProduct(ArrayRef sizes); diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 89ade79a3ac0..a0c8acea91dc 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -839,6 +839,25 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ /// bool areTrailingDimsContiguous(int64_t n); + /// Return the number of trailing dimensions that are contiguous. + /// + /// Examples: + /// - memref<5x3x2xi8, strided<[6,2,1]>>, the number of collapsable + /// trailing dimensions is 3 + /// - memref<5x3x2xi8, strided<[12,2,1]>>, the number of collapsable + /// trailing dimensions is 2 (dimension 0 is non-contiguous) + /// - memref<5x3x2xi8, strided<[12,4,1]>>, the number of collapsable + /// trailing dimensions is 1 (dimension 1 is non-contiguous) + /// - memref<5x3x2xi8, strided<[12,4,2]>>, the number of collapsable + /// trailing dimensions is 0 (dimension 2 is non-contiguous) + /// - memref>, the number of collapsable + /// trailing dimensions is 3 + /// - memref>, the number of collapsable + /// trailing dimensions is 2 (dimension 0 is non-contiguous) + /// - memref<5x?x2xi8, strided<[?,2,1]>>, the number of collapsable + /// trailing dimensions is 2 (stride 0 is dynamic) + int64_t getNumContiguousTrailingDims(); + /// Return a version of this type with identity layout if it can be /// determined statically that the layout is the canonical contiguous /// strided layout. Otherwise pass the layout into `simplifyAffineMap` diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp index 8de77e2c3cb0..e1648ab99ff2 100644 --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -69,7 +69,8 @@ SmallVector delinearizeImpl(ExprType linearIndex, //===----------------------------------------------------------------------===// SmallVector mlir::computeSuffixProduct(ArrayRef sizes) { - assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) && + assert((sizes.empty() || + llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) && "sizes must be nonnegative"); int64_t unit = 1; return ::computeSuffixProductImpl(sizes, unit); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index e3a00ac5a14b..6661efa8907b 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -660,35 +660,45 @@ LogicalResult MemRefType::verify(function_ref emitError, } bool MemRefType::areTrailingDimsContiguous(int64_t n) { - if (!isLastDimUnitStride()) - return false; + assert(n <= getRank() && + "number of dimensions to check must not exceed rank"); + return n <= getNumContiguousTrailingDims(); +} - auto memrefShape = getShape().take_back(n); - if (ShapedType::isDynamicShape(memrefShape)) - return false; +int64_t MemRefType::getNumContiguousTrailingDims() { + const int64_t n = getRank(); + // memrefs with identity layout are entirely contiguous. if (getLayout().isIdentity()) - return true; + return n; + // Get the strides (if any). Failing to do that, conservatively assume a + // non-contiguous layout. int64_t offset; - SmallVector stridesFull; - if (!succeeded(getStridesAndOffset(stridesFull, offset))) - return false; - auto strides = ArrayRef(stridesFull).take_back(n); + SmallVector strides; + if (!succeeded(getStridesAndOffset(strides, offset))) + return 0; - if (strides.empty()) - return true; + ArrayRef shape = getShape(); - // Check whether strides match "flattened" dims. - SmallVector flattenedDims; - auto dimProduct = 1; - for (auto dim : llvm::reverse(memrefShape.drop_front(1))) { - dimProduct *= dim; - flattenedDims.push_back(dimProduct); + // A memref with dimensions `d0, d1, ..., dn-1` and strides + // `s0, s1, ..., sn-1` is contiguous up to dimension `k` + // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`, + // for `i` in `[k, n-1]`. + // Ignore stride elements if the corresponding dimension is 1, as they are + // of no consequence. + int64_t dimProduct = 1; + for (int64_t i = n - 1; i >= 0; --i) { + if (shape[i] == 1) + continue; + if (strides[i] != dimProduct) + return n - i - 1; + if (shape[i] == ShapedType::kDynamic) + return n - i; + dimProduct *= shape[i]; } - strides = strides.drop_back(1); - return llvm::equal(strides, llvm::reverse(flattenedDims)); + return n; } MemRefType MemRefType::canonicalizeStridedLayout() { diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index e840dc6bbf22..45873aa93153 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -188,9 +188,35 @@ func.func @transfer_read_leading_dynamic_dims( // ----- -// One of the dims to be flattened is dynamic - not supported ATM. +// The vector is a non-contiguous slice of the input +// memref. func.func @negative_transfer_read_dynamic_dim_to_flatten( + %mem : memref<4x?x?x2xi8>) -> vector<2x2x2xi8> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : + memref<4x?x?x2xi8>, vector<2x2x2xi8> + return %res : vector<2x2x2xi8> +} + +// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten( +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten( +// CHECK-128B-NOT: memref.collapse_shape + +// ----- + +// When collapsing memref dimensions, we may include the rightmost dynamic +// dimension (e.g., at position `k`) provided that the strides for dimensions +// `k+1`, `k+2`, etc., ensure contiguity in memory. The stride at position `k` +// itself does not factor into this. (Here "strides" mean both explicit and +// implied by identity map) + +func.func @transfer_read_dynamic_dim_to_flatten( %idx_1: index, %idx_2: index, %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { @@ -203,11 +229,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten( return %res : vector<1x2x6xi32> } -// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten -// CHECK-NOT: memref.collapse_shape -// CHECK-NOT: vector.shape_cast +// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> -// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten +// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten +// CHECK-SAME: %[[IDX_1:arg0]] +// CHECK-SAME: %[[IDX_2:arg1]] +// CHECK-SAME: %[[MEM:arg2]] +// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32> +// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]] +// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]], +// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32> +// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32> +// CHECK: return %[[RESULT]] : vector<1x2x6xi32> + + +// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten // CHECK-128B-NOT: memref.collapse_shape // ----- @@ -451,9 +491,31 @@ func.func @transfer_write_leading_dynamic_dims( // ----- -// One of the dims to be flattened is dynamic - not supported ATM. +// The vector is a non-contiguous slice of the input +// memref. func.func @negative_transfer_write_dynamic_to_flatten( + %mem : memref<4x?x?x2xi8>, + %vec : vector<2x2x2xi8>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] + : vector<2x2x2xi8>, memref<4x?x?x2xi8> + return +} + +// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten( +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten( +// CHECK-128B-NOT: memref.collapse_shape + +// ----- + +// See the comment in front of @transfer_read_dynamic_dim_to_flatten. + +func.func @transfer_write_dynamic_dim_to_flatten( %idx_1: index, %idx_2: index, %vec : vector<1x2x6xi32>, @@ -466,11 +528,24 @@ func.func @negative_transfer_write_dynamic_to_flatten( return } -// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten -// CHECK-NOT: memref.collapse_shape -// CHECK-NOT: vector.shape_cast +// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> -// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten +// CHECK-LABEL: func.func @transfer_write_dynamic_dim_to_flatten +// CHECK-SAME: %[[IDX_1:arg0]]: index +// CHECK-SAME: %[[IDX_2:arg1]]: index +// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32> +// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32> + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32> +// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]] +// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32> +// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]] +// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32> + +// CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten // CHECK-128B-NOT: memref.collapse_shape // ----- diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt index 770064486457..d22afb3003e7 100644 --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_unittest(MLIRIRTests IRMapping.cpp InterfaceAttachmentTest.cpp LocationTest.cpp + MemrefLayoutTest.cpp OperationSupportTest.cpp PatternMatchTest.cpp ShapedTypeTest.cpp diff --git a/mlir/unittests/IR/MemrefLayoutTest.cpp b/mlir/unittests/IR/MemrefLayoutTest.cpp new file mode 100644 index 000000000000..f243a76ee660 --- /dev/null +++ b/mlir/unittests/IR/MemrefLayoutTest.cpp @@ -0,0 +1,111 @@ +//===- LayoutTest.cpp - unit tests related to memref layout ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::memref; + +// +// Test the correctness of `memref::getNumContiguousTrailingDims` +// +TEST(MemRefLayout, numContigDim) { + MLIRContext ctx; + OpBuilder b(&ctx); + + const int64_t _ = ShapedType::kDynamic; + const FloatType f32 = b.getF32Type(); + auto strided = [&ctx](ArrayRef s) { + return StridedLayoutAttr::get(&ctx, 0, s); + }; + + // Special case for identity maps and no explicit `strided` attribute - the + // memref is entirely contiguous even if the strides cannot be determined + // statically. + + // memref + auto m0 = MemRefType::get({_, _, _}, f32); + EXPECT_EQ(m0.getNumContiguousTrailingDims(), 3); + + // Conservatively assume memref is sparse everywhere if cannot get the + // strides. + + // memref<2x2x2xf32, (i,j,k)->(i,k,j)> + auto m1 = MemRefType::get( + {2, 2, 2}, f32, + AffineMap::getPermutationMap(ArrayRef{0, 2, 1}, &ctx)); + EXPECT_EQ(m1.getNumContiguousTrailingDims(), 0); + + // A base cases of a fixed memref with the usual strides. + + // memref<2x2x2xf32, strided<[4, 2, 1]>> + auto m3 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1})); + EXPECT_EQ(m3.getNumContiguousTrailingDims(), 3); + + // A fixed memref with a discontinuity in the rightmost dimension. + + // memref<2x2x2xf32, strided<[8, 4, 2]>> + auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2})); + EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0); + + // A fixed memref with a discontinuity in the "middle". + + // memref<2x2x2xf32, strided<[8, 2, 1]>> + auto m5 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1})); + EXPECT_EQ(m5.getNumContiguousTrailingDims(), 2); + + // A dynamic memref where the dynamic dimension breaks continuity. + + // memref<2x?x2xf32, strided<[4, 2, 1]>> + auto m6 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1})); + EXPECT_EQ(m6.getNumContiguousTrailingDims(), 2); + + // A edge case of a dynamic memref where the dynamic dimension is the first + // one. + + // memref> + auto m7 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1})); + EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2); + + // A memref with a unit dimension. Unit dimensions do not affect continuity, + // even if the corresponding stride is dynamic. + + // memref<2x1x2xf32, strided<[2,?,1]>> + auto m8 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1})); + EXPECT_EQ(m8.getNumContiguousTrailingDims(), 3); +} + +// +// Test the member function `memref::areTrailingDimsContiguous` +// +TEST(MemRefLayout, contigTrailingDim) { + MLIRContext ctx; + OpBuilder b(&ctx); + + const int64_t _ = ShapedType::kDynamic; + const FloatType f32 = b.getF32Type(); + auto strided = [&ctx](ArrayRef s) { + return StridedLayoutAttr::get(&ctx, 0, s); + }; + + // A not-entirely-continuous, not-entirely-discontinuous memref. + // ensure `areTrailingDimsContiguous` returns `true` for the value + // returned by `getNumContiguousTrailingDims` and `false` for the next bigger + // number. + + // memref<2x?x2xf32, strided<[?,2,1]>> + auto m = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1})); + int64_t n = m.getNumContiguousTrailingDims(); + EXPECT_TRUE(m.areTrailingDimsContiguous(n)); + ASSERT_TRUE(n + 1 <= m.getRank()); + EXPECT_FALSE(m.areTrailingDimsContiguous(n + 1)); +}