//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" namespace mlir { namespace tensor { namespace { static bool areAllConstantIntValue(ArrayRef ofrs, int64_t value) { return llvm::all_of( ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); }); } /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and /// the pad op has zero low paddings, or if `pack` has no padding values. struct FoldPadWithPackOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { auto padOp = packOp.getSource().getDefiningOp(); if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) return failure(); Value constantPaddingValue = padOp.getConstantPaddingValue(); if (!constantPaddingValue) return failure(); if (auto paddingValue = packOp.getPaddingValue()) if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue)) return failure(); rewriter.replaceOpWithNewOp( packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(), packOp.getMixedTiles(), constantPaddingValue, packOp.getOuterDimsPerm()); return success(); } }; /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already /// has extract_slice semantics. struct FoldUnpackWithExtractSliceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { auto unpackOp = sliceOp.getSource().getDefiningOp(); if (!unpackOp) return failure(); if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) { return rewriter.notifyMatchFailure( sliceOp, "rank-reduced folding is not supported"); } // Check all offsets are zeros, and all strides are ones. if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { return rewriter.notifyMatchFailure( sliceOp, "expects offsets to be 0s and strides to be 1s"); } // Create a new empty output tensor. Type elementType = unpackOp.getDestType().getElementType(); Value output = rewriter.create( sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); rewriter.replaceOpWithNewOp( sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(), unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); return success(); } }; /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose /// semantics. struct FoldProducerPackWithConsumerLinalgTransposeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override { auto packOp = transposeOp.getOperand(0).getDefiningOp(); if (!packOp) return failure(); auto innerDimsPos = packOp.getInnerDimsPos(); auto mixedInnerTiles = packOp.getMixedTiles(); auto outerDimsPerm = packOp.getOuterDimsPerm(); auto transposePerm = transposeOp.getPermutation(); SmallVector newOuterDimsPermVec; SmallVector newInnerDimsPosVec; SmallVector newMixedInnerTilesVec; int64_t srcRank = packOp.getSourceRank(); // Process transpose operation for non-tiled outer dimensions for (unsigned int i = 0; i < srcRank; ++i) { int64_t remappedPosition = transposePerm[i]; // If tensor.pack has outer_dims_perm attribute, then consider it during // index remapping. if (!outerDimsPerm.empty()) { if (transposePerm[i] >= srcRank) { return rewriter.notifyMatchFailure( transposeOp, "Cannot fold in tensor.pack if a tile dimension was transposed " "with a non-tile dimension in linalg.transpose."); } remappedPosition = outerDimsPerm[remappedPosition]; } newOuterDimsPermVec.push_back(remappedPosition); } // Process transpose operation for tiled inner dimensions for (unsigned int i = srcRank; i < transposePerm.size(); ++i) { int64_t remappedPosition = transposePerm[i] - srcRank; newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]); newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]); } Value output = packOp.createDestinationTensor( rewriter, transposeOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec); rewriter.replaceOpWithNewOp( transposeOp, packOp.getSource(), output, newInnerDimsPosVec, newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec); return success(); } }; } // namespace void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { patterns.insert( patterns.getContext()); } } // namespace tensor } // namespace mlir