Currently the `getTiledImplementation` and `generateResultTileValue`
return just `SmallVector<Operation *>` and `FailureOr<Value>`.
- For `getTiledImplementation` returning empty implies tiling wasnt
done. There is also an implicit assumption that the tiled operation
results correspond to the tiled values of the result of the original
operation. This cannot handle cases where the tiled implementation
might use multiple operations to compute the tiled value for the
results of the untiled operation. Sometimes, the tiled operation
might not directly give the tiled values, and might require casts,
etc to get a replacement.
- For `generateResultTileValue`, it is assumed that the op defining
the returned `Value` is the operation that represents the tiled
computation. Again presence of casts, etc violate this.
Instead make these methods return
```
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
};
```
The `tiledOps` represent the operations generated that are relevant
for subsequent transformations. The `tiledValues` represent the tiled
values for the results of the original operation. This better
transmits the state of the transformed IR.
As a consequence the following methods also return `FailureOr<TilingResult>`
- `tensor::replaceExtractSliceWithTiledProducer`
- `tensor::bubbleUpPadSlice`
Differential Revision: https://reviews.llvm.org/D145133
43 lines
1.6 KiB
C++
43 lines
1.6 KiB
C++
//===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Swap a `tensor.extract_slice` with the producer of the source if the producer
|
|
// implements the `TilingInterface`. When used in conjunction with tiling this
|
|
// effectively tiles + fuses the producer with its consumer.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/Interfaces/TilingInterface.h"
|
|
|
|
using namespace mlir;
|
|
|
|
FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
|
|
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
|
|
auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
|
|
if (!producerOp)
|
|
return failure();
|
|
|
|
// `TilingInterface` currently only supports strides being 1.
|
|
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
|
|
return !isConstantIntValue(ofr, 1);
|
|
}))
|
|
return failure();
|
|
|
|
FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
|
|
builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
|
|
sliceOp.getMixedSizes());
|
|
if (failed(tiledResult))
|
|
return failure();
|
|
|
|
return *tiledResult;
|
|
}
|