From 10054ba4acbc5378d2e2aa869a5bccd88aa4b59e Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Tue, 8 Oct 2024 11:51:01 -0400 Subject: [PATCH] [mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract (#111541) Co-authored-by: Jakub Kuderski --- .../Vector/Transforms/VectorRewritePatterns.h | 5 ++ ...sertExtractStridedSliceRewritePatterns.cpp | 58 +++++++++++++++++++ ...uous-extract-strided-slice-to-extract.mlir | 35 +++++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 23 ++++++++ 4 files changed, 121 insertions(+) create mode 100644 mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index a59f06f3c1ef..ec1de7fa66aa 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -235,6 +235,11 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns( std::function controlFn = nullptr, PatternBenefit benefit = 1); +/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the +/// slice is contiguous, into extract and shape_cast. +void populateVectorContiguousExtractStridedSliceToExtractPatterns( + RewritePatternSet &patterns, PatternBenefit benefit = 1); + /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops /// based on the destination vector shape. Bitcasts from a lower bitwidth /// element type to a higher bitwidth one are extracted from the lower bitwidth diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index ec2ef3fc7501..c2da9347aadc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -329,12 +329,70 @@ public: } }; +/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the +/// slice is contiguous, into extract and shape_cast. +class ContiguousExtractStridedSliceToExtract final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractStridedSliceOp op, + PatternRewriter &rewriter) const override { + if (op.hasNonUnitStrides()) { + return failure(); + } + Value source = op.getOperand(); + auto sourceType = cast(source.getType()); + if (sourceType.isScalable()) { + return failure(); + } + + // Compute the number of offsets to pass to ExtractOp::build. That is the + // difference between the source rank and the desired slice rank. We walk + // the dimensions from innermost out, and stop when the next slice dimension + // is not full-size. + SmallVector sizes = getI64SubArray(op.getSizes()); + int numOffsets; + for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) { + if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) { + break; + } + } + + // If not even the inner-most dimension is full-size, this op can't be + // rewritten as an ExtractOp. + if (numOffsets == sourceType.getRank()) { + return failure(); + } + + // Avoid generating slices that have unit outer dimensions. The shape_cast + // op that we create below would take bad generic fallback patterns + // (ShapeCastOpRewritePattern). + while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) { + ++numOffsets; + } + + SmallVector offsets = getI64SubArray(op.getOffsets()); + auto extractOffsets = ArrayRef(offsets).take_front(numOffsets); + Value extract = rewriter.create(op->getLoc(), source, + extractOffsets); + rewriter.replaceOpWithNewOp(op, op.getType(), extract); + return success(); + } +}; + void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); } +void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); +} + void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns( RewritePatternSet &patterns, std::function controlFn, diff --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir new file mode 100644 index 000000000000..9147e7bf0258 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s + +// CHECK-LABEL: @extract_strided_slice_to_extract_i8 +// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8> +// CHECK: return %[[EXTRACT]] : vector<8xi8> +func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> { + %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8> + %2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8> + return %2 : vector<8xi8> +} + +// CHECK-LABEL: @extract_strided_slice_to_extract_i32 +// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32> +// CHECK: return %[[EXTRACT]] : vector<4xi32> +func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> { + %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32> + %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32> + return %2 : vector<4xi32> +} + +// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1 +// CHECK: vector.extract_strided_slice +func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> { + %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32> + %2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32> + return %2 : vector<2xi32> +} + +// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2 +// CHECK: vector.extract_strided_slice +func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> { + %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32> + %2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32> + return %2 : vector<2xi32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 72aaa7dc4f89..d91e955b7064 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -709,6 +709,27 @@ struct TestVectorExtractStridedSliceLowering } }; +struct TestVectorContiguousExtractStridedSliceToExtract + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestVectorExtractStridedSliceLowering) + + StringRef getArgument() const final { + return "test-vector-contiguous-extract-strided-slice-to-extract"; + } + StringRef getDescription() const final { + return "Test lowering patterns that rewrite simple cases of N-D " + "extract_strided_slice, where the slice is contiguous, into extract " + "and shape_cast"; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestVectorBreakDownBitCast : public PassWrapper> { @@ -935,6 +956,8 @@ void registerTestVectorLowerings() { PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration();