//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::tensor; namespace { /// Fold expand_shape(extract_slice) ops that cancel itself out. struct FoldExpandOfRankReducingExtract : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, PatternRewriter &rewriter) const override { RankedTensorType resultType = expandShapeOp.getResultType(); auto extractSliceOp = expandShapeOp.getSrc().getDefiningOp(); if (!extractSliceOp) return failure(); RankedTensorType srcType = extractSliceOp.getSourceType(); // Only cases where the ExpandShapeOp can be folded away entirely are // supported. Moreover, only simple cases where the resulting ExtractSliceOp // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( srcType, extractSliceOp.getStaticOffsets(), extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); if (nonReducingExtractType != resultType) return failure(); SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); rewriter.replaceOpWithNewOp( expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, mixedStrides); return success(); } }; /// Fold insert_slice(collapse_shape) ops that cancel itself out. template struct FoldInsertOfRankReducingInsert : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override { auto collapseShapeOp = insertSliceOp.getSource().template getDefiningOp(); if (!collapseShapeOp) return failure(); RankedTensorType srcType = collapseShapeOp.getSrcType(); // Only cases where the CollapseShapeOp can be folded away entirely are // supported. Moreover, only simple cases where the resulting InsertSliceOp // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingInsertType = RankedTensorType::get(insertSliceOp.getStaticSizes(), insertSliceOp.getDestType().getElementType()); if (nonReducingInsertType != srcType) return failure(); SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); SmallVector mixedSizes = insertSliceOp.getMixedSizes(); SmallVector mixedStrides = insertSliceOp.getMixedStrides(); rewriter.replaceOpWithNewOp(insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(), mixedOffsets, mixedSizes, mixedStrides); return success(); } }; } // namespace void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) { patterns.add, FoldInsertOfRankReducingInsert>( patterns.getContext()); }