155 lines
5.8 KiB
C++
155 lines
5.8 KiB
C++
//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
namespace mlir {
|
|
namespace tensor {
|
|
namespace {
|
|
|
|
static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
|
|
return llvm::all_of(
|
|
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
|
|
}
|
|
|
|
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
|
|
/// the pad op has zero low paddings, or if `pack` has no padding values.
|
|
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
|
|
using OpRewritePattern<PackOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(PackOp packOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto padOp = packOp.getSource().getDefiningOp<PadOp>();
|
|
|
|
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
|
|
return failure();
|
|
|
|
Value constantPaddingValue = padOp.getConstantPaddingValue();
|
|
if (!constantPaddingValue)
|
|
return failure();
|
|
|
|
if (auto paddingValue = packOp.getPaddingValue())
|
|
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<PackOp>(
|
|
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
|
|
packOp.getMixedTiles(), constantPaddingValue,
|
|
packOp.getOuterDimsPerm());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
|
|
/// has extract_slice semantics.
|
|
struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
|
|
using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
|
|
if (!unpackOp)
|
|
return failure();
|
|
|
|
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "rank-reduced folding is not supported");
|
|
}
|
|
|
|
// Check all offsets are zeros, and all strides are ones.
|
|
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
|
|
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "expects offsets to be 0s and strides to be 1s");
|
|
}
|
|
|
|
// Create a new empty output tensor.
|
|
Type elementType = unpackOp.getDestType().getElementType();
|
|
Value output = rewriter.create<EmptyOp>(
|
|
sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
|
|
rewriter.replaceOpWithNewOp<UnPackOp>(
|
|
sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
|
|
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
|
|
/// semantics.
|
|
struct FoldProducerPackWithConsumerLinalgTransposeOp
|
|
: public OpRewritePattern<linalg::TransposeOp> {
|
|
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
|
|
|
|
if (!packOp)
|
|
return failure();
|
|
|
|
auto innerDimsPos = packOp.getInnerDimsPos();
|
|
auto mixedInnerTiles = packOp.getMixedTiles();
|
|
auto outerDimsPerm = packOp.getOuterDimsPerm();
|
|
auto transposePerm = transposeOp.getPermutation();
|
|
SmallVector<int64_t> newOuterDimsPermVec;
|
|
SmallVector<int64_t> newInnerDimsPosVec;
|
|
SmallVector<OpFoldResult> newMixedInnerTilesVec;
|
|
int64_t srcRank = packOp.getSourceRank();
|
|
|
|
// Process transpose operation for non-tiled outer dimensions
|
|
for (unsigned int i = 0; i < srcRank; ++i) {
|
|
int64_t remappedPosition = transposePerm[i];
|
|
|
|
// If tensor.pack has outer_dims_perm attribute, then consider it during
|
|
// index remapping.
|
|
if (!outerDimsPerm.empty()) {
|
|
if (transposePerm[i] >= srcRank) {
|
|
return rewriter.notifyMatchFailure(
|
|
transposeOp,
|
|
"Cannot fold in tensor.pack if a tile dimension was transposed "
|
|
"with a non-tile dimension in linalg.transpose.");
|
|
}
|
|
remappedPosition = outerDimsPerm[remappedPosition];
|
|
}
|
|
|
|
newOuterDimsPermVec.push_back(remappedPosition);
|
|
}
|
|
|
|
// Process transpose operation for tiled inner dimensions
|
|
for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
|
|
int64_t remappedPosition = transposePerm[i] - srcRank;
|
|
newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
|
|
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
|
|
}
|
|
|
|
Value output = packOp.createDestinationTensor(
|
|
rewriter, transposeOp.getLoc(), packOp.getSource(),
|
|
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
|
|
|
|
rewriter.replaceOpWithNewOp<PackOp>(
|
|
transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
|
|
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
|
|
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
|
|
FoldProducerPackWithConsumerLinalgTransposeOp>(
|
|
patterns.getContext());
|
|
}
|
|
|
|
} // namespace tensor
|
|
} // namespace mlir
|