[mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract (#111541)
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
This commit is contained in:
@@ -235,6 +235,11 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
|
||||
std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
|
||||
/// slice is contiguous, into extract and shape_cast.
|
||||
void populateVectorContiguousExtractStridedSliceToExtractPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
|
||||
/// based on the destination vector shape. Bitcasts from a lower bitwidth
|
||||
/// element type to a higher bitwidth one are extracted from the lower bitwidth
|
||||
|
||||
@@ -329,12 +329,70 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
|
||||
/// slice is contiguous, into extract and shape_cast.
|
||||
class ContiguousExtractStridedSliceToExtract final
|
||||
: public OpRewritePattern<ExtractStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op.hasNonUnitStrides()) {
|
||||
return failure();
|
||||
}
|
||||
Value source = op.getOperand();
|
||||
auto sourceType = cast<VectorType>(source.getType());
|
||||
if (sourceType.isScalable()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Compute the number of offsets to pass to ExtractOp::build. That is the
|
||||
// difference between the source rank and the desired slice rank. We walk
|
||||
// the dimensions from innermost out, and stop when the next slice dimension
|
||||
// is not full-size.
|
||||
SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
|
||||
int numOffsets;
|
||||
for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
|
||||
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If not even the inner-most dimension is full-size, this op can't be
|
||||
// rewritten as an ExtractOp.
|
||||
if (numOffsets == sourceType.getRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Avoid generating slices that have unit outer dimensions. The shape_cast
|
||||
// op that we create below would take bad generic fallback patterns
|
||||
// (ShapeCastOpRewritePattern).
|
||||
while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
|
||||
++numOffsets;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
|
||||
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
|
||||
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
|
||||
extractOffsets);
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<DecomposeDifferentRankInsertStridedSlice,
|
||||
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
|
||||
}
|
||||
|
||||
void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
|
||||
benefit);
|
||||
}
|
||||
|
||||
void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
std::function<bool(ExtractStridedSliceOp)> controlFn,
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @extract_strided_slice_to_extract_i8
|
||||
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
|
||||
// CHECK: return %[[EXTRACT]] : vector<8xi8>
|
||||
func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
|
||||
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
|
||||
%2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
|
||||
return %2 : vector<8xi8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extract_strided_slice_to_extract_i32
|
||||
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
|
||||
// CHECK: return %[[EXTRACT]] : vector<4xi32>
|
||||
func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
|
||||
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
|
||||
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
|
||||
return %2 : vector<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
|
||||
// CHECK: vector.extract_strided_slice
|
||||
func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
|
||||
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
|
||||
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32>
|
||||
return %2 : vector<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2
|
||||
// CHECK: vector.extract_strided_slice
|
||||
func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
|
||||
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
|
||||
%2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
|
||||
return %2 : vector<2xi32>
|
||||
}
|
||||
@@ -709,6 +709,27 @@ struct TestVectorExtractStridedSliceLowering
|
||||
}
|
||||
};
|
||||
|
||||
struct TestVectorContiguousExtractStridedSliceToExtract
|
||||
: public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
|
||||
OperationPass<func::FuncOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
||||
TestVectorExtractStridedSliceLowering)
|
||||
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-contiguous-extract-strided-slice-to-extract";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test lowering patterns that rewrite simple cases of N-D "
|
||||
"extract_strided_slice, where the slice is contiguous, into extract "
|
||||
"and shape_cast";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct TestVectorBreakDownBitCast
|
||||
: public PassWrapper<TestVectorBreakDownBitCast,
|
||||
OperationPass<func::FuncOp>> {
|
||||
@@ -935,6 +956,8 @@ void registerTestVectorLowerings() {
|
||||
|
||||
PassRegistration<TestVectorExtractStridedSliceLowering>();
|
||||
|
||||
PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
|
||||
|
||||
PassRegistration<TestVectorBreakDownBitCast>();
|
||||
|
||||
PassRegistration<TestCreateVectorBroadcast>();
|
||||
|
||||
Reference in New Issue
Block a user