[mlir][TilingInterface] Handle multi operand consumer fusion. (#145193)

For consumer fusion cases of this form

```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {

  tensor.parallel_insert_slice ... into %arg0
  tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#0, %0#1)
```

the current consumer fusion that handles one slice at a time cannot fuse
the consumer into the loop, since fusing along one slice will create and
SSA violation on the other use from the `scf.forall`. The solution is to
allow consumer fusion to allow considering multiple slices at once. This
PR changes the `TilingInterface` methods related to consumer fusion,
i.e.

- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`

to allow fusion while considering multiple operands. It is upto the
`TilingInterface` implementation to return an error if a list of tiles
of the operands cannot result in a consistent implementation of the
tiled operation.

The Linalg operation implementation of `TilingInterface` has been
modified to account for these changes and allow cases where operand
tiles that can result in a consistent tiling implementation are handled.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
This commit is contained in:
MaheshRavishankar
2025-06-25 11:54:38 -07:00
committed by GitHub
parent 28f6f87061
commit c873e5f87d
11 changed files with 703 additions and 205 deletions

View File

@@ -313,19 +313,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer,
const SCFTileAndFuseOptions &options);
/// Fuse the consumer of the source of `candidateSliceOp` by computing the
/// required slice of the consumer in-place. Note that the method
/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
/// value but does not delete the slice operation.
/// Fuse the consumer `candidateSlices` by computing the required slice of the
/// consumer in-place. All the entries of `candidateSlices` are expected to map
/// to the same consumer. The method returns an error if the consumer cannot be
/// tiled in a manner that is consistent for all the passed slices. Note that
/// the method replaces the uses of `candidateSlices` with the tiled and fused
/// consumer value but does not delete the slice operations.
struct SCFFuseConsumerOfSliceResult {
OpOperand *origConsumerOperand; // Original untiled consumer's operand.
OpOperand
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
// Original untiled consumer operands.
SmallVector<OpOperand *> origConsumerOperands;
// Tiled and fused consumer operands.
SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops);
tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops);
/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.

View File

@@ -31,12 +31,16 @@ namespace tensor {
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
/// Method to swap an `tensor.insert_slice` with its consumer when the
/// consumer implements the `TilingInterface`.
/// Method to swap `tensor.insert_slice`s with their consumers when the
/// consumer implements the `TilingInterface`. The size of `sliceOps` and
/// `consumerOperands` is expected to be the same. Every entry in
/// `consumerOperands` represents a use of the the corresponding
/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is
/// expected to be uses in the same consumer.
FailureOr<TilingResult>
replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
OffsetSizeAndStrideOpInterface sliceOp,
OpOperand &consumerOp);
replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
ArrayRef<tensor::InsertSliceOp> sliceOps,
ArrayRef<OpOperand *> consumerOperands);
//===----------------------------------------------------------------------===//
// Populate functions.

View File

@@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
using PointerUnion<Attribute, Value>::PointerUnion;
public:
void dump() const { llvm::errs() << *this << "\n"; }
LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }
MLIRContext *getContext() const {
PointerUnion pu = *this;

View File

@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
InterfaceMethod<
/*desc=*/[{
Method to generate the tiled implementation of an operation that uses
exactly a tile of the given operand.
the exact tiles of the given operands.
This method is required to allow operations to be "tiled and fused"
with an (already tiled) producer. Given a tile of the producer, this
method generates the tile of the consumer that uses exactly this
produced tile. In some sense it is the "reverse" of
with an (already tiled) producer. Given tiles of the producer, this
method generates the tile of the consumer that uses exactly these
produced tiles. In some sense it is the "reverse" of
`generateResultTileValue`.
- `operandNumber` is the result of the producer used by the consumer.
- `offsets` is the offset of the slice of the producer result used by
the tiled implementation of the consumer.
- `sizes` is the size of the slice of the producer result used by the
- `operandNumbers` is the list of operands whose tiles are "producers".
- `allOffsets` is the offset of the slice of the producer used by the
tiled implementation of the consumer.
- `allSizes` is the size of the slice of the producer used by the
consumer.
If it is illegal to fuse with a producer along the given operand for
If it is illegal to fuse with a producer along the given operand tiles for
an operation, the implementation should return a failure.
}],
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementationFromOperandTile",
/*methodName=*/"getTiledImplementationFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"unsigned":$operandNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
"::mlir::ArrayRef<unsigned>":$operandNumbers,
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
tile of the operand.
This method is required to allow operations to be "tiled and fused"
with an (already tiled) producer. Given a tile of an operand,
returns the tile of the iteration space that uses this tile.
- `operandNumber` is the result of the producer used by the consumer.
- `offsets` is the offset of the slice of the producer result used by
with an (already tiled) producer. Given tiles of operands,
returns the tile of the iteration space that uses these tiles.
- `operandNumbers` is the list of operands whose tiles are "produced"
by the producer(s).
- `allOffsets` is the offset of the slice of the producers used by
the tiled implementation of the consumer.
- `sizes` is the size of the slice of the producer result used by the
- `allSizes` is the size of the slice of the producers used by the
consumer.
If it is illegal to fuse with a producer along the given operand for
an operation, or if this mapping cannot be computed, the
implementation should return a failure.
If it is illegal to fuse with the producer slices for an operation,
or if this mapping cannot be computed, the implementation should
return a failure.
Note that unlike the "tile consumer and fuse producer" case, the
"tile producer and fuse consumer" requires an additional method to get
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformation. It does not provide guarantees on whether such a
transformation is profitable.
For most cases `getTiledImplementationFromOperandTile` could be a
implemented using `getIterationDomainTileFromOperandTile` +
For most cases `getTiledImplementationFromOperandTiles` could be a
implemented using `getIterationDomainTileFromOperandTiles` +
`getTiledImplementation` methods.
}],
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"getIterationDomainTileFromOperandTile",
/*methodName=*/"getIterationDomainTileFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"unsigned":$operandNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
"::mlir::ArrayRef<unsigned>":$operandNumbers,
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",

View File

@@ -22,8 +22,11 @@
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "linalg-tiling-interface-impl"
using namespace mlir;
using namespace mlir::linalg;
@@ -148,55 +151,82 @@ struct LinalgOpTilingInterface
/// Utility to fetch the offsets and sizes when applied as per the indexing
/// map of the linalg op. This helps in fusing the linalg op as a consumer of
/// a given slice op.
void
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &mappedOffsets,
SmallVectorImpl<OpFoldResult> &mappedSizes) const {
unsigned numLoops = linalgOp.getNumLoops();
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
mappedOffsets.resize(numLoops);
mappedSizes.resize(numLoops);
if (!indexingMap.isPermutation()) {
SmallVector<Range> iterationDomain =
tilingInterfaceOp.getIterationDomain(b);
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
mappedOffsets[index] = value.offset;
mappedSizes[index] = value.size;
static LogicalResult
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
ArrayRef<AffineMap> indexingMaps,
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
for (auto [indexingMap, offsets, sizes] :
llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
for (auto [resultExpr, offset, size] :
llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
if (!dimExpr)
continue;
unsigned position = dimExpr.getPosition();
auto it = mappedOffsets.find(position);
if (it != mappedOffsets.end()) {
OpFoldResult seenOffset = it->second;
OpFoldResult seenSize = mappedSizes.lookup(position);
if (seenOffset != offset || seenSize != size) {
LLVM_DEBUG({
llvm::dbgs() << "inconsistent iteration space mapping from "
"offsets/sizes of operands/results";
});
return failure();
}
} else {
mappedOffsets[position] = offset;
mappedSizes[position] = size;
}
}
}
for (const auto &&[index, value] :
llvm::enumerate(indexingMap.getResults())) {
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
mappedOffsets[dimPosition] = offsets[index];
mappedSizes[dimPosition] = sizes[index];
// Aggregate from the given operand offsets and sizes, or default to
// iteration space values.
SmallVector<Range> iterationDomain =
cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
mappedOffsetsVec.resize(iterationDomain.size());
mappedSizesVec.resize(iterationDomain.size());
for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
auto it = mappedOffsets.find(index);
if (it != mappedOffsets.end()) {
mappedOffsetsVec[index] = it->second;
mappedSizesVec[index] = mappedSizes.lookup(index);
continue;
}
mappedOffsetsVec[index] = domain.offset;
mappedSizesVec[index] = domain.size;
}
return success();
}
/// Method to return the position of the result tile computed by the tiled
/// operation.
LogicalResult getIterationDomainTileFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
LogicalResult getIterationDomainTileFromOperandTiles(
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);
// Check that the indexing map used for the operand is a projected
// permutation. This could be relaxed with a more general approach that can
// map the offsets and sizes from the operand to iteration space tiles
// (filling in full extent for dimensions not used to access the result).
AffineMap indexingMap =
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
if (!indexingMap.isProjectedPermutation()) {
return op->emitError()
<< "unhandled get iter domain position when operand is not "
"accessed using a permuted projection";
std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
iterationSpaceSizes;
SmallVector<AffineMap> indexingMaps =
llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
return linalgOp.getMatchingIndexingMap(&opOperand);
});
if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
allSizes, iterDomainOffsets,
iterDomainSizes))) {
return failure();
}
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
iterDomainOffsets, iterDomainSizes);
return success();
}
@@ -247,8 +277,13 @@ struct LinalgOpTilingInterface
"accessed using a permuted projection");
}
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
iterDomainOffsets, iterDomainSizes);
SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
auto status =
getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
{allSizes}, iterDomainOffsets, iterDomainSizes);
(void)status;
assert(succeeded(status) && "unexpected error in offset calculation");
return success();
}
@@ -279,12 +314,13 @@ struct LinalgOpTilingInterface
/// Method to generate the tiled implementation of an operation from the tile
/// of the operand.
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
if (failed(getIterationDomainTileFromOperandTile(
op, b, operandNumber, offsets, sizes, mappedOffsets,
if (failed(getIterationDomainTileFromOperandTiles(
op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
mappedSizes))) {
return failure();
}
@@ -837,13 +873,20 @@ struct PackOpTiling
/// Method to return the position of iteration domain tile computed by the
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
/// `resultSizes` only cover outer dimensions.
LogicalResult getIterationDomainTileFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
LogicalResult getIterationDomainTileFromOperandTiles(
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
if (operandNumber != 0)
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
LLVM_DEBUG(
{ llvm::dbgs() << "unsupported operands for consumer fusion"; });
return failure();
}
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
// It is not trivial to infer dest tile from source tile if `packOp` has
@@ -904,11 +947,18 @@ struct PackOpTiling
}
/// Method to return the tiled implementation of tensor.pack as a consumer.
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
if (operandNumber != 0)
FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
LLVM_DEBUG(
{ llvm ::dbgs() << "unhandled operands for consumer fusion"; });
return failure();
}
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
Location loc = packOp.getLoc();
@@ -923,8 +973,8 @@ struct PackOpTiling
tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
if (failed(getIterationDomainTileFromOperandTile(
op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
if (failed(getIterationDomainTileFromOperandTiles(
op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
outerDimSizes)))
return failure();
@@ -1182,12 +1232,21 @@ struct UnPackOpTiling
/// Method to return the position of iteration domain tile computed by the
/// tiled operation.
LogicalResult getIterationDomainTileFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
LogicalResult getIterationDomainTileFromOperandTiles(
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
if (operandNumbers.size() != 1) {
LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; });
return failure();
}
auto unPackOp = cast<UnPackOp>(op);
unsigned operandNumber = operandNumbers[0];
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
// If the operand tile is the dest, then no adjustment is needed.
if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
resultOffsets = llvm::to_vector(offsets);
@@ -1241,10 +1300,18 @@ struct UnPackOpTiling
}
/// Method to return the tiled implementation of tensor.unpack as a consumer.
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; });
return failure();
}
auto unPackOp = cast<UnPackOp>(op);
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
ArrayRef<OpFoldResult> sizes(allSizes[0]);
// tensor.unpack op is fusible (as a consumer) only if inner dims are not
// tiled.
int64_t numTiles = unPackOp.getInnerDimsPos().size();
@@ -1259,8 +1326,8 @@ struct UnPackOpTiling
// Fetch offset/size for creating the slice of the dest operand of
// unpack op.
SmallVector<OpFoldResult> outputOffsets, outputSizes;
if (failed(getIterationDomainTileFromOperandTile(
op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
if (failed(getIterationDomainTileFromOperandTiles(
op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
outputSizes)))
return failure();

View File

@@ -2047,53 +2047,119 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
/// A utility to fetch an untiled consumer of
/// tensor.insert_slice/tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected empty loops");
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
} else if (auto parallelInsertSlice =
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
} else {
return failure();
assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
SmallVector<OpOperand *> fusedOperands;
for (auto sliceOp : sliceOps) {
FailureOr<OpOperand *> fusedOperand =
TypeSwitch<Operation *, FailureOr<OpOperand *>>(sliceOp)
.Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
[&](auto op) {
return getUntiledConsumerFromSlice(rewriter, op, loops);
})
.Default([&](Operation *op) {
return rewriter.notifyMatchFailure(op, "unhandled slice type");
});
if (failed(fusedOperand)) {
return failure();
}
if (!fusedOperands.empty() &&
fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
return rewriter.notifyMatchFailure(
fusedOperand.value()->getOwner(),
"all candidate slices must be to the same consumer");
}
fusedOperands.push_back(fusedOperand.value());
}
return fusedOperands;
}
template <typename InsertSliceOpTy>
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
InsertSliceOpTy sliceOp);
template <>
tensor::InsertSliceOp
cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
tensor::InsertSliceOp insertSliceOp) {
return cast<tensor::InsertSliceOp>(
rewriter.clone(*insertSliceOp.getOperation()));
}
template <>
tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
return rewriter.create<tensor::InsertSliceOp>(
insertSliceOp->getLoc(), insertSliceOp.getSource(),
insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
}
static SmallVector<tensor::InsertSliceOp>
cloneAsInsertSlices(RewriterBase &rewriter,
ArrayRef<Operation *> candidateSlices) {
assert(!candidateSlices.empty() &&
"unexpected empty list of slices to clone");
SmallVector<tensor::InsertSliceOp> clonedSlices;
for (auto sliceOp : candidateSlices) {
TypeSwitch<Operation *>(sliceOp)
.Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
[&](auto op) {
auto clonedOp = cloneAsInsertSlice(rewriter, op);
clonedSlices.push_back(clonedOp);
})
.Default([&](Operation *op) {
// Assert here assuming this has already been checked.
assert(0 && "unexpected slice type while cloning as insert slice");
});
}
return clonedSlices;
}
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumerOfSlice(
RewriterBase &rewriter, Operation *candidateSliceOp,
mlir::scf::tileAndFuseConsumerOfSlices(
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops) {
if (candidateSlices.empty()) {
return rewriter.notifyMatchFailure(
rewriter.getUnknownLoc(),
"no candidate slices provided for consumer fusion");
}
// Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case.
if (loops.empty()) {
return candidateSliceOp->emitOpError(
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"cannot call tile and fuse consumer with an empty loop nest");
}
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))
return failure();
if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
llvm::all_of(candidateSlices,
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"candidates slices need to be all `tensor.extract_slice`s or "
"`tensor.parallel_insert_slice`s");
}
// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
FailureOr<OpOperand *> maybeConsumerOpOperand =
getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
}
OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
Operation *consumerOp = consumerOpOperand->getOwner();
unsigned operandNumber = consumerOpOperand->getOperandNumber();
unsigned resultNumber = 0;
if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
resultNumber = producerResult.getResultNumber();
} else {
return rewriter.notifyMatchFailure(
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
SmallVector<OpOperand *> consumerOpOperands;
Operation *consumerOp;
{
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSlices.front(),
"could not fetch consumer to fuse");
}
std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
consumerOp = consumerOpOperands.front()->getOwner();
}
LoopLikeOpInterface outerMostLoop = loops.front();
@@ -2113,16 +2179,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
if (!dstOp)
return rewriter.notifyMatchFailure(consumerOp,
"consumer op is not DPS operation");
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
return dstOp.isDpsInit(opOperand);
})) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
SmallVector<Value> newInits = dpsInits;
Location loc = outerMostLoop->getLoc();
SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
// 3. Move the whole loop structure right before firstUserOfLoop, the
// dominance should be already ensured by `checkAssumptionForLoop`.
@@ -2137,43 +2201,52 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
tensor::InsertSliceOp clonedInsertSliceOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
rewriter.setInsertionPoint(newForallOp.getTerminator());
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else {
rewriter.setInsertionPoint(candidateSliceOp);
clonedInsertSliceOp =
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
rewriter.setInsertionPoint(candidateSlices.front());
}
// 5.a. Clone all the candidate slices as equivalent insert slice ops.
SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
cloneAsInsertSlices(rewriter, candidateSlices);
// 5.a. Clone consumer op.
// 5.b. Clone consumer op.
auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
SmallVector<unsigned> operandNumbers =
llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
return opOperand->getOperandNumber();
});
SmallVector<OpOperand *> clonedOpFusedOperandsList =
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
return &clonedConsumerOp->getOpOperand(operandNum);
});
// 5.b. Replace all uses of the loop result with the result of the cloned
// 5.c. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice.
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
operandToReplace.set(clonedInsertSliceOp.getResult());
for (auto [operandToReplace, clonedSliceOp] :
llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
operandToReplace->set(clonedSliceOp.getResult());
}
});
// 6. Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op.
auto ossSliceOp =
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
clonedOpFusedOperandsList);
if (failed(tileAndFuseResult)) {
return failure();
}
auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
clonedInsertSliceOp.getSource());
for (auto [operandNum, clonedSliceOp] :
llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
clonedSliceOp.getSource());
}
// 7. Reconstruct [nested] loop with new inits.
YieldTiledValuesFn newYieldValuesFn =
@@ -2185,14 +2258,20 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// 8. Set inner insertPoint right before tiled consumer op.
innerRewriter.setInsertionPoint(tiledConsumerOp);
SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
for (auto candidateSliceOp : clonedInsertSlices) {
SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
// 9. Check all insert stride is 1.
if (!llvm::all_of(strides, isOneInteger)) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "containingOp's result yield with stride");
// 9. Check all insert stride is 1.
if (!llvm::all_of(strides, isOneInteger)) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "containingOp's result yield with stride");
}
allOffsets.emplace_back(std::move(offsets));
allSizes.emplace_back(std::move(sizes));
}
// 10. Try to get iter domain position from input position. Use
@@ -2202,8 +2281,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// tiledConsumerOp could lead to some chained unnecessary extra index
// computation.
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
iterDomainSizes))) {
return rewriter.notifyMatchFailure(
clonedConsumerOp,
@@ -2279,10 +2358,13 @@ mlir::scf::tileAndFuseConsumerOfSlice(
// 16. Need to erase the old scf loop and the cloned consumer op.
rewriter.eraseOp(clonedConsumerOp);
SmallVector<OpOperand *> tiledAndFusedOpOperands =
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
});
return scf::SCFFuseConsumerOfSliceResult{
consumerOpOperand,
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
tileAndFuseResult->tiledOps};
std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
std::move(tileAndFuseResult->tiledOps)};
}
//===----------------------------------------------------------------------===//

View File

@@ -17,6 +17,9 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "tensor-swap-slices"
using namespace mlir;
@@ -39,21 +42,55 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
return *tiledResult;
}
FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
OpOperand &consumer) {
auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
ArrayRef<OpOperand *> consumerOperands) {
if (sliceOps.empty()) {
LLVM_DEBUG(
{ llvm::dbgs() << "expected candidate slices list to be non-empty"; });
return failure();
}
if (sliceOps.size() != consumerOperands.size()) {
LLVM_DEBUG({
llvm::dbgs()
<< "expected as many operands as the number of slices passed";
});
return failure();
}
auto consumerOp =
dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
if (!consumerOp)
return failure();
for (auto opOperand : consumerOperands.drop_front()) {
if (opOperand->getOwner() != consumerOp) {
LLVM_DEBUG({
llvm::dbgs()
<< "expected all consumer operands to be from the same operation";
});
return failure();
}
}
// `TilingInterface` currently only supports strides being 1.
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
return failure();
auto consumerOperandNums = llvm::map_to_vector(
consumerOperands, [](OpOperand *opOperand) -> unsigned {
return opOperand->getOperandNumber();
});
SmallVector<SmallVector<OpFoldResult>> allOffsets;
SmallVector<SmallVector<OpFoldResult>> allSizes;
for (auto sliceOp : sliceOps) {
// `TilingInterface` currently only supports strides being 1.
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
return failure();
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
allOffsets.emplace_back(std::move(offsets));
allSizes.emplace_back(std::move(sizes));
}
FailureOr<TilingResult> tiledResult =
consumerOp.getTiledImplementationFromOperandTile(
builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes());
consumerOp.getTiledImplementationFromOperandTiles(
builder, consumerOperandNums, allOffsets, allSizes);
if (failed(tiledResult))
return failure();

View File

@@ -653,6 +653,7 @@ module {
%5 = affine.min #map2(%i)[%d0, %idx]
%6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
// CHECK: linalg.generic
// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
// CHECK: %[[T2:.*]] = linalg.generic {{.*}}
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
#map = affine_map<(d0) -> (d0)>
module {
@@ -620,3 +620,294 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
// -----
func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
%tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
%arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%generic:2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.mulf %b0, %b1 : f32
%1 = arith.addf %b0, %b2 : f32
linalg.yield %0, %1 : f32, f32
} -> (tensor<?xf32>, tensor<?xf32>)
scf.forall.in_parallel {
tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
}
}
%empty = tensor.empty(%dim0) : tensor<?xf32>
%result = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?xf32>
return %result : tensor<?xf32>
}
// CHECK-LABEL: func @multi_slice_fusion1(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
// CHECK: %[[TILESIZE:.+]] = affine.min
// CHECK-DAG: %[[GENERIC:.+]]:2 = linalg.generic
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
// CHECK: %[[FUSED:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]]#0, %[[GENERIC]]#1 :
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
// CHECK: return %[[RESULT]]#2
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
// Check that when the given operand tiles are inconsistent, tiling fails.
func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
%tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
%arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%generic0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.mulf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?xf32>
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
%generic1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0: f32
} -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
}
}
%empty = tensor.empty(%dim0) : tensor<?xf32>
%result = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?xf32>
return %result : tensor<?xf32>
}
// CHECK-LABEL: func @multi_slice_fusion2(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
// CHECK: %[[TILESIZE:.+]] = affine.min
// CHECK: %[[GENERIC0:.+]] = linalg.generic
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
// CHECK: %[[FUSED:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
// CHECK: return %[[RESULT]]#2
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
%arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
%loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) {
%tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
%tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
%arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
%init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> to tensor<?x?xf32>
%generic0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.mulf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32>
%generic1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0: f32
} -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32>
}
}
%empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
%result = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// CHECK-LABEL: func @multi_slice_fusion_with_broadcast(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]])
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
// CHECK-DAG: %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]])
// CHECK-DAG: %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]])
// CHECK: %[[GENERIC0:.+]] = linalg.generic
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
// CHECK: %[[FUSED:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
// CHECK: return %[[RESULT]]#2
// -----
func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
%arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
%loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
%tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
%arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
%init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> to tensor<?x?xf32>
%generic0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.mulf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
%init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> to tensor<?x?xf32>
%generic1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0: f32
} -> tensor<?x?xf32>
scf.forall.in_parallel {
// expected-error @below {{failed to fuse consumer of slice}}
tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
}
}
%empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
%result = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

View File

@@ -21,6 +21,9 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "test-tiling-interface"
#define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.h.inc"
@@ -168,29 +171,30 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
/// Apply fusing of consumer transformation to all payload ops and store both
/// the original consumer operation as well as the fused consumer operation.
template <typename Range>
static LogicalResult applyFuseConsumer(
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse,
TransformResults &transformResults) {
RewriterBase &rewriter, Operation *transformOp,
ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
uint32_t numConsumerToFuse, TransformResults &transformResults) {
SmallVector<Operation *> originalConsumerOps;
SmallVector<Operation *> fusedConsumerOps;
for (Operation *target : payloadOps) {
rewriter.setInsertionPoint(target);
rewriter.setInsertionPoint(slices.front());
while (numConsumerToFuse--) {
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
while (numConsumerToFuse--) {
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
if (failed(fuseConsumerResults))
return failure();
if (failed(fuseConsumerResults))
return slices.front()->emitOpError("failed to fuse consumer of slice");
// Report back the relevant handles to the transform op.
originalConsumerOps.push_back(
fuseConsumerResults->origConsumerOperand->getOwner());
fusedConsumerOps.push_back(
fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
// Report back the relevant handles to the transform op.
for (OpOperand *origConsumerOperand :
fuseConsumerResults->origConsumerOperands) {
originalConsumerOps.push_back(origConsumerOperand->getOwner());
}
for (OpOperand *tiledAndFusedConsumerOperand :
fuseConsumerResults->tiledAndFusedConsumerOperands) {
fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
}
}
@@ -203,6 +207,12 @@ DiagnosedSilenceableFailure
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
SmallVector<Operation *> slices;
for (auto op : getTargets()) {
auto sliceOp = *state.getPayloadOps(op).begin();
slices.push_back(sliceOp);
}
SmallVector<LoopLikeOpInterface> loops;
for (auto op : llvm::reverse(getLoops())) {
auto loopLikeOp =
@@ -212,16 +222,16 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
}
loops.push_back(loopLikeOp);
}
LogicalResult result = applyFuseConsumer(
rewriter, getOperation(), state.getPayloadOps(getTarget()), loops,
getNumConsumerToFuse(), transformResults);
LogicalResult result =
applyFuseConsumer(rewriter, getOperation(), slices, loops,
getNumConsumerToFuse(), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
void transform::TestFuseConsumerOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
consumesHandle(getTargetsMutable(), effects);
consumesHandle(getLoopsMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);

View File

@@ -50,7 +50,8 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
}
def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
[DeclareOpInterfaceMethods<TransformOpInterface>,
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
@@ -59,14 +60,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
}];
let arguments = (ins
TransformHandleTypeInterface:$target,
Variadic<TransformHandleTypeInterface>:$targets,
Variadic<TransformHandleTypeInterface>:$loops,
DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
let results = (outs TransformHandleTypeInterface:$consumer,
TransformHandleTypeInterface:$fused_consumer);
let assemblyFormat = [{
$target `in` `(` $loops `)`
$targets `in` `(` $loops `)`
(`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)?
attr-dict `:` functional-type(operands, results)
}];