The current patterns compared the reassocation indices for the two ops and failed if neither of them were of size 1. This patch relaxes this restriction by handling a new case where the reassociation indices might be of the same size. Also generalizes to cases where when generating the swapped `tensor.expand_shape` -> `tensor.collapse_shape` if one of them is degenerate, those are not generated. Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
775 lines
35 KiB
C++
775 lines
35 KiB
C++
//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tensor;
|
|
|
|
namespace {
|
|
/// Fold expand_shape(extract_slice) ops that cancel itself out.
|
|
struct FoldExpandOfRankReducingExtract
|
|
: public OpRewritePattern<ExpandShapeOp> {
|
|
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
RankedTensorType resultType = expandShapeOp.getResultType();
|
|
auto extractSliceOp =
|
|
expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
|
|
if (!extractSliceOp)
|
|
return failure();
|
|
RankedTensorType srcType = extractSliceOp.getSourceType();
|
|
|
|
// Only cases where the ExpandShapeOp can be folded away entirely are
|
|
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
|
|
// has no rank-reduction anymore are supported at the moment.
|
|
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
|
|
srcType, extractSliceOp.getStaticOffsets(),
|
|
extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
|
|
if (nonReducingExtractType != resultType)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
|
|
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
|
|
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
|
expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
|
|
mixedStrides);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold collapse_shape which only removes static dimensions of size `1`
|
|
/// into extract_slice.
|
|
struct FoldUnPaddingCollapseIntoExtract
|
|
: public OpRewritePattern<tensor::CollapseShapeOp> {
|
|
using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto extractSliceOp =
|
|
collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
|
|
// Collapse cannot be folded away with multiple users of the extract slice
|
|
// and it is not necessarily beneficial to only convert the collapse into
|
|
// another extract slice.
|
|
if (!extractSliceOp || !extractSliceOp->hasOneUse())
|
|
return failure();
|
|
|
|
// Only fold away simple collapse where all removed dimensions have static
|
|
// size `1`.
|
|
SliceVerificationResult res = isRankReducedType(
|
|
collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
|
|
if (res != SliceVerificationResult::Success)
|
|
return rewriter.notifyMatchFailure(collapseShapeOp,
|
|
"expected unpadding collapse");
|
|
|
|
Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
|
|
extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
|
|
extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
|
|
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
|
|
rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold insert_slice(collapse_shape) ops that cancel itself out.
|
|
template <typename OpTy>
|
|
struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy insertSliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto collapseShapeOp =
|
|
insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
|
|
if (!collapseShapeOp)
|
|
return failure();
|
|
RankedTensorType srcType = collapseShapeOp.getSrcType();
|
|
|
|
// Only cases where the CollapseShapeOp can be folded away entirely are
|
|
// supported. Moreover, only simple cases where the resulting InsertSliceOp
|
|
// has no rank-reduction anymore are supported at the moment.
|
|
RankedTensorType nonReducingInsertType =
|
|
RankedTensorType::get(insertSliceOp.getStaticSizes(),
|
|
insertSliceOp.getDestType().getElementType());
|
|
if (nonReducingInsertType != srcType)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
|
|
SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
|
|
rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
|
|
insertSliceOp.getDest(), mixedOffsets,
|
|
mixedSizes, mixedStrides);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold expand_shape which only adds static dimensions of size `1`
|
|
/// into insert_slice.
|
|
template <typename OpTy>
|
|
struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy insertSliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto expandShapeOp = insertSliceOp.getSource()
|
|
.template getDefiningOp<tensor::ExpandShapeOp>();
|
|
if (!expandShapeOp)
|
|
return failure();
|
|
|
|
// Only fold away simple expansion where all added dimensions have static
|
|
// size `1`.
|
|
SliceVerificationResult res = isRankReducedType(
|
|
expandShapeOp.getResultType(), expandShapeOp.getSrcType());
|
|
if (res != SliceVerificationResult::Success)
|
|
return rewriter.notifyMatchFailure(insertSliceOp,
|
|
"expected rank increasing expansion");
|
|
|
|
rewriter.modifyOpInPlace(insertSliceOp, [&]() {
|
|
insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
|
|
});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Pattern to bubble up a tensor.expand_shape op through a producer
|
|
/// tensor.collapse_shape op that has non intersecting reassociations.
|
|
struct BubbleUpExpandThroughParallelCollapse
|
|
: public OpRewritePattern<tensor::ExpandShapeOp> {
|
|
using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto collapseOp =
|
|
expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
|
|
if (!collapseOp)
|
|
return failure();
|
|
auto expandReInds = expandOp.getReassociationIndices();
|
|
auto collapseReInds = collapseOp.getReassociationIndices();
|
|
|
|
// Special case where the collapsed tensor to expand is a 0-D tensor,
|
|
// then the reassociation maps will be empty and not produce valid results.
|
|
if (expandReInds.size() == 0) {
|
|
return failure();
|
|
}
|
|
|
|
// Reshapes are parallel to each other (by construction the number of
|
|
// reassociations specified in the collapse and expand are the same), if at
|
|
// any position
|
|
// 1. either the reassociation indices are of the same size, or
|
|
// 2. either the reassociation in the collapse or the expand is of size 1.
|
|
ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
|
|
ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
|
|
for (auto [expandReassociation, collapseReassociation] :
|
|
llvm::zip_equal(expandReInds, collapseReInds)) {
|
|
if (collapseReassociation.size() == expandReassociation.size()) {
|
|
// Even if the reassociations are the same, the collapse/expand should
|
|
// result in the same dimensions. i.e 4x8x2 into 64 should be expanded
|
|
// into 4x8x2 again. In presense of dynamic dimensions one can only
|
|
// verify "equality" when there is only one dynamic dimension present,
|
|
// and all other static dimensions are equal.
|
|
ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
|
|
collapseReassociation.front(), collapseReassociation.size());
|
|
int64_t numCollapsedDynamic =
|
|
llvm::count_if(collapsedStaticShapes,
|
|
[](int64_t d) { return ShapedType::isDynamic(d); });
|
|
ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
|
|
expandReassociation.front(), expandReassociation.size());
|
|
int64_t numExpandedDynamic =
|
|
llvm::count_if(expandedStaticShapes,
|
|
[](int64_t d) { return ShapedType::isDynamic(d); });
|
|
if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
|
|
collapsedStaticShapes != expandedStaticShapes) {
|
|
return failure();
|
|
}
|
|
continue;
|
|
}
|
|
// If the reassociations are not same, one or the other needs to be of
|
|
// size one.
|
|
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
|
|
return failure();
|
|
}
|
|
|
|
// Compute new reassociation indices and expanded/collaped shapes.
|
|
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
|
|
Location loc = expandOp->getLoc();
|
|
SmallVector<OpFoldResult> sourceSizes =
|
|
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
|
|
SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
|
|
SmallVector<OpFoldResult> newExpandSizes;
|
|
|
|
int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
|
|
resultSizeIndex = 0;
|
|
|
|
for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
|
|
auto &collapseReassociation = collapseReInds[idx];
|
|
auto &expandReassociation = expandReInds[idx];
|
|
|
|
// Case 1. The reassociations are same in the collapse producer
|
|
// and expand consumer. In the swapped expand, each of the final
|
|
// dimensions are kept as is in the expand and the collapse. So,
|
|
// for every element in the `ReassocationIndices` vector add a new
|
|
// `ReassociationIndices` vector for the swapped expand and collapse
|
|
// (of size 1).
|
|
if (collapseReassociation.size() == expandReassociation.size()) {
|
|
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
|
|
newCollapseReInds.push_back({newCollapseIndex++});
|
|
newExpandReInds.push_back({newExpandIndex++});
|
|
newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
|
|
sourceSizeIndex++;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
|
|
// in the expand is of size == 1). In this case, the original dimensions
|
|
// are preserved on expansion and collapsed subsequently.
|
|
if (collapseReassociation.size() != 1) {
|
|
ReassociationIndices newCollapseReassociation;
|
|
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
|
|
newCollapseReassociation.push_back(newCollapseIndex++);
|
|
newExpandReInds.push_back({newExpandIndex++});
|
|
newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
|
|
}
|
|
resultSizeIndex++;
|
|
newCollapseReInds.push_back(newCollapseReassociation);
|
|
continue;
|
|
}
|
|
|
|
// Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
|
|
// in the collapse is of size == 1). In this case, the expansion happens
|
|
// first and the expanded dimensions are preserved on collapse.
|
|
ReassociationIndices newExpandReassociation;
|
|
for (size_t i = 0; i < expandReassociation.size(); ++i) {
|
|
newExpandReassociation.push_back(newExpandIndex++);
|
|
newCollapseReInds.push_back({newCollapseIndex++});
|
|
newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
|
|
}
|
|
newExpandReInds.push_back(newExpandReassociation);
|
|
sourceSizeIndex++;
|
|
}
|
|
|
|
// Swap reshape order.
|
|
SmallVector<Value> dynamicSizes;
|
|
SmallVector<int64_t> staticSizes;
|
|
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
|
|
auto expandResultType = expandOp.getResultType().clone(staticSizes);
|
|
Value newCollapseSrc = collapseOp.getSrc();
|
|
// If the number of reassociation indices in the new `expand_shape` op
|
|
// matches the number of dimensions of the result, then the expand_shape
|
|
// is a no-op.
|
|
if (newExpandReInds.size() != newExpandSizes.size()) {
|
|
newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
|
|
loc, expandResultType, newCollapseSrc, newExpandReInds,
|
|
newExpandSizes);
|
|
}
|
|
|
|
// If the number of reassociation indices in the new `collapse_shape` op
|
|
// matches the number of dimensions of the source, then the collapse_shape
|
|
// is a no-op.
|
|
Value replacement = newCollapseSrc;
|
|
if (newCollapseReInds.size() != newExpandSizes.size()) {
|
|
replacement = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, newCollapseSrc, newCollapseReInds);
|
|
}
|
|
rewriter.replaceOp(expandOp, replacement);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts `tensor.extract_slice(tensor.expand_shape)` to
|
|
/// `tensor.expand_shape(tensor.extract_slice)`.
|
|
///
|
|
/// For this transformation to be possible, the slice must be fully contiguous
|
|
/// within each reassociation group of the expand_shape. A slice is defined as
|
|
/// fully contiguous within a reassociation group if after flattening the
|
|
/// reassociation group to a single 1D range, then the slice taken out of the
|
|
/// group could be defined as a single contiguous subrange within that range.
|
|
///
|
|
/// Rank reducing slices are not supported.
|
|
///
|
|
/// Example:
|
|
/// The transformation is possible because each reassociation group has a
|
|
/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
|
|
/// ```
|
|
/// BEFORE:
|
|
/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
|
|
/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
|
|
/// %slice = tensor.extract_slice %reshape ...
|
|
/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
|
|
///
|
|
/// AFTER:
|
|
/// %slice = tensor.extract_slice %in ...
|
|
/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
|
|
/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
|
|
/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
|
|
/// ```
|
|
///
|
|
/// Note - this pattern could be extended to be a swap pattern between
|
|
/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
|
|
/// implemented only as a bubble up pattern for `tensor.extract_slice`.
|
|
struct BubbleUpExpandShapeThroughExtractSlice
|
|
: public OpRewritePattern<tensor::ExtractSliceOp> {
|
|
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto expandShapeOp =
|
|
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
|
|
|
|
if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
|
|
rewriter)
|
|
.failed())
|
|
return failure();
|
|
|
|
// The tensor.extract_slice before applying the pattern works on the result
|
|
// of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
|
|
// referring to the state before applying the pattern are named with the
|
|
// prefix "expanded", and ones referring to the state after applying the
|
|
// pattern are named with the prefix "collapsed".
|
|
SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
|
|
SmallVector<OpFoldResult> expandedShape =
|
|
getMixedValues(expandShapeOp.getStaticOutputShape(),
|
|
expandShapeOp.getOutputShape(), rewriter);
|
|
|
|
// Helper variables and function for accumulating the size values.
|
|
Location loc = expandShapeOp->getLoc();
|
|
AffineExpr d0, d1, d2;
|
|
bindDims(rewriter.getContext(), d0, d1, d2);
|
|
// Multiply two integers.
|
|
auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
|
|
auto mulMap = AffineMap::get(2, 0, {d0 * d1});
|
|
return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
|
|
{v1, v2});
|
|
};
|
|
|
|
// Compute new offsets, sizes, and strides for tensor.extract_slice.
|
|
// The new tensor.extract_slice will work on a tensor that has has a rank of
|
|
// ReassociationIndices.size(). In the loop a single offset, size, and
|
|
// stride value is computed per reassociation group.
|
|
SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
|
|
collapsedStrides;
|
|
for (const ReassociationIndices &indices :
|
|
expandShapeOp.getReassociationIndices()) {
|
|
// collapsedSize will hold the size of the single dim that represents the
|
|
// reassociation group in the non expanded tensor.
|
|
OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
|
|
// The reassocGroupSizes and reassocGroupOffsets are used to create an
|
|
// affine.linearize_index op to linearize the single offset value required
|
|
// for this reassociation group.
|
|
SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
|
|
|
|
for (long expandedDim : indices) {
|
|
// reassocGroupSizes and reassocGroupOffsets can be obtained directly
|
|
// from the expanded state, but the collapsed size requires calculation
|
|
// as it did not previously exist.
|
|
reassocGroupSizes.push_back(expandedShape[expandedDim]);
|
|
reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
|
|
collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
|
|
}
|
|
|
|
SmallVector<Value> offsetVals =
|
|
llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
|
|
return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
|
|
});
|
|
OpFoldResult collapsedOffset =
|
|
rewriter
|
|
.create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
|
|
reassocGroupSizes,
|
|
/*disjoint=*/true)
|
|
.getResult();
|
|
collapsedOffsets.push_back(collapsedOffset);
|
|
collapsedSizes.push_back(collapsedSize);
|
|
|
|
// Only unit stride is supported.
|
|
collapsedStrides.push_back(rewriter.getIndexAttr(1));
|
|
}
|
|
|
|
// The shape of the result can be obtained from the sizes passed in.
|
|
SmallVector<Value> dynDims;
|
|
SmallVector<int64_t> shape;
|
|
dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
|
|
RankedTensorType resultType = RankedTensorType::get(
|
|
shape, expandShapeOp.getResultType().getElementType());
|
|
|
|
// Create a new ExtractSliceOp and ExpandShapeOp.
|
|
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
|
loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
|
|
collapsedStrides);
|
|
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
|
sliceOp, resultType, newSliceOp,
|
|
expandShapeOp.getReassociationIndices(), expandedSizes);
|
|
return success();
|
|
}
|
|
|
|
// Helper function to check if all the required conditions for the
|
|
// tensor.extract_slice to be bubbled up through the tensor.expand_shape are
|
|
// met.
|
|
LogicalResult
|
|
checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
|
|
tensor::ExpandShapeOp expandShapeOp,
|
|
PatternRewriter &rewriter) const {
|
|
|
|
if (!expandShapeOp) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "tensor.extract_slice source not produced by expand_shape");
|
|
}
|
|
|
|
if (!sliceOp.hasUnitStride()) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
|
|
"be supported in this transformation.");
|
|
}
|
|
|
|
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
|
|
|
|
if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
|
|
sizes.size()) {
|
|
return rewriter.notifyMatchFailure(sliceOp,
|
|
"unimplemented: rank reducing slice");
|
|
}
|
|
|
|
SmallVector<OpFoldResult> outputShape =
|
|
getMixedValues(expandShapeOp.getStaticOutputShape(),
|
|
expandShapeOp.getOutputShape(), rewriter);
|
|
|
|
std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
|
|
isZeroOffsetAndFullSize =
|
|
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
|
|
if (!isConstantIntValue(offset, 0))
|
|
return false;
|
|
FailureOr<bool> maybeEqual =
|
|
ValueBoundsConstraintSet::areEqual(sliceSize, size);
|
|
return llvm::succeeded(maybeEqual) && maybeEqual.value();
|
|
};
|
|
|
|
// Check that the slice is contiguous within each reassociation group.
|
|
// The slice is contiguous only if after the first dimension where a non
|
|
// unit slice is taken, the slice size on all subsequent dimensions of the
|
|
// group is equal to the entire size of the dimension.
|
|
// Examples of contiguous slices:
|
|
// full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
|
|
// full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
|
|
// Examples of non contiguous slices:
|
|
// full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
|
|
// full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
|
|
for (const ReassociationIndices &indices :
|
|
expandShapeOp.getReassociationIndices()) {
|
|
int64_t i = 0;
|
|
int64_t e = indices.size();
|
|
// Find the first expanded dim after the first dim with non-unit extracted
|
|
// size.
|
|
for (; i < e; ++i) {
|
|
if (!isConstantIntValue(sizes[indices[i]], 1)) {
|
|
// +1 to skip the first non-unit size dim.
|
|
i++;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Verify that all subsequent dimensions extract the full size of the
|
|
// source tensor.
|
|
for (; i < e; ++i) {
|
|
int64_t expandedDim = indices[i];
|
|
if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
|
|
outputShape[expandedDim])) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "Not a contiguous slice of the expanded tensor.");
|
|
}
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
|
|
/// `tensor.collapse_shape(tensor.extract_slice)`.
|
|
///
|
|
/// For this transformation to be possible - after bubbling up, the extraction
|
|
/// of the contiguous slice must be representable as a single slice obtained via
|
|
/// tensor.extract_slice within each reassociation group of the src.
|
|
///
|
|
/// In case the size and offset extracted are static then this is possible if
|
|
/// the following conditions are met within each reassociation group:
|
|
/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
|
|
/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
|
|
/// shape of a desired slice. A slice of shape S can be extracted as a
|
|
/// contiguous span of elements if and only if there exists an index k in {0, 1,
|
|
/// ..., n} such that:
|
|
/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
|
|
/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
|
|
/// one dimension),
|
|
/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
|
|
/// in full).
|
|
/// In other words, the slice shape S must be of the form:
|
|
/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
|
|
///
|
|
/// In case the size and/or offset extracted are dynamic then this is possible
|
|
/// only if there is single dimension in the reassociation group that has a size
|
|
/// not equal to 1.
|
|
/// In other words, the tensor shape must be of the form:
|
|
/// [ 1, 1, ..., 1, A, 1, ...,1 ]
|
|
/// Note - it might be possible to enable this pattern for more cases when the
|
|
/// size/offset are dynamic via performing an analysis of the possible values
|
|
/// that could be given to the size/offset.
|
|
///
|
|
/// Example:
|
|
/// The transformation is possible because each reassociation group can be
|
|
/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
|
|
/// [20->10]).
|
|
/// ```
|
|
/// BEFORE:
|
|
/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
|
|
/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
|
|
/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
|
|
/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
|
|
///
|
|
/// AFTER:
|
|
/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
|
|
// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
|
|
/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
|
|
/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
|
|
/// ```
|
|
///
|
|
/// Negative example:
|
|
/// The transformation is not possible because we cannot use a single slice to
|
|
/// represent the reassociation group [2x3x10->???]. If we would want the
|
|
/// collapse to be after the extraction, we would need to extract multiple
|
|
/// slices and concat them together.
|
|
/// ```
|
|
/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
|
|
/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
|
|
/// tensor<60xf32> to tensor<15xf32>
|
|
/// ```
|
|
/// If we would want the collapse to be after the extraction, a possible
|
|
/// alternate transformation could be to extract multiple slices and concat them
|
|
/// together:
|
|
/// ```
|
|
/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
|
|
/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
|
|
/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
|
|
/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
|
|
/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
|
|
/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
|
|
/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
|
|
/// to tensor<15xf32>
|
|
/// ```
|
|
/// But this is not the intended purpose of the transformation.
|
|
struct BubbleUpCollapseShapeThroughExtractSlice
|
|
: public OpRewritePattern<tensor::ExtractSliceOp> {
|
|
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto collapseShapeOp =
|
|
sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
|
|
if (!collapseShapeOp) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp,
|
|
"tensor.extract_slice source not produced by tensor.collapse_shape");
|
|
}
|
|
|
|
if (!sliceOp.hasUnitStride()) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
|
|
"be supported in this transformation.");
|
|
}
|
|
|
|
// The tensor.extract_slice before applying the pattern works on the result
|
|
// of the tensor.collapse_shape, so variables (i.e. inputs for
|
|
// ExtractSliceOp) referring to the state before applying the pattern are
|
|
// named with the prefix "collapsed", and ones referring to the state after
|
|
// applying the pattern are named with the prefix "expanded".
|
|
SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
|
|
|
|
if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
|
|
collapsedSizes.size()) {
|
|
return rewriter.notifyMatchFailure(sliceOp,
|
|
"unimplemented: rank reducing slice");
|
|
}
|
|
|
|
ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
|
|
SmallVector<ReassociationIndices, 4> reassociationIndices =
|
|
collapseShapeOp.getReassociationIndices();
|
|
|
|
// Compute new offsets, sizes, and strides for tensor.extract_slice.
|
|
// The new tensor.extract_slice will work on a tensor that has has a rank
|
|
// equal to the rank of the src of the collapse_shape. In each iteration of
|
|
// the loop, the offsets and sizes will be computed per reassociation group.
|
|
SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
|
|
SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
|
|
rewriter.getIndexAttr(1));
|
|
|
|
for (auto [collapsedSize, collapsedOffset, reassocIndices] :
|
|
llvm::zip_equal(collapsedSizes, collapsedOffsets,
|
|
collapseShapeOp.getReassociationIndices())) {
|
|
// CASE #1 - size and/or offset are dynamic.
|
|
// In this case, the slice can be represented as a contiguous slice only
|
|
// if there is a single dimension in the reassociation group that has a
|
|
// size not equal to 1.
|
|
if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
|
|
int nonUnitSizeCount = 0;
|
|
for (int64_t expandedShapeIdx : reassocIndices) {
|
|
if (srcShape[expandedShapeIdx] != 1) {
|
|
nonUnitSizeCount++;
|
|
expandedSizes.push_back(collapsedSize);
|
|
expandedOffsets.push_back(collapsedOffset);
|
|
continue;
|
|
}
|
|
|
|
expandedSizes.push_back(rewriter.getIndexAttr(1));
|
|
expandedOffsets.push_back(rewriter.getIndexAttr(0));
|
|
}
|
|
|
|
if (nonUnitSizeCount != 1) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp,
|
|
"unsupported: slice cannot be verified to be contiguous");
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// CASE #2 = size and offset are static.
|
|
// Verify that the slice can be represented as a contiguous slice of the
|
|
// src of the collapse_shape.
|
|
// Checking this is done on order of most internal dimensions first,
|
|
// so traversal is done in reverse order of the reassociation group.
|
|
// If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
|
|
// ...,An] then we first find the size and offset for n...k+1 then for k
|
|
// and then for k-1...0.
|
|
|
|
// currentCollapsedsize and currentCollapsedOffset are initialized with
|
|
// the original collapsed size and offset and divided by the expanded
|
|
// shape size in each dimension as we go along the reassociation group.
|
|
// In essence we are spreading the original collapsed size and offset over
|
|
// the various expanded slice dimensions.
|
|
// The variables are used both to check the validity of the slice and to
|
|
// compute the expanded sizes and offsets.
|
|
int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
|
|
int64_t currentCollapsedOffset =
|
|
getConstantIntValue(collapsedOffset).value();
|
|
|
|
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
|
|
|
|
ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
|
|
reassocIndices.rend());
|
|
int64_t idx = 0;
|
|
int64_t reassocGroupSize = reassocIndices.size();
|
|
|
|
// First handle the trailing dimensions where the slice size should be
|
|
// equal to the tensor shape and the offset should be 0 (n...k+1).
|
|
for (; idx < reassocGroupSize; ++idx) {
|
|
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
|
|
|
|
if (currentCollapsedsize < expandedShapeSize)
|
|
break;
|
|
|
|
// We need to make sure that the slice size can be set to the shape size
|
|
// and the offset to 0.
|
|
if ((currentCollapsedsize % expandedShapeSize) != 0 ||
|
|
(currentCollapsedOffset % expandedShapeSize) != 0) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "unsupported: cannot be extracted as a contiguous slice "
|
|
"of the src of the collapse_shape");
|
|
}
|
|
|
|
groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
|
|
groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
|
|
|
|
currentCollapsedsize /= expandedShapeSize;
|
|
currentCollapsedOffset /= expandedShapeSize;
|
|
}
|
|
|
|
// Now handle the first dim where slicing occurs on (k).
|
|
if (idx < reassocGroupSize) {
|
|
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
|
|
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
|
|
// We need to make sure that the slice size in this dim + offset will
|
|
// not exceed the shape size.
|
|
if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "unsupported: slice cannot be extracted as a contiguous "
|
|
"slice of the src of the collapse_shape");
|
|
}
|
|
|
|
groupExpandedSizes.push_back(
|
|
rewriter.getIndexAttr(currentCollapsedsize));
|
|
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
|
|
|
|
currentCollapsedOffset /= expandedShapeSize;
|
|
}
|
|
|
|
// Now handle the leading dimensions where the slice size is equal to 1
|
|
// (k-1...0).
|
|
// The size for these dimensions must be 1 because of how we constructed
|
|
// the slice size of the expanded shape. We spread the original collapsed
|
|
// size over the expanded shape sizes until we reached dimension k where
|
|
// the remaining size was smaller than the expanded shape size, and spread
|
|
// the remaining size on it. So, now we are left with only 1s.
|
|
for (idx++; idx < reassocGroupSize; ++idx) {
|
|
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
|
|
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
|
|
groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
|
|
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
|
|
currentCollapsedOffset /= expandedShapeSize;
|
|
}
|
|
|
|
expandedSizes.append(groupExpandedSizes.rbegin(),
|
|
groupExpandedSizes.rend());
|
|
expandedOffsets.append(groupExpandedOffsets.rbegin(),
|
|
groupExpandedOffsets.rend());
|
|
}
|
|
|
|
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
|
collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
|
|
expandedSizes, expandedStrides);
|
|
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
|
|
sliceOp, sliceOp.getResultType(), newSliceOp,
|
|
collapseShapeOp.getReassociationIndices());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns
|
|
.add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
|
|
FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
|
|
FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
|
|
FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
|
|
FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
|
|
patterns.getContext());
|
|
}
|
|
|
|
void mlir::tensor::populateBubbleUpExpandShapePatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
|
|
}
|
|
|
|
void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<BubbleUpExpandShapeThroughExtractSlice,
|
|
BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
|
|
}
|