//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" namespace mlir { namespace tensor { namespace { static bool areAllConstantIntValue(ArrayRef ofrs, int64_t value) { return llvm::all_of( ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); }); } /// Packing one-dimensional tensor can be expressed as an expand shape op. struct SimplifyPackToExpandShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; Value insertExpand(RewriterBase &rewriter, Location loc, Value operand, Type newOperandType, ArrayAttr reassociation) const { if (operand.getType() == newOperandType) return operand; return rewriter.create(loc, newOperandType, operand, reassociation); } LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { if (packOp.getPaddingValue()) return rewriter.notifyMatchFailure(packOp, "expects no padding value"); auto outerDimsPerm = packOp.getOuterDimsPerm(); if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { return rewriter.notifyMatchFailure( packOp, "expects outer_dims_perm is empty or an identity permutation"); } RankedTensorType sourceType = packOp.getSourceType(); RankedTensorType destType = packOp.getDestType(); ArrayRef dimsPos = packOp.getInnerDimsPos(); if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) { return rewriter.notifyMatchFailure( packOp, "expects packing at the innermost dimension"); } auto reassociation = getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) return failure(); Value expanded = insertExpand( rewriter, packOp.getLoc(), packOp.getSource(), destType, getReassociationIndicesAttribute(rewriter, *reassociation)); rewriter.replaceOp(packOp, expanded); return success(); } }; struct SimplifyUnPackToCollapseShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand, Type newOperandType, ArrayAttr reassociation) const { if (operand.getType() == newOperandType) return operand; return rewriter.create(loc, newOperandType, operand, reassociation); } LogicalResult matchAndRewrite(UnPackOp unpackOp, PatternRewriter &rewriter) const override { auto outerDimsPerm = unpackOp.getOuterDimsPerm(); if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { return rewriter.notifyMatchFailure( unpackOp, "expects outer_dims_perm is empty or an identity permutation"); } RankedTensorType sourceType = unpackOp.getSourceType(); RankedTensorType destType = unpackOp.getDestType(); if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) return rewriter.notifyMatchFailure(unpackOp, "expects static shapes"); ArrayRef dimsPos = unpackOp.getInnerDimsPos(); if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) { return rewriter.notifyMatchFailure( unpackOp, "expects unpacking at the innermost dimension"); } auto reassociation = getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) return failure(); Value collapsed = insertCollapse( rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType, getReassociationIndicesAttribute(rewriter, *reassociation)); rewriter.replaceOp(unpackOp, collapsed); return success(); } }; /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and /// the pad op has zero low paddings, or if `pack` has no padding values. struct FoldPadWithPackOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { auto padOp = packOp.getSource().getDefiningOp(); if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) return failure(); Value constantPaddingValue = padOp.getConstantPaddingValue(); if (!constantPaddingValue) return failure(); if (auto paddingValue = packOp.getPaddingValue()) if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue)) return failure(); rewriter.replaceOpWithNewOp( packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(), packOp.getMixedTiles(), constantPaddingValue, packOp.getOuterDimsPerm()); return success(); } }; /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already /// has extract_slice semantics. struct FoldUnpackWithExtractSliceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { auto unpackOp = sliceOp.getSource().getDefiningOp(); if (!unpackOp) return failure(); if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) { return rewriter.notifyMatchFailure( sliceOp, "rank-reduced folding is not supported"); } // Check all offsets are zeros, and all strides are ones. if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { return rewriter.notifyMatchFailure( sliceOp, "expects offsets to be 0s and strides to be 1s"); } // Create a new empty output tensor. Type elementType = unpackOp.getDestType().getElementType(); Value output = rewriter.create( sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); rewriter.replaceOpWithNewOp( sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(), unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); return success(); } }; /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose /// semantics. struct FoldProducerPackWithConsumerLinalgTransposeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override { auto packOp = transposeOp.getOperand(0).getDefiningOp(); if (!packOp) return failure(); auto innerDimsPos = packOp.getInnerDimsPos(); auto mixedInnerTiles = packOp.getMixedTiles(); auto outerDimsPerm = packOp.getOuterDimsPerm(); auto transposePerm = transposeOp.getPermutation(); SmallVector newOuterDimsPermVec; SmallVector newInnerDimsPosVec; SmallVector newMixedInnerTilesVec; int64_t srcRank = packOp.getSourceRank(); // Process transpose operation for non-tiled outer dimensions for (unsigned int i = 0; i < srcRank; ++i) { int64_t remappedPosition = transposePerm[i]; // If tensor.pack has outer_dims_perm attribute, then consider it during // index remapping. if (!outerDimsPerm.empty()) { if (transposePerm[i] >= srcRank) { return rewriter.notifyMatchFailure( transposeOp, "Cannot fold in tensor.pack if a tile dimension was transposed " "with a non-tile dimension in linalg.transpose."); } remappedPosition = outerDimsPerm[remappedPosition]; } newOuterDimsPermVec.push_back(remappedPosition); } // Process transpose operation for tiled inner dimensions for (unsigned int i = srcRank; i < transposePerm.size(); ++i) { int64_t remappedPosition = transposePerm[i] - srcRank; newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]); newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]); } Value output = packOp.createDestinationTensor( rewriter, transposeOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec); rewriter.replaceOpWithNewOp( transposeOp, packOp.getSource(), output, newInnerDimsPosVec, newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec); return success(); } }; /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose /// semantics. struct FoldConsumerPackWithProducerLinalgTransposeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { auto transposeOp = packOp.getSource().getDefiningOp(); if (!transposeOp) return failure(); auto transposePermutation = transposeOp.getPermutation(); auto outerDimsPerm = packOp.getOuterDimsPerm(); auto innerDimsPos = packOp.getInnerDimsPos(); SmallVector newInnerDimsPosVec; SmallVector newOuterDimsPermVec = llvm::to_vector(transposePermutation); if (!outerDimsPerm.empty()) applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); // Can't use applyPermutationToVector for newInnerDimsPosVec since input and // permutation rank won't necessarily be equal in all cases. for (auto dim : innerDimsPos) newInnerDimsPosVec.push_back(transposePermutation[dim]); Value output = packOp.createDestinationTensor( rewriter, packOp.getLoc(), transposeOp.getOperand(0), packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec); rewriter.replaceOpWithNewOp( packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec, packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec); return success(); } }; } // namespace void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { patterns.insert( patterns.getContext()); } void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) { patterns.add( patterns.getContext()); } } // namespace tensor } // namespace mlir