Refactored @Max191's PR https://github.com/llvm/llvm-project/pull/94637 to move it to `Tensor` From the original PR >This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other. I'm not sure if I put the code/tests in the right places, so let me know where those go if they aren't. cc @MaheshRavishankar @hanhanW --------- Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
230 lines
9.9 KiB
C++
230 lines
9.9 KiB
C++
//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tensor;
|
|
|
|
namespace {
|
|
/// Fold expand_shape(extract_slice) ops that cancel itself out.
|
|
struct FoldExpandOfRankReducingExtract
|
|
: public OpRewritePattern<ExpandShapeOp> {
|
|
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
RankedTensorType resultType = expandShapeOp.getResultType();
|
|
auto extractSliceOp =
|
|
expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
|
|
if (!extractSliceOp)
|
|
return failure();
|
|
RankedTensorType srcType = extractSliceOp.getSourceType();
|
|
|
|
// Only cases where the ExpandShapeOp can be folded away entirely are
|
|
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
|
|
// has no rank-reduction anymore are supported at the moment.
|
|
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
|
|
srcType, extractSliceOp.getStaticOffsets(),
|
|
extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
|
|
if (nonReducingExtractType != resultType)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
|
|
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
|
|
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
|
expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
|
|
mixedStrides);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold collapse_shape which only removes static dimensions of size `1`
|
|
/// into extract_slice.
|
|
struct FoldUnPaddingCollapseIntoExtract
|
|
: public OpRewritePattern<tensor::CollapseShapeOp> {
|
|
using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto extractSliceOp =
|
|
collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
|
|
// Collapse cannot be folded away with multiple users of the extract slice
|
|
// and it is not necessarily beneficial to only convert the collapse into
|
|
// another extract slice.
|
|
if (!extractSliceOp || !extractSliceOp->hasOneUse())
|
|
return failure();
|
|
|
|
// Only fold away simple collapse where all removed dimensions have static
|
|
// size `1`.
|
|
SliceVerificationResult res = isRankReducedType(
|
|
collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
|
|
if (res != SliceVerificationResult::Success)
|
|
return rewriter.notifyMatchFailure(collapseShapeOp,
|
|
"expected unpadding collapse");
|
|
|
|
Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
|
|
extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
|
|
extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
|
|
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
|
|
rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold insert_slice(collapse_shape) ops that cancel itself out.
|
|
template <typename OpTy>
|
|
struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy insertSliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto collapseShapeOp =
|
|
insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
|
|
if (!collapseShapeOp)
|
|
return failure();
|
|
RankedTensorType srcType = collapseShapeOp.getSrcType();
|
|
|
|
// Only cases where the CollapseShapeOp can be folded away entirely are
|
|
// supported. Moreover, only simple cases where the resulting InsertSliceOp
|
|
// has no rank-reduction anymore are supported at the moment.
|
|
RankedTensorType nonReducingInsertType =
|
|
RankedTensorType::get(insertSliceOp.getStaticSizes(),
|
|
insertSliceOp.getDestType().getElementType());
|
|
if (nonReducingInsertType != srcType)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
|
|
SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
|
|
rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
|
|
insertSliceOp.getDest(), mixedOffsets,
|
|
mixedSizes, mixedStrides);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold expand_shape which only adds static dimensions of size `1`
|
|
/// into insert_slice.
|
|
template <typename OpTy>
|
|
struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy insertSliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto expandShapeOp = insertSliceOp.getSource()
|
|
.template getDefiningOp<tensor::ExpandShapeOp>();
|
|
if (!expandShapeOp)
|
|
return failure();
|
|
|
|
// Only fold away simple expansion where all added dimensions have static
|
|
// size `1`.
|
|
SliceVerificationResult res = isRankReducedType(
|
|
expandShapeOp.getResultType(), expandShapeOp.getSrcType());
|
|
if (res != SliceVerificationResult::Success)
|
|
return rewriter.notifyMatchFailure(insertSliceOp,
|
|
"expected rank increasing expansion");
|
|
|
|
rewriter.modifyOpInPlace(insertSliceOp, [&]() {
|
|
insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
|
|
});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Pattern to bubble up a tensor.expand_shape op through a producer
|
|
/// tensor.collapse_shape op that has non intersecting reassociations.
|
|
struct BubbleUpExpandThroughParallelCollapse
|
|
: public OpRewritePattern<tensor::ExpandShapeOp> {
|
|
using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto collapseOp =
|
|
expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
|
|
if (!collapseOp)
|
|
return failure();
|
|
auto expandReInds = expandOp.getReassociationIndices();
|
|
auto collapseReInds = collapseOp.getReassociationIndices();
|
|
|
|
// Reshapes are parallel to each other if none of the reassociation indices
|
|
// have greater than 1 index for both reshapes.
|
|
for (auto [expandReassociation, collapseReassociation] :
|
|
llvm::zip_equal(expandReInds, collapseReInds)) {
|
|
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
|
|
return failure();
|
|
}
|
|
|
|
// Compute new reassociation indices and expanded/collaped shapes.
|
|
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
|
|
Location loc = expandOp->getLoc();
|
|
SmallVector<OpFoldResult> collapseSizes =
|
|
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
|
|
SmallVector<OpFoldResult> expandSizes(getMixedValues(
|
|
expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
|
|
SmallVector<OpFoldResult> newExpandSizes;
|
|
int64_t index = 0, expandIndex = 0, collapseIndex = 0;
|
|
for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
|
|
if (collapseReassociation.size() != 1) {
|
|
ReassociationIndices newCollapseReassociation;
|
|
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
|
|
newCollapseReassociation.push_back(index);
|
|
newExpandReInds.push_back({index++});
|
|
newExpandSizes.push_back(collapseSizes[collapseIndex++]);
|
|
}
|
|
newCollapseReInds.push_back(newCollapseReassociation);
|
|
expandIndex++;
|
|
continue;
|
|
}
|
|
ReassociationIndices newExpandReassociation;
|
|
auto expandReassociation = expandReInds[idx];
|
|
for (size_t i = 0; i < expandReassociation.size(); ++i) {
|
|
newExpandReassociation.push_back(index);
|
|
newCollapseReInds.push_back({index++});
|
|
newExpandSizes.push_back(expandSizes[expandIndex++]);
|
|
}
|
|
newExpandReInds.push_back(newExpandReassociation);
|
|
collapseIndex++;
|
|
}
|
|
|
|
// Swap reshape order.
|
|
SmallVector<Value> dynamicSizes;
|
|
SmallVector<int64_t> staticSizes;
|
|
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
|
|
auto expandResultType = expandOp.getResultType().clone(staticSizes);
|
|
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
|
|
loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
|
|
newExpandSizes);
|
|
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
|
|
expandOp, newExpand.getResult(), newCollapseReInds);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns
|
|
.add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
|
|
FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
|
|
FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
|
|
FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
|
|
FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
|
|
patterns.getContext());
|
|
}
|
|
|
|
void mlir::tensor::populateBubbleUpExpandShapePatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
|
|
}
|