Files
clang-p2996/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
Han-Chung Wang c39915fa2e [mlir][NFC] Simplify constant checks with isOneInteger and renamed isZeroInteger. (#139340)
The revision adds isOneInteger helper, and simplifies the existing code
with the two methods. It removes some lambda, which makes code cleaner.

For downstream users, you can update the code with the below script.

```bash
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
2025-05-20 14:53:02 -07:00

62 lines
2.2 KiB
C++

//===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Swap a `tensor.extract_slice` with the producer of the source if the producer
// implements the `TilingInterface`. When used in conjunction with tiling this
// effectively tiles + fuses the producer with its consumer.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Interfaces/TilingInterface.h"
using namespace mlir;
FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
if (!producerOp)
return failure();
// `TilingInterface` currently only supports strides being 1.
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
return failure();
FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes());
if (failed(tiledResult))
return failure();
return *tiledResult;
}
FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
OpOperand &consumer) {
auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
if (!consumerOp)
return failure();
// `TilingInterface` currently only supports strides being 1.
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
return failure();
FailureOr<TilingResult> tiledResult =
consumerOp.getTiledImplementationFromOperandTile(
builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes());
if (failed(tiledResult))
return failure();
return *tiledResult;
}