diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 453fa73429dd..fa2912a3e577 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -389,6 +389,13 @@ void populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth); +/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D +/// vector shuffle operations. +void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + unsigned targetBitWidth); + } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index b59e9062e5a0..69999f0918c1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -13,9 +13,16 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include using namespace mlir; @@ -103,6 +110,251 @@ public: return success(); } +private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works +/// on a linearized vector. +/// Following, +/// vector.extract_strided_slice %source +/// { offsets = [..], strides = [..], sizes = [..] } +/// is converted to : +/// %source_1d = vector.shape_cast %source +/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] +/// %out_nd = vector.shape_cast %out_1d +/// `shuffle_indices_1d` is computed using the offsets and sizes of the +/// extraction. +struct LinearizeVectorExtractStridedSlice final + : public mlir::OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorExtractStridedSlice( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstType = getTypeConverter()->convertType(extractOp.getType()); + assert(!(extractOp.getVector().getType().isScalable() || + dstType.cast().isScalable()) && + "scalable vectors are not supported."); + if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) + return rewriter.notifyMatchFailure( + extractOp, "Can't flatten since targetBitWidth <= OpSize"); + + ArrayAttr offsets = extractOp.getOffsets(); + ArrayAttr sizes = extractOp.getSizes(); + ArrayAttr strides = extractOp.getStrides(); + if (!isConstantIntValue(strides[0], 1)) + return rewriter.notifyMatchFailure( + extractOp, "Strided slice with stride != 1 is not supported."); + Value srcVector = adaptor.getVector(); + // If kD offsets are specified for nD source vector (n > k), the granularity + // of the extraction is greater than 1. In this case last (n-k) dimensions + // form the extraction granularity. + // Example : + // vector.extract_strided_slice %src { + // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : + // vector<4x8x8xf32> to vector<2x2x8xf32> + // Here, extraction granularity is 8. + int64_t extractGranularitySize = 1; + int64_t nD = extractOp.getSourceVectorType().getRank(); + int64_t kD = (int64_t)offsets.size(); + int64_t k = kD; + while (k < nD) { + extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; + ++k; + } + // Get total number of extracted slices. + int64_t nExtractedSlices = 1; + for (Attribute size : sizes) { + nExtractedSlices *= size.cast().getInt(); + } + // Compute the strides of the source vector considering first k dimensions. + llvm::SmallVector sourceStrides(kD, extractGranularitySize); + for (int i = kD - 2; i >= 0; --i) { + sourceStrides[i] = sourceStrides[i + 1] * + extractOp.getSourceVectorType().getShape()[i + 1]; + } + // Final shuffle indices has nExtractedSlices * extractGranularitySize + // elements. + llvm::SmallVector indices(nExtractedSlices * + extractGranularitySize); + // Compute the strides of the extracted kD vector. + llvm::SmallVector extractedStrides(kD, 1); + // Compute extractedStrides. + for (int i = kD - 2; i >= 0; --i) { + extractedStrides[i] = + extractedStrides[i + 1] * sizes[i + 1].cast().getInt(); + } + // Iterate over all extracted slices from 0 to nExtractedSlices - 1 + // and compute the multi-dimensional index and the corresponding linearized + // index within the source vector. + for (int64_t i = 0; i < nExtractedSlices; ++i) { + int64_t index = i; + // Compute the corresponding multi-dimensional index. + llvm::SmallVector multiDimIndex(kD, 0); + for (int64_t j = 0; j < kD; ++j) { + multiDimIndex[j] = (index / extractedStrides[j]); + index -= multiDimIndex[j] * extractedStrides[j]; + } + // Compute the corresponding linearized index in the source vector + // i.e. shift the multiDimIndex by the offsets. + int64_t linearizedIndex = 0; + for (int64_t j = 0; j < kD; ++j) { + linearizedIndex += + (offsets[j].cast().getInt() + multiDimIndex[j]) * + sourceStrides[j]; + } + // Fill the indices array form linearizedIndex to linearizedIndex + + // extractGranularitySize. + for (int64_t j = 0; j < extractGranularitySize; ++j) { + indices[i * extractGranularitySize + j] = linearizedIndex + j; + } + } + // Perform a shuffle to extract the kD vector. + rewriter.replaceOpWithNewOp( + extractOp, dstType, srcVector, srcVector, + rewriter.getI64ArrayAttr(indices)); + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the ShuffleOp that works on nD (n > 1) +/// vectors to a ShuffleOp that works on linearized vectors. +/// Following, +/// vector.shuffle %v1, %v2 [ shuffle_indices ] +/// is converted to : +/// %v1_1d = vector.shape_cast %v1 +/// %v2_1d = vector.shape_cast %v2 +/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] +/// %out_nd = vector.shape_cast %out_1d +// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` +/// of the original shuffle operation. +struct LinearizeVectorShuffle final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorShuffle( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstType = getTypeConverter()->convertType(shuffleOp.getType()); + assert(!(shuffleOp.getV1VectorType().isScalable() || + shuffleOp.getV2VectorType().isScalable() || + dstType.cast().isScalable()) && + "scalable vectors are not supported."); + if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) + return rewriter.notifyMatchFailure( + shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); + + Value vec1 = adaptor.getV1(); + Value vec2 = adaptor.getV2(); + int shuffleSliceLen = 1; + int rank = shuffleOp.getV1().getType().getRank(); + + // If rank > 1, we need to do the shuffle in the granularity of slices + // instead of scalars. Size of the slice is equal to the rank-1 innermost + // dims. Mask of the shuffle op specifies which slice to take from the + // outermost dim. + if (rank > 1) { + llvm::ArrayRef shape = shuffleOp.getV1().getType().getShape(); + for (unsigned i = 1; i < shape.size(); ++i) { + shuffleSliceLen *= shape[i]; + } + } + + // For each value in the mask, we generate the indices of the source vectors + // that needs to be shuffled to the destination vector. If shuffleSliceLen > + // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of + // elements) instead of scalars. + ArrayAttr mask = shuffleOp.getMask(); + int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; + llvm::SmallVector indices(totalSizeOfShuffledElmnts); + for (auto [i, value] : + llvm::enumerate(mask.getAsValueRange())) { + + int64_t v = value.getZExtValue(); + std::iota(indices.begin() + shuffleSliceLen * i, + indices.begin() + shuffleSliceLen * (i + 1), + shuffleSliceLen * v); + } + + rewriter.replaceOpWithNewOp( + shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices)); + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the ExtractOp to a ShuffleOp that works on a +/// linearized vector. +/// Following, +/// vector.extract %source [ position ] +/// is converted to : +/// %source_1d = vector.shape_cast %source +/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] +/// %out_nd = vector.shape_cast %out_1d +/// `shuffle_indices_1d` is computed using the position of the original extract. +struct LinearizeVectorExtract final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorExtract( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + LogicalResult + matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = getTypeConverter()->convertType(extractOp.getType()); + assert(!(extractOp.getVector().getType().isScalable() || + dstTy.cast().isScalable()) && + "scalable vectors are not supported."); + if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) + return rewriter.notifyMatchFailure( + extractOp, "Can't flatten since targetBitWidth <= OpSize"); + + // Dynamic position is not supported. + if (extractOp.hasDynamicPosition()) + return rewriter.notifyMatchFailure(extractOp, + "dynamic position is not supported."); + + llvm::ArrayRef shape = extractOp.getVector().getType().getShape(); + int64_t size = extractOp.getVector().getType().getNumElements(); + + // Compute linearized offset. + int64_t linearizedOffset = 0; + llvm::ArrayRef offsets = extractOp.getStaticPosition(); + for (auto [i, off] : llvm::enumerate(offsets)) { + size /= shape[i]; + linearizedOffset += offsets[i] * size; + } + + llvm::SmallVector indices(size); + std::iota(indices.begin(), indices.end(), linearizedOffset); + rewriter.replaceOpWithNewOp( + extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), + rewriter.getI64ArrayAttr(indices)); + + return success(); + } + private: unsigned targetVectorBitWidth; }; @@ -145,3 +397,21 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( patterns.add( typeConverter, patterns.getContext(), targetBitWidth); } + +void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, unsigned int targetBitWidth) { + target.addDynamicallyLegalOp( + [=](vector::ShuffleOp shuffleOp) -> bool { + return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) + ? (typeConverter.isLegal(shuffleOp) && + shuffleOp.getResult() + .getType() + .cast() + .getRank() == 1) + : true; + }); + patterns.add( + typeConverter, patterns.getContext(), targetBitWidth); +} diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 22be78cd6820..b29ceab5783d 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -153,3 +153,95 @@ func.func @test_0d_vector() -> vector { // ALL: return %[[CST]] return %0 : vector } + +// ----- +// ALL-LABEL: test_extract_strided_slice_1 +// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> { +func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> { + // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32> + // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] + // DEFAULT-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32> + // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32> + // DEFAULT: return %[[RES]] : vector<2x2xf32 + + // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32> + // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] + // BW-128-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32> + // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32> + // BW-128: return %[[RES]] : vector<2x2xf32> + + // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ARG:.*]] {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32> + // BW-0: return %[[RES]] : vector<2x2xf32> + %0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]} + : vector<4x8xf32> to vector<2x2xf32> + return %0 : vector<2x2xf32> +} + +// ----- +// ALL-LABEL: test_extract_strided_slice_2 +// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> { +func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> { + // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32> + // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] + // DEFAULT-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32> + // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32> + // DEFAULT: return %[[RES]] : vector<1x4x2xf32> + + // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32> + // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] + // BW-128-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32> + // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32> + // BW-128: return %[[RES]] : vector<1x4x2xf32> + + // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ORIG_ARG]] {offsets = [1, 2], sizes = [1, 4], strides = [1, 1]} : vector<2x8x2xf32> to vector<1x4x2xf32> + // BW-0: return %[[RES]] : vector<1x4x2xf32> + %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] } + : vector<2x8x2xf32> to vector<1x4x2xf32> + return %0 : vector<1x4x2xf32> +} + +// ----- +// ALL-LABEL: test_vector_shuffle +// ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> { +func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> { + // DEFAULT: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32> + // DEFAULT: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32> + // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]] + // DEFAULT-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32> + // DEFAULT: return %[[RES]] : vector<8x2xf32> + + // BW-128: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32> + // BW-128: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32> + // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]] + // BW-128-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32> + // BW-128: return %[[RES]] : vector<8x2xf32> + + // BW-0: %[[RES:.*]] = vector.shuffle %[[ORIG_ARG0]], %[[ORIG_ARG1]] [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32> + // BW-0: return %[[RES]] : vector<8x2xf32> + %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32> + return %0 : vector<8x2xf32> +} + +// ----- +// ALL-LABEL: test_vector_extract +// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> { +func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { + // DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32> + // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] + // DEFAULT-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32> + // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32> + // DEFAULT: return %[[RES]] : vector<8x2xf32> + + // BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32> + // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] + // BW-128-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32> + // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32> + // BW-128: return %[[RES]] : vector<8x2xf32> + + // BW-0: %[[RES:.*]] = vector.extract %[[ORIG_ARG]][1] : vector<8x2xf32> from vector<2x8x2xf32> + // BW-0: return %[[RES]] : vector<8x2xf32> + %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32> + return %0 : vector<8x2xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 006225999105..c978699e179f 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -867,6 +867,8 @@ struct TestVectorLinearize final vector::populateVectorLinearizeTypeConversionsAndLegality( typeConverter, patterns, target, targetVectorBitwidth); + vector::populateVectorLinearizeShuffleLikeOpsPatterns( + typeConverter, patterns, target, targetVectorBitwidth); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure();