[mlir][TilingInterface] Handle multi operand consumer fusion. (#145193)

For consumer fusion cases of this form

```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {

  tensor.parallel_insert_slice ... into %arg0
  tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#0, %0#1)
```

the current consumer fusion that handles one slice at a time cannot fuse
the consumer into the loop, since fusing along one slice will create and
SSA violation on the other use from the `scf.forall`. The solution is to
allow consumer fusion to allow considering multiple slices at once. This
PR changes the `TilingInterface` methods related to consumer fusion,
i.e.

- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`

to allow fusion while considering multiple operands. It is upto the
`TilingInterface` implementation to return an error if a list of tiles
of the operands cannot result in a consistent implementation of the
tiled operation.

The Linalg operation implementation of `TilingInterface` has been
modified to account for these changes and allow cases where operand
tiles that can result in a consistent tiling implementation are handled.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
This commit is contained in:
MaheshRavishankar
2025-06-25 11:54:38 -07:00
committed by GitHub
parent 28f6f87061
commit c873e5f87d
11 changed files with 703 additions and 205 deletions

View File

@@ -313,19 +313,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer, TilingInterface consumer,
const SCFTileAndFuseOptions &options); const SCFTileAndFuseOptions &options);
/// Fuse the consumer of the source of `candidateSliceOp` by computing the /// Fuse the consumer `candidateSlices` by computing the required slice of the
/// required slice of the consumer in-place. Note that the method /// consumer in-place. All the entries of `candidateSlices` are expected to map
/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer /// to the same consumer. The method returns an error if the consumer cannot be
/// value but does not delete the slice operation. /// tiled in a manner that is consistent for all the passed slices. Note that
/// the method replaces the uses of `candidateSlices` with the tiled and fused
/// consumer value but does not delete the slice operations.
struct SCFFuseConsumerOfSliceResult { struct SCFFuseConsumerOfSliceResult {
OpOperand *origConsumerOperand; // Original untiled consumer's operand. // Original untiled consumer operands.
OpOperand SmallVector<OpOperand *> origConsumerOperands;
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand. // Tiled and fused consumer operands.
SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
SmallVector<Operation *> tiledOps; SmallVector<Operation *> tiledOps;
}; };
FailureOr<scf::SCFFuseConsumerOfSliceResult> FailureOr<scf::SCFFuseConsumerOfSliceResult>
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
MutableArrayRef<LoopLikeOpInterface> loops); ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops);
/// Method to lower an `op` that implements the `TilingInterface` to /// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars. /// loops/scalars.

View File

@@ -31,12 +31,16 @@ namespace tensor {
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer( FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
/// Method to swap an `tensor.insert_slice` with its consumer when the /// Method to swap `tensor.insert_slice`s with their consumers when the
/// consumer implements the `TilingInterface`. /// consumer implements the `TilingInterface`. The size of `sliceOps` and
/// `consumerOperands` is expected to be the same. Every entry in
/// `consumerOperands` represents a use of the the corresponding
/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is
/// expected to be uses in the same consumer.
FailureOr<TilingResult> FailureOr<TilingResult>
replaceInsertSliceWithTiledConsumer(OpBuilder &builder, replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
OffsetSizeAndStrideOpInterface sliceOp, ArrayRef<tensor::InsertSliceOp> sliceOps,
OpOperand &consumerOp); ArrayRef<OpOperand *> consumerOperands);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Populate functions. // Populate functions.

View File

@@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
using PointerUnion<Attribute, Value>::PointerUnion; using PointerUnion<Attribute, Value>::PointerUnion;
public: public:
void dump() const { llvm::errs() << *this << "\n"; } LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }
MLIRContext *getContext() const { MLIRContext *getContext() const {
PointerUnion pu = *this; PointerUnion pu = *this;

View File

@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
InterfaceMethod< InterfaceMethod<
/*desc=*/[{ /*desc=*/[{
Method to generate the tiled implementation of an operation that uses Method to generate the tiled implementation of an operation that uses
exactly a tile of the given operand. the exact tiles of the given operands.
This method is required to allow operations to be "tiled and fused" This method is required to allow operations to be "tiled and fused"
with an (already tiled) producer. Given a tile of the producer, this with an (already tiled) producer. Given tiles of the producer, this
method generates the tile of the consumer that uses exactly this method generates the tile of the consumer that uses exactly these
produced tile. In some sense it is the "reverse" of produced tiles. In some sense it is the "reverse" of
`generateResultTileValue`. `generateResultTileValue`.
- `operandNumber` is the result of the producer used by the consumer. - `operandNumbers` is the list of operands whose tiles are "producers".
- `offsets` is the offset of the slice of the producer result used by - `allOffsets` is the offset of the slice of the producer used by the
the tiled implementation of the consumer. tiled implementation of the consumer.
- `sizes` is the size of the slice of the producer result used by the - `allSizes` is the size of the slice of the producer used by the
consumer. consumer.
If it is illegal to fuse with a producer along the given operand for If it is illegal to fuse with a producer along the given operand tiles for
an operation, the implementation should return a failure. an operation, the implementation should return a failure.
}], }],
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>", /*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementationFromOperandTile", /*methodName=*/"getTiledImplementationFromOperandTiles",
/*args=*/(ins /*args=*/(ins
"::mlir::OpBuilder &":$b, "::mlir::OpBuilder &":$b,
"unsigned":$operandNumber, "::mlir::ArrayRef<unsigned>":$operandNumbers,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets, "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes), "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
/*methodBody=*/"", /*methodBody=*/"",
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
return failure(); return failure();
@@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
tile of the operand. tile of the operand.
This method is required to allow operations to be "tiled and fused" This method is required to allow operations to be "tiled and fused"
with an (already tiled) producer. Given a tile of an operand, with an (already tiled) producer. Given tiles of operands,
returns the tile of the iteration space that uses this tile. returns the tile of the iteration space that uses these tiles.
- `operandNumber` is the result of the producer used by the consumer. - `operandNumbers` is the list of operands whose tiles are "produced"
- `offsets` is the offset of the slice of the producer result used by by the producer(s).
- `allOffsets` is the offset of the slice of the producers used by
the tiled implementation of the consumer. the tiled implementation of the consumer.
- `sizes` is the size of the slice of the producer result used by the - `allSizes` is the size of the slice of the producers used by the
consumer. consumer.
If it is illegal to fuse with a producer along the given operand for If it is illegal to fuse with the producer slices for an operation,
an operation, or if this mapping cannot be computed, the or if this mapping cannot be computed, the implementation should
implementation should return a failure. return a failure.
Note that unlike the "tile consumer and fuse producer" case, the Note that unlike the "tile consumer and fuse producer" case, the
"tile producer and fuse consumer" requires an additional method to get "tile producer and fuse consumer" requires an additional method to get
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformation. It does not provide guarantees on whether such a transformation. It does not provide guarantees on whether such a
transformation is profitable. transformation is profitable.
For most cases `getTiledImplementationFromOperandTile` could be a For most cases `getTiledImplementationFromOperandTiles` could be a
implemented using `getIterationDomainTileFromOperandTile` + implemented using `getIterationDomainTileFromOperandTiles` +
`getTiledImplementation` methods. `getTiledImplementation` methods.
}], }],
/*retType=*/"::llvm::LogicalResult", /*retType=*/"::llvm::LogicalResult",
/*methodName=*/"getIterationDomainTileFromOperandTile", /*methodName=*/"getIterationDomainTileFromOperandTiles",
/*args=*/(ins /*args=*/(ins
"::mlir::OpBuilder &":$b, "::mlir::OpBuilder &":$b,
"unsigned":$operandNumber, "::mlir::ArrayRef<unsigned>":$operandNumbers,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets, "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes, "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets, "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes), "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"", /*methodBody=*/"",

View File

@@ -22,8 +22,11 @@
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "llvm/Support/Debug.h"
#include <optional> #include <optional>
#define DEBUG_TYPE "linalg-tiling-interface-impl"
using namespace mlir; using namespace mlir;
using namespace mlir::linalg; using namespace mlir::linalg;
@@ -148,55 +151,82 @@ struct LinalgOpTilingInterface
/// Utility to fetch the offsets and sizes when applied as per the indexing /// Utility to fetch the offsets and sizes when applied as per the indexing
/// map of the linalg op. This helps in fusing the linalg op as a consumer of /// map of the linalg op. This helps in fusing the linalg op as a consumer of
/// a given slice op. /// a given slice op.
void static LogicalResult
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap, getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
ArrayRef<OpFoldResult> offsets, ArrayRef<AffineMap> indexingMaps,
ArrayRef<OpFoldResult> sizes, ArrayRef<SmallVector<OpFoldResult>> allOffsets,
SmallVectorImpl<OpFoldResult> &mappedOffsets, ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &mappedSizes) const { SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
unsigned numLoops = linalgOp.getNumLoops(); SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation()); DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
mappedOffsets.resize(numLoops);
mappedSizes.resize(numLoops); for (auto [indexingMap, offsets, sizes] :
if (!indexingMap.isPermutation()) { llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
SmallVector<Range> iterationDomain = for (auto [resultExpr, offset, size] :
tilingInterfaceOp.getIterationDomain(b); llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
mappedOffsets[index] = value.offset; if (!dimExpr)
mappedSizes[index] = value.size; continue;
unsigned position = dimExpr.getPosition();
auto it = mappedOffsets.find(position);
if (it != mappedOffsets.end()) {
OpFoldResult seenOffset = it->second;
OpFoldResult seenSize = mappedSizes.lookup(position);
if (seenOffset != offset || seenSize != size) {
LLVM_DEBUG({
llvm::dbgs() << "inconsistent iteration space mapping from "
"offsets/sizes of operands/results";
});
return failure();
}
} else {
mappedOffsets[position] = offset;
mappedSizes[position] = size;
}
} }
} }
for (const auto &&[index, value] :
llvm::enumerate(indexingMap.getResults())) { // Aggregate from the given operand offsets and sizes, or default to
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition(); // iteration space values.
mappedOffsets[dimPosition] = offsets[index]; SmallVector<Range> iterationDomain =
mappedSizes[dimPosition] = sizes[index]; cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
mappedOffsetsVec.resize(iterationDomain.size());
mappedSizesVec.resize(iterationDomain.size());
for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
auto it = mappedOffsets.find(index);
if (it != mappedOffsets.end()) {
mappedOffsetsVec[index] = it->second;
mappedSizesVec[index] = mappedSizes.lookup(index);
continue;
}
mappedOffsetsVec[index] = domain.offset;
mappedSizesVec[index] = domain.size;
} }
return success();
} }
/// Method to return the position of the result tile computed by the tiled /// Method to return the position of the result tile computed by the tiled
/// operation. /// operation.
LogicalResult getIterationDomainTileFromOperandTile( LogicalResult getIterationDomainTileFromOperandTiles(
Operation *op, OpBuilder &b, unsigned operandNumber, Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets, SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op); auto linalgOp = cast<LinalgOp>(op);
// Check that the indexing map used for the operand is a projected std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
// permutation. This could be relaxed with a more general approach that can iterationSpaceSizes;
// map the offsets and sizes from the operand to iteration space tiles SmallVector<AffineMap> indexingMaps =
// (filling in full extent for dimensions not used to access the result). llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
AffineMap indexingMap = OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber)); return linalgOp.getMatchingIndexingMap(&opOperand);
if (!indexingMap.isProjectedPermutation()) { });
return op->emitError() if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
<< "unhandled get iter domain position when operand is not " allSizes, iterDomainOffsets,
"accessed using a permuted projection"; iterDomainSizes))) {
return failure();
} }
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
iterDomainOffsets, iterDomainSizes);
return success(); return success();
} }
@@ -247,8 +277,13 @@ struct LinalgOpTilingInterface
"accessed using a permuted projection"); "accessed using a permuted projection");
} }
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
iterDomainOffsets, iterDomainSizes); SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
auto status =
getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
{allSizes}, iterDomainOffsets, iterDomainSizes);
(void)status;
assert(succeeded(status) && "unexpected error in offset calculation");
return success(); return success();
} }
@@ -279,12 +314,13 @@ struct LinalgOpTilingInterface
/// Method to generate the tiled implementation of an operation from the tile /// Method to generate the tiled implementation of an operation from the tile
/// of the operand. /// of the operand.
FailureOr<TilingResult> getTiledImplementationFromOperandTile( FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
Operation *op, OpBuilder &b, unsigned operandNumber, Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes; SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
if (failed(getIterationDomainTileFromOperandTile( if (failed(getIterationDomainTileFromOperandTiles(
op, b, operandNumber, offsets, sizes, mappedOffsets, op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
mappedSizes))) { mappedSizes))) {
return failure(); return failure();
} }
@@ -837,13 +873,20 @@ struct PackOpTiling
/// Method to return the position of iteration domain tile computed by the /// Method to return the position of iteration domain tile computed by the
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
/// `resultSizes` only cover outer dimensions. /// `resultSizes` only cover outer dimensions.
LogicalResult getIterationDomainTileFromOperandTile( LogicalResult getIterationDomainTileFromOperandTiles(
Operation *op, OpBuilder &b, unsigned operandNumber, Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &resultOffsets, SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const { SmallVectorImpl<OpFoldResult> &resultSizes) const {
if (operandNumber != 0) if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
LLVM_DEBUG(
{ llvm::dbgs() << "unsupported operands for consumer fusion"; });
return failure(); return failure();
}
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op); auto packOp = cast<PackOp>(op);
// It is not trivial to infer dest tile from source tile if `packOp` has // It is not trivial to infer dest tile from source tile if `packOp` has
@@ -904,11 +947,18 @@ struct PackOpTiling
} }
/// Method to return the tiled implementation of tensor.pack as a consumer. /// Method to return the tiled implementation of tensor.pack as a consumer.
FailureOr<TilingResult> getTiledImplementationFromOperandTile( FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
Operation *op, OpBuilder &b, unsigned operandNumber, Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { ArrayRef<SmallVector<OpFoldResult>> allOffsets,
if (operandNumber != 0) ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
LLVM_DEBUG(
{ llvm ::dbgs() << "unhandled operands for consumer fusion"; });
return failure(); return failure();
}
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op); auto packOp = cast<PackOp>(op);
Location loc = packOp.getLoc(); Location loc = packOp.getLoc();
@@ -923,8 +973,8 @@ struct PackOpTiling
tiledOperands.push_back(sourceSlice); tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
if (failed(getIterationDomainTileFromOperandTile( if (failed(getIterationDomainTileFromOperandTiles(
op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets, op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
outerDimSizes))) outerDimSizes)))
return failure(); return failure();
@@ -1182,12 +1232,21 @@ struct UnPackOpTiling
/// Method to return the position of iteration domain tile computed by the /// Method to return the position of iteration domain tile computed by the
/// tiled operation. /// tiled operation.
LogicalResult getIterationDomainTileFromOperandTile( LogicalResult getIterationDomainTileFromOperandTiles(
Operation *op, OpBuilder &b, unsigned operandNumber, Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &resultOffsets, SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const { SmallVectorImpl<OpFoldResult> &resultSizes) const {
if (operandNumbers.size() != 1) {
LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; });
return failure();
}
auto unPackOp = cast<UnPackOp>(op); auto unPackOp = cast<UnPackOp>(op);
unsigned operandNumber = operandNumbers[0];
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
// If the operand tile is the dest, then no adjustment is needed. // If the operand tile is the dest, then no adjustment is needed.
if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) { if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
resultOffsets = llvm::to_vector(offsets); resultOffsets = llvm::to_vector(offsets);
@@ -1241,10 +1300,18 @@ struct UnPackOpTiling
} }
/// Method to return the tiled implementation of tensor.unpack as a consumer. /// Method to return the tiled implementation of tensor.unpack as a consumer.
FailureOr<TilingResult> getTiledImplementationFromOperandTile( FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
Operation *op, OpBuilder &b, unsigned operandNumber, Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; });
return failure();
}
auto unPackOp = cast<UnPackOp>(op); auto unPackOp = cast<UnPackOp>(op);
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
// tensor.unpack op is fusible (as a consumer) only if inner dims are not // tensor.unpack op is fusible (as a consumer) only if inner dims are not
// tiled. // tiled.
int64_t numTiles = unPackOp.getInnerDimsPos().size(); int64_t numTiles = unPackOp.getInnerDimsPos().size();
@@ -1259,8 +1326,8 @@ struct UnPackOpTiling
// Fetch offset/size for creating the slice of the dest operand of // Fetch offset/size for creating the slice of the dest operand of
// unpack op. // unpack op.
SmallVector<OpFoldResult> outputOffsets, outputSizes; SmallVector<OpFoldResult> outputOffsets, outputSizes;
if (failed(getIterationDomainTileFromOperandTile( if (failed(getIterationDomainTileFromOperandTiles(
op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets, op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
outputSizes))) outputSizes)))
return failure(); return failure();

View File

@@ -2047,53 +2047,119 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
/// A utility to fetch an untiled consumer of /// A utility to fetch an untiled consumer of
/// tensor.insert_slice/tensor.parallel_insert_slice. /// tensor.insert_slice/tensor.parallel_insert_slice.
static FailureOr<OpOperand *> static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp, RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
MutableArrayRef<LoopLikeOpInterface> loops) { MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected empty loops"); assert(!loops.empty() && "unexpected empty loops");
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) { assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
return getUntiledConsumerFromSlice(rewriter, insertSlice, loops); SmallVector<OpOperand *> fusedOperands;
} else if (auto parallelInsertSlice = for (auto sliceOp : sliceOps) {
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) { FailureOr<OpOperand *> fusedOperand =
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops); TypeSwitch<Operation *, FailureOr<OpOperand *>>(sliceOp)
} else { .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
return failure(); [&](auto op) {
return getUntiledConsumerFromSlice(rewriter, op, loops);
})
.Default([&](Operation *op) {
return rewriter.notifyMatchFailure(op, "unhandled slice type");
});
if (failed(fusedOperand)) {
return failure();
}
if (!fusedOperands.empty() &&
fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
return rewriter.notifyMatchFailure(
fusedOperand.value()->getOwner(),
"all candidate slices must be to the same consumer");
}
fusedOperands.push_back(fusedOperand.value());
} }
return fusedOperands;
}
template <typename InsertSliceOpTy>
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
InsertSliceOpTy sliceOp);
template <>
tensor::InsertSliceOp
cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
tensor::InsertSliceOp insertSliceOp) {
return cast<tensor::InsertSliceOp>(
rewriter.clone(*insertSliceOp.getOperation()));
}
template <>
tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
return rewriter.create<tensor::InsertSliceOp>(
insertSliceOp->getLoc(), insertSliceOp.getSource(),
insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
}
static SmallVector<tensor::InsertSliceOp>
cloneAsInsertSlices(RewriterBase &rewriter,
ArrayRef<Operation *> candidateSlices) {
assert(!candidateSlices.empty() &&
"unexpected empty list of slices to clone");
SmallVector<tensor::InsertSliceOp> clonedSlices;
for (auto sliceOp : candidateSlices) {
TypeSwitch<Operation *>(sliceOp)
.Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
[&](auto op) {
auto clonedOp = cloneAsInsertSlice(rewriter, op);
clonedSlices.push_back(clonedOp);
})
.Default([&](Operation *op) {
// Assert here assuming this has already been checked.
assert(0 && "unexpected slice type while cloning as insert slice");
});
}
return clonedSlices;
} }
/// Implementation of fusing consumer of a single slice by computing the /// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop. /// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult> FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumerOfSlice( mlir::scf::tileAndFuseConsumerOfSlices(
RewriterBase &rewriter, Operation *candidateSliceOp, RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops) { MutableArrayRef<LoopLikeOpInterface> loops) {
if (candidateSlices.empty()) {
return rewriter.notifyMatchFailure(
rewriter.getUnknownLoc(),
"no candidate slices provided for consumer fusion");
}
// Return if `loops` is empty, return an error for now. Caller is expected // Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case. // to handle this case.
if (loops.empty()) { if (loops.empty()) {
return candidateSliceOp->emitOpError( return rewriter.notifyMatchFailure(
candidateSlices.front(),
"cannot call tile and fuse consumer with an empty loop nest"); "cannot call tile and fuse consumer with an empty loop nest");
} }
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp)) if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
return failure(); llvm::all_of(candidateSlices,
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"candidates slices need to be all `tensor.extract_slice`s or "
"`tensor.parallel_insert_slice`s");
}
// 1. Get the consumer of scf.for for the result yielded by // 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice. // tensor.insert_slice/parallel_insert_slice.
FailureOr<OpOperand *> maybeConsumerOpOperand = SmallVector<OpOperand *> consumerOpOperands;
getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops); Operation *consumerOp;
if (failed(maybeConsumerOpOperand)) { {
return rewriter.notifyMatchFailure(candidateSliceOp, FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
"could not fetch consumer to fuse"); getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
} if (failed(maybeConsumerOpOperand)) {
OpOperand *consumerOpOperand = *maybeConsumerOpOperand; return rewriter.notifyMatchFailure(candidateSlices.front(),
Operation *consumerOp = consumerOpOperand->getOwner(); "could not fetch consumer to fuse");
unsigned operandNumber = consumerOpOperand->getOperandNumber(); }
unsigned resultNumber = 0; std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) { consumerOp = consumerOpOperands.front()->getOwner();
resultNumber = producerResult.getResultNumber();
} else {
return rewriter.notifyMatchFailure(
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
} }
LoopLikeOpInterface outerMostLoop = loops.front(); LoopLikeOpInterface outerMostLoop = loops.front();
@@ -2113,16 +2179,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
if (!dstOp) if (!dstOp)
return rewriter.notifyMatchFailure(consumerOp, return rewriter.notifyMatchFailure(consumerOp,
"consumer op is not DPS operation"); "consumer op is not DPS operation");
SmallVector<Value> dpsInits = if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); return dstOp.isDpsInit(opOperand);
if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) { })) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
consumerOp, consumerOp,
"consumer op taking the result of scf.for as init is not supported"); "consumer op taking the result of scf.for as init is not supported");
} }
SmallVector<Value> newInits = dpsInits; SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
Location loc = outerMostLoop->getLoc();
// 3. Move the whole loop structure right before firstUserOfLoop, the // 3. Move the whole loop structure right before firstUserOfLoop, the
// dominance should be already ensured by `checkAssumptionForLoop`. // dominance should be already ensured by `checkAssumptionForLoop`.
@@ -2137,43 +2201,52 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// tensor.insert_slice. In the scf.for case this is a clone of the // tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the // candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice. // operands of tensor.parallel_insert_slice.
tensor::InsertSliceOp clonedInsertSliceOp;
if (auto sliceOp = if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) { dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation()); auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
rewriter.setInsertionPoint(newForallOp.getTerminator()); rewriter.setInsertionPoint(newForallOp.getTerminator());
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else { } else {
rewriter.setInsertionPoint(candidateSliceOp); rewriter.setInsertionPoint(candidateSlices.front());
clonedInsertSliceOp =
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
} }
// 5.a. Clone all the candidate slices as equivalent insert slice ops.
SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
cloneAsInsertSlices(rewriter, candidateSlices);
// 5.a. Clone consumer op. // 5.b. Clone consumer op.
auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp)); auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
SmallVector<unsigned> operandNumbers =
llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
return opOperand->getOperandNumber();
});
SmallVector<OpOperand *> clonedOpFusedOperandsList =
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
return &clonedConsumerOp->getOpOperand(operandNum);
});
// 5.b. Replace all uses of the loop result with the result of the cloned // 5.c. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice. // tensor.insert_slice.
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
operandToReplace.set(clonedInsertSliceOp.getResult()); for (auto [operandToReplace, clonedSliceOp] :
llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
operandToReplace->set(clonedSliceOp.getResult());
}
}); });
// 6. Perform tiling of the cloned consumer and replace the operand at // 6. Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op. // `operandNumber` with the source of the cloned tensor.insert_slice op.
auto ossSliceOp =
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
FailureOr<TilingResult> tileAndFuseResult = FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer( tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); clonedOpFusedOperandsList);
if (failed(tileAndFuseResult)) { if (failed(tileAndFuseResult)) {
return failure(); return failure();
} }
auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]); auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), for (auto [operandNum, clonedSliceOp] :
clonedInsertSliceOp.getSource()); llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
clonedSliceOp.getSource());
}
// 7. Reconstruct [nested] loop with new inits. // 7. Reconstruct [nested] loop with new inits.
YieldTiledValuesFn newYieldValuesFn = YieldTiledValuesFn newYieldValuesFn =
@@ -2185,14 +2258,20 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// 8. Set inner insertPoint right before tiled consumer op. // 8. Set inner insertPoint right before tiled consumer op.
innerRewriter.setInsertionPoint(tiledConsumerOp); innerRewriter.setInsertionPoint(tiledConsumerOp);
SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets(); SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes(); for (auto candidateSliceOp : clonedInsertSlices) {
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides(); SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
// 9. Check all insert stride is 1. // 9. Check all insert stride is 1.
if (!llvm::all_of(strides, isOneInteger)) { if (!llvm::all_of(strides, isOneInteger)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
candidateSliceOp, "containingOp's result yield with stride"); candidateSliceOp, "containingOp's result yield with stride");
}
allOffsets.emplace_back(std::move(offsets));
allSizes.emplace_back(std::move(sizes));
} }
// 10. Try to get iter domain position from input position. Use // 10. Try to get iter domain position from input position. Use
@@ -2202,8 +2281,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// tiledConsumerOp could lead to some chained unnecessary extra index // tiledConsumerOp could lead to some chained unnecessary extra index
// computation. // computation.
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes; SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
rewriter, operandNumber, offsets, sizes, iterDomainOffsets, rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
iterDomainSizes))) { iterDomainSizes))) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
clonedConsumerOp, clonedConsumerOp,
@@ -2279,10 +2358,13 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// 16. Need to erase the old scf loop and the cloned consumer op. // 16. Need to erase the old scf loop and the cloned consumer op.
rewriter.eraseOp(clonedConsumerOp); rewriter.eraseOp(clonedConsumerOp);
SmallVector<OpOperand *> tiledAndFusedOpOperands =
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
});
return scf::SCFFuseConsumerOfSliceResult{ return scf::SCFFuseConsumerOfSliceResult{
consumerOpOperand, std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)), std::move(tileAndFuseResult->tiledOps)};
tileAndFuseResult->tiledOps};
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@@ -17,6 +17,9 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/TilingInterface.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "tensor-swap-slices"
using namespace mlir; using namespace mlir;
@@ -39,21 +42,55 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
return *tiledResult; return *tiledResult;
} }
FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer( FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
OpOperand &consumer) { ArrayRef<OpOperand *> consumerOperands) {
auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner()); if (sliceOps.empty()) {
LLVM_DEBUG(
{ llvm::dbgs() << "expected candidate slices list to be non-empty"; });
return failure();
}
if (sliceOps.size() != consumerOperands.size()) {
LLVM_DEBUG({
llvm::dbgs()
<< "expected as many operands as the number of slices passed";
});
return failure();
}
auto consumerOp =
dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
if (!consumerOp) if (!consumerOp)
return failure(); return failure();
for (auto opOperand : consumerOperands.drop_front()) {
if (opOperand->getOwner() != consumerOp) {
LLVM_DEBUG({
llvm::dbgs()
<< "expected all consumer operands to be from the same operation";
});
return failure();
}
}
// `TilingInterface` currently only supports strides being 1. auto consumerOperandNums = llvm::map_to_vector(
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) consumerOperands, [](OpOperand *opOperand) -> unsigned {
return failure(); return opOperand->getOperandNumber();
});
SmallVector<SmallVector<OpFoldResult>> allOffsets;
SmallVector<SmallVector<OpFoldResult>> allSizes;
for (auto sliceOp : sliceOps) {
// `TilingInterface` currently only supports strides being 1.
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
return failure();
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
allOffsets.emplace_back(std::move(offsets));
allSizes.emplace_back(std::move(sizes));
}
FailureOr<TilingResult> tiledResult = FailureOr<TilingResult> tiledResult =
consumerOp.getTiledImplementationFromOperandTile( consumerOp.getTiledImplementationFromOperandTiles(
builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(), builder, consumerOperandNums, allOffsets, allSizes);
sliceOp.getMixedSizes());
if (failed(tiledResult)) if (failed(tiledResult))
return failure(); return failure();

View File

@@ -653,6 +653,7 @@ module {
%5 = affine.min #map2(%i)[%d0, %idx] %5 = affine.min #map2(%i)[%d0, %idx]
%6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32> %6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
// CHECK: linalg.generic
// CHECK: %[[T1:.*]] = linalg.generic {{.*}} // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
// CHECK: %[[T2:.*]] = linalg.generic {{.*}} // CHECK: %[[T2:.*]] = linalg.generic {{.*}}
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32> %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s // RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
#map = affine_map<(d0) -> (d0)> #map = affine_map<(d0) -> (d0)>
module { module {
@@ -620,3 +620,294 @@ module attributes {transform.with_named_sequence} {
transform.yield transform.yield
} }
} }
// -----
func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
%tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
%arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%generic:2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.mulf %b0, %b1 : f32
%1 = arith.addf %b0, %b2 : f32
linalg.yield %0, %1 : f32, f32
} -> (tensor<?xf32>, tensor<?xf32>)
scf.forall.in_parallel {
tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
}
}
%empty = tensor.empty(%dim0) : tensor<?xf32>
%result = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?xf32>
return %result : tensor<?xf32>
}
// CHECK-LABEL: func @multi_slice_fusion1(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
// CHECK: %[[TILESIZE:.+]] = affine.min
// CHECK-DAG: %[[GENERIC:.+]]:2 = linalg.generic
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
// CHECK: %[[FUSED:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]]#0, %[[GENERIC]]#1 :
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
// CHECK: return %[[RESULT]]#2
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
// Check that when the given operand tiles are inconsistent, tiling fails.
func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
%tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
%arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%generic0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.mulf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?xf32>
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%generic1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0: f32
} -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
}
}
%empty = tensor.empty(%dim0) : tensor<?xf32>
%result = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?xf32>
return %result : tensor<?xf32>
}
// CHECK-LABEL: func @multi_slice_fusion2(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
// CHECK: %[[TILESIZE:.+]] = affine.min
// CHECK: %[[GENERIC0:.+]] = linalg.generic
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
// CHECK: %[[FUSED:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
// CHECK: return %[[RESULT]]#2
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
%arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
%loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) {
%tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
%tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
%arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
%init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> to tensor<?x?xf32>
%generic0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.mulf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32>
%generic1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0: f32
} -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32>
}
}
%empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
%result = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// CHECK-LABEL: func @multi_slice_fusion_with_broadcast(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]])
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
// CHECK-DAG: %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]])
// CHECK-DAG: %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]])
// CHECK: %[[GENERIC0:.+]] = linalg.generic
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
// CHECK: %[[FUSED:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
// CHECK: return %[[RESULT]]#2
// -----
func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
%arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
%loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
%tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
%arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
%init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> to tensor<?x?xf32>
%generic0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.mulf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
%init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> to tensor<?x?xf32>
%generic1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0: f32
} -> tensor<?x?xf32>
scf.forall.in_parallel {
// expected-error @below {{failed to fuse consumer of slice}}
tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
}
}
%empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
%result = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

View File

@@ -21,6 +21,9 @@
#include "mlir/IR/Dominance.h" #include "mlir/IR/Dominance.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/TilingInterface.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "test-tiling-interface"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.h.inc" #include "TestTilingInterfaceTransformOps.h.inc"
@@ -168,29 +171,30 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
/// Apply fusing of consumer transformation to all payload ops and store both /// Apply fusing of consumer transformation to all payload ops and store both
/// the original consumer operation as well as the fused consumer operation. /// the original consumer operation as well as the fused consumer operation.
template <typename Range>
static LogicalResult applyFuseConsumer( static LogicalResult applyFuseConsumer(
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, RewriterBase &rewriter, Operation *transformOp,
MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse, ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
TransformResults &transformResults) { uint32_t numConsumerToFuse, TransformResults &transformResults) {
SmallVector<Operation *> originalConsumerOps; SmallVector<Operation *> originalConsumerOps;
SmallVector<Operation *> fusedConsumerOps; SmallVector<Operation *> fusedConsumerOps;
for (Operation *target : payloadOps) { rewriter.setInsertionPoint(slices.front());
rewriter.setInsertionPoint(target);
while (numConsumerToFuse--) { while (numConsumerToFuse--) {
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults = FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
scf::tileAndFuseConsumerOfSlice(rewriter, target, loops); scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
if (failed(fuseConsumerResults)) if (failed(fuseConsumerResults))
return failure(); return slices.front()->emitOpError("failed to fuse consumer of slice");
// Report back the relevant handles to the transform op. // Report back the relevant handles to the transform op.
originalConsumerOps.push_back( for (OpOperand *origConsumerOperand :
fuseConsumerResults->origConsumerOperand->getOwner()); fuseConsumerResults->origConsumerOperands) {
fusedConsumerOps.push_back( originalConsumerOps.push_back(origConsumerOperand->getOwner());
fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner()); }
for (OpOperand *tiledAndFusedConsumerOperand :
fuseConsumerResults->tiledAndFusedConsumerOperands) {
fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
} }
} }
@@ -203,6 +207,12 @@ DiagnosedSilenceableFailure
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults, TransformResults &transformResults,
TransformState &state) { TransformState &state) {
SmallVector<Operation *> slices;
for (auto op : getTargets()) {
auto sliceOp = *state.getPayloadOps(op).begin();
slices.push_back(sliceOp);
}
SmallVector<LoopLikeOpInterface> loops; SmallVector<LoopLikeOpInterface> loops;
for (auto op : llvm::reverse(getLoops())) { for (auto op : llvm::reverse(getLoops())) {
auto loopLikeOp = auto loopLikeOp =
@@ -212,16 +222,16 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
} }
loops.push_back(loopLikeOp); loops.push_back(loopLikeOp);
} }
LogicalResult result = applyFuseConsumer( LogicalResult result =
rewriter, getOperation(), state.getPayloadOps(getTarget()), loops, applyFuseConsumer(rewriter, getOperation(), slices, loops,
getNumConsumerToFuse(), transformResults); getNumConsumerToFuse(), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success(); : DiagnosedSilenceableFailure::success();
} }
void transform::TestFuseConsumerOp::getEffects( void transform::TestFuseConsumerOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects); consumesHandle(getTargetsMutable(), effects);
consumesHandle(getLoopsMutable(), effects); consumesHandle(getLoopsMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects); producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects); modifiesPayload(effects);

View File

@@ -50,7 +50,8 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
} }
def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer", def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
[DeclareOpInterfaceMethods<TransformOpInterface>, [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ReportTrackingListenerFailuresOpTrait]> { ReportTrackingListenerFailuresOpTrait]> {
let description = [{ let description = [{
@@ -59,14 +60,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
}]; }];
let arguments = (ins let arguments = (ins
TransformHandleTypeInterface:$target, Variadic<TransformHandleTypeInterface>:$targets,
Variadic<TransformHandleTypeInterface>:$loops, Variadic<TransformHandleTypeInterface>:$loops,
DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse); DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
let results = (outs TransformHandleTypeInterface:$consumer, let results = (outs TransformHandleTypeInterface:$consumer,
TransformHandleTypeInterface:$fused_consumer); TransformHandleTypeInterface:$fused_consumer);
let assemblyFormat = [{ let assemblyFormat = [{
$target `in` `(` $loops `)` $targets `in` `(` $loops `)`
(`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)?
attr-dict `:` functional-type(operands, results) attr-dict `:` functional-type(operands, results)
}]; }];