Files
clang-p2996/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp

155 lines
5.8 KiB
C++

//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
namespace mlir {
namespace tensor {
namespace {
static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
return llvm::all_of(
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
}
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
auto padOp = packOp.getSource().getDefiningOp<PadOp>();
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
return failure();
Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue)
return failure();
if (auto paddingValue = packOp.getPaddingValue())
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();
rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
packOp.getMixedTiles(), constantPaddingValue,
packOp.getOuterDimsPerm());
return success();
}
};
/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
/// has extract_slice semantics.
struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
if (!unpackOp)
return failure();
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
return rewriter.notifyMatchFailure(
sliceOp, "rank-reduced folding is not supported");
}
// Check all offsets are zeros, and all strides are ones.
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
return rewriter.notifyMatchFailure(
sliceOp, "expects offsets to be 0s and strides to be 1s");
}
// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
Value output = rewriter.create<EmptyOp>(
sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
rewriter.replaceOpWithNewOp<UnPackOp>(
sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
return success();
}
};
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldProducerPackWithConsumerLinalgTransposeOp
: public OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
if (!packOp)
return failure();
auto innerDimsPos = packOp.getInnerDimsPos();
auto mixedInnerTiles = packOp.getMixedTiles();
auto outerDimsPerm = packOp.getOuterDimsPerm();
auto transposePerm = transposeOp.getPermutation();
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
int64_t srcRank = packOp.getSourceRank();
// Process transpose operation for non-tiled outer dimensions
for (unsigned int i = 0; i < srcRank; ++i) {
int64_t remappedPosition = transposePerm[i];
// If tensor.pack has outer_dims_perm attribute, then consider it during
// index remapping.
if (!outerDimsPerm.empty()) {
if (transposePerm[i] >= srcRank) {
return rewriter.notifyMatchFailure(
transposeOp,
"Cannot fold in tensor.pack if a tile dimension was transposed "
"with a non-tile dimension in linalg.transpose.");
}
remappedPosition = outerDimsPerm[remappedPosition];
}
newOuterDimsPermVec.push_back(remappedPosition);
}
// Process transpose operation for tiled inner dimensions
for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
int64_t remappedPosition = transposePerm[i] - srcRank;
newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
}
Value output = packOp.createDestinationTensor(
rewriter, transposeOp.getLoc(), packOp.getSource(),
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<PackOp>(
transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
return success();
}
};
} // namespace
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
FoldProducerPackWithConsumerLinalgTransposeOp>(
patterns.getContext());
}
} // namespace tensor
} // namespace mlir