[mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract (#111541)

Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
This commit is contained in:
Benoit Jacob
2024-10-08 11:51:01 -04:00
committed by GitHub
parent d079743fe6
commit 10054ba4ac
4 changed files with 121 additions and 0 deletions

View File

@@ -235,6 +235,11 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
PatternBenefit benefit = 1);
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
/// slice is contiguous, into extract and shape_cast.
void populateVectorContiguousExtractStridedSliceToExtractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
/// based on the destination vector shape. Bitcasts from a lower bitwidth
/// element type to a higher bitwidth one are extracted from the lower bitwidth

View File

@@ -329,12 +329,70 @@ public:
}
};
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
/// slice is contiguous, into extract and shape_cast.
class ContiguousExtractStridedSliceToExtract final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
if (op.hasNonUnitStrides()) {
return failure();
}
Value source = op.getOperand();
auto sourceType = cast<VectorType>(source.getType());
if (sourceType.isScalable()) {
return failure();
}
// Compute the number of offsets to pass to ExtractOp::build. That is the
// difference between the source rank and the desired slice rank. We walk
// the dimensions from innermost out, and stop when the next slice dimension
// is not full-size.
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
int numOffsets;
for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
break;
}
}
// If not even the inner-most dimension is full-size, this op can't be
// rewritten as an ExtractOp.
if (numOffsets == sourceType.getRank()) {
return failure();
}
// Avoid generating slices that have unit outer dimensions. The shape_cast
// op that we create below would take bad generic fallback patterns
// (ShapeCastOpRewritePattern).
while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
++numOffsets;
}
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
extractOffsets);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
return success();
}
};
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DecomposeDifferentRankInsertStridedSlice,
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
}
void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
benefit);
}
void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
RewritePatternSet &patterns,
std::function<bool(ExtractStridedSliceOp)> controlFn,

View File

@@ -0,0 +1,35 @@
// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
// CHECK-LABEL: @extract_strided_slice_to_extract_i8
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
// CHECK: return %[[EXTRACT]] : vector<8xi8>
func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
%2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
return %2 : vector<8xi8>
}
// CHECK-LABEL: @extract_strided_slice_to_extract_i32
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
// CHECK: return %[[EXTRACT]] : vector<4xi32>
func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
return %2 : vector<4xi32>
}
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
// CHECK: vector.extract_strided_slice
func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32>
return %2 : vector<2xi32>
}
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2
// CHECK: vector.extract_strided_slice
func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
%2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
return %2 : vector<2xi32>
}

View File

@@ -709,6 +709,27 @@ struct TestVectorExtractStridedSliceLowering
}
};
struct TestVectorContiguousExtractStridedSliceToExtract
: public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorExtractStridedSliceLowering)
StringRef getArgument() const final {
return "test-vector-contiguous-extract-strided-slice-to-extract";
}
StringRef getDescription() const final {
return "Test lowering patterns that rewrite simple cases of N-D "
"extract_strided_slice, where the slice is contiguous, into extract "
"and shape_cast";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorBreakDownBitCast
: public PassWrapper<TestVectorBreakDownBitCast,
OperationPass<func::FuncOp>> {
@@ -935,6 +956,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorExtractStridedSliceLowering>();
PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
PassRegistration<TestVectorBreakDownBitCast>();
PassRegistration<TestCreateVectorBroadcast>();