Files
clang-p2996/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
Matthias Springer 0a8e3dd432 [mlir][Interfaces] DestinationStyleOpInterface: Rename hasTensor/BufferSemantics (#77574)
Rename interface functions as follows:
* `hasTensorSemantics` -> `hasPureTensorSemantics`
* `hasBufferSemantics` -> `hasPureBufferSemantics`

These two functions return "true" if the op has tensor/buffer operands
but not buffer/tensor operands.

Also drop the "ranked" part from the interface, i.e., do not distinguish
between ranked/unranked types.

The new function names describe the functions more accurately. They also
align their semantics with the notion of "tensor semantics" with the
bufferization framework. (An op is supposed to be bufferized if it has
tensor operands, and we don't care if it also has memref operands.)

This change is in preparation of #75273, which adds
`BufferizableOpInterface::hasTensorSemantics`. By renaming the functions
in the `DestinationStyleOpInterface`, we can avoid name clashes between
the two interfaces.
2024-01-12 10:02:54 +01:00

139 lines
5.3 KiB
C++

//===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns that transforms linalg.<op> +
// tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
// the computation for the linalg op.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
namespace {
/// Bubble up extract_slice above Linalg operation.
///
/// A sequence of operations
///
/// ```mlir
/// %0 = linalg.<op> ... arg0, arg1, ...
/// %1 = tensor.extract_slice %0 ...
/// ```
///
/// can be replaced with
///
/// ```mlir
/// %0 = tensor.extract_slice %arg0
/// %1 = tensor.extract_slice %arg1
/// %2 = linalg.<op> ... %0, %1, ...
/// ```
///
/// This results in the reduce computation of the linalg operation.
///
struct BubbleUpExtractSliceOpPattern
: OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const final {
Value source = sliceOp.getSource();
auto linalgOp = source.getDefiningOp<LinalgOp>();
if (!linalgOp) {
return rewriter.notifyMatchFailure(sliceOp,
"expected source to be linalg op");
}
// TODO: we might relax this if we want heuristics to detect that all uses
// are small portion of the output.
if (!linalgOp->hasOneUse()) {
return rewriter.notifyMatchFailure(sliceOp,
"expected single use of linalg op");
}
if (linalgOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(sliceOp,
"expected single output of linalg op");
}
if (!linalgOp.hasPureTensorSemantics()) {
return rewriter.notifyMatchFailure(sliceOp,
"expected tensor of linalg op");
}
if (!sliceOp.hasUnitStride())
return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
}
OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
if (!indexingMap.isProjectedPermutation()) {
return rewriter.notifyMatchFailure(
sliceOp, "expected a projected permutation for output");
}
auto linalgLoc = linalgOp.getLoc();
SmallVector<OpFoldResult> allShapeSizes =
linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
if (!shapeSizesToLoopsMap) {
return rewriter.notifyMatchFailure(
linalgOp, "failed to get loops map from shape sizes");
}
SmallVector<OpFoldResult> sizeBounds =
affine::makeComposedFoldedMultiResultAffineApply(
rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);
// The offsets and sizes from the slice operation only give you the tile
// size of the output. Use that compute the tile sizes and offsets of the
// loops. For loops not used to access the output, set the tile sizes to
// loop bounds and set the offset to 0.
SmallVector<OpFoldResult> tileOffsets(sizeBounds.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> tileSizes = sizeBounds;
for (auto const &result : enumerate(indexingMap.getResults())) {
unsigned position = cast<AffineDimExpr>(result.value()).getPosition();
tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
}
SmallVector<Value> valuesToTile = linalgOp->getOperands();
SmallVector<Value> tiledOperands =
makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
tileOffsets, tileSizes, sizeBounds,
/*omitPartialTileCheck=*/true);
SmallVector<Type, 4> resultTensorTypes;
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable())
resultTensorTypes.push_back(
tiledOperands[opOperand.getOperandNumber()].getType());
Operation *newOp =
clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
rewriter.replaceOp(sliceOp, newOp->getResults());
return success();
}
};
} // namespace
void mlir::linalg::populateBubbleUpExtractSliceOpPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.add<BubbleUpExtractSliceOpPattern>(context);
}