[mlir][TilingInterface] Handle multi operand consumer fusion. (#145193)
For consumer fusion cases of this form
```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {
tensor.parallel_insert_slice ... into %arg0
tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#0, %0#1)
```
the current consumer fusion that handles one slice at a time cannot fuse
the consumer into the loop, since fusing along one slice will create and
SSA violation on the other use from the `scf.forall`. The solution is to
allow consumer fusion to allow considering multiple slices at once. This
PR changes the `TilingInterface` methods related to consumer fusion,
i.e.
- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`
to allow fusion while considering multiple operands. It is upto the
`TilingInterface` implementation to return an error if a list of tiles
of the operands cannot result in a consistent implementation of the
tiled operation.
The Linalg operation implementation of `TilingInterface` has been
modified to account for these changes and allow cases where operand
tiles that can result in a consistent tiling implementation are handled.
---------
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
This commit is contained in:
committed by
GitHub
parent
28f6f87061
commit
c873e5f87d
@@ -313,19 +313,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
|
|||||||
TilingInterface consumer,
|
TilingInterface consumer,
|
||||||
const SCFTileAndFuseOptions &options);
|
const SCFTileAndFuseOptions &options);
|
||||||
|
|
||||||
/// Fuse the consumer of the source of `candidateSliceOp` by computing the
|
/// Fuse the consumer `candidateSlices` by computing the required slice of the
|
||||||
/// required slice of the consumer in-place. Note that the method
|
/// consumer in-place. All the entries of `candidateSlices` are expected to map
|
||||||
/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
|
/// to the same consumer. The method returns an error if the consumer cannot be
|
||||||
/// value but does not delete the slice operation.
|
/// tiled in a manner that is consistent for all the passed slices. Note that
|
||||||
|
/// the method replaces the uses of `candidateSlices` with the tiled and fused
|
||||||
|
/// consumer value but does not delete the slice operations.
|
||||||
struct SCFFuseConsumerOfSliceResult {
|
struct SCFFuseConsumerOfSliceResult {
|
||||||
OpOperand *origConsumerOperand; // Original untiled consumer's operand.
|
// Original untiled consumer operands.
|
||||||
OpOperand
|
SmallVector<OpOperand *> origConsumerOperands;
|
||||||
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
|
// Tiled and fused consumer operands.
|
||||||
|
SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
|
||||||
SmallVector<Operation *> tiledOps;
|
SmallVector<Operation *> tiledOps;
|
||||||
};
|
};
|
||||||
FailureOr<scf::SCFFuseConsumerOfSliceResult>
|
FailureOr<scf::SCFFuseConsumerOfSliceResult>
|
||||||
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
|
tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
|
||||||
MutableArrayRef<LoopLikeOpInterface> loops);
|
ArrayRef<Operation *> candidateSlices,
|
||||||
|
MutableArrayRef<LoopLikeOpInterface> loops);
|
||||||
|
|
||||||
/// Method to lower an `op` that implements the `TilingInterface` to
|
/// Method to lower an `op` that implements the `TilingInterface` to
|
||||||
/// loops/scalars.
|
/// loops/scalars.
|
||||||
|
|||||||
@@ -31,12 +31,16 @@ namespace tensor {
|
|||||||
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
|
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
|
||||||
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
|
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
|
||||||
|
|
||||||
/// Method to swap an `tensor.insert_slice` with its consumer when the
|
/// Method to swap `tensor.insert_slice`s with their consumers when the
|
||||||
/// consumer implements the `TilingInterface`.
|
/// consumer implements the `TilingInterface`. The size of `sliceOps` and
|
||||||
|
/// `consumerOperands` is expected to be the same. Every entry in
|
||||||
|
/// `consumerOperands` represents a use of the the corresponding
|
||||||
|
/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is
|
||||||
|
/// expected to be uses in the same consumer.
|
||||||
FailureOr<TilingResult>
|
FailureOr<TilingResult>
|
||||||
replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
|
replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
|
||||||
OffsetSizeAndStrideOpInterface sliceOp,
|
ArrayRef<tensor::InsertSliceOp> sliceOps,
|
||||||
OpOperand &consumerOp);
|
ArrayRef<OpOperand *> consumerOperands);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Populate functions.
|
// Populate functions.
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
|
|||||||
using PointerUnion<Attribute, Value>::PointerUnion;
|
using PointerUnion<Attribute, Value>::PointerUnion;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void dump() const { llvm::errs() << *this << "\n"; }
|
LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }
|
||||||
|
|
||||||
MLIRContext *getContext() const {
|
MLIRContext *getContext() const {
|
||||||
PointerUnion pu = *this;
|
PointerUnion pu = *this;
|
||||||
|
|||||||
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
|||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
/*desc=*/[{
|
/*desc=*/[{
|
||||||
Method to generate the tiled implementation of an operation that uses
|
Method to generate the tiled implementation of an operation that uses
|
||||||
exactly a tile of the given operand.
|
the exact tiles of the given operands.
|
||||||
|
|
||||||
This method is required to allow operations to be "tiled and fused"
|
This method is required to allow operations to be "tiled and fused"
|
||||||
with an (already tiled) producer. Given a tile of the producer, this
|
with an (already tiled) producer. Given tiles of the producer, this
|
||||||
method generates the tile of the consumer that uses exactly this
|
method generates the tile of the consumer that uses exactly these
|
||||||
produced tile. In some sense it is the "reverse" of
|
produced tiles. In some sense it is the "reverse" of
|
||||||
`generateResultTileValue`.
|
`generateResultTileValue`.
|
||||||
- `operandNumber` is the result of the producer used by the consumer.
|
- `operandNumbers` is the list of operands whose tiles are "producers".
|
||||||
- `offsets` is the offset of the slice of the producer result used by
|
- `allOffsets` is the offset of the slice of the producer used by the
|
||||||
the tiled implementation of the consumer.
|
tiled implementation of the consumer.
|
||||||
- `sizes` is the size of the slice of the producer result used by the
|
- `allSizes` is the size of the slice of the producer used by the
|
||||||
consumer.
|
consumer.
|
||||||
If it is illegal to fuse with a producer along the given operand for
|
If it is illegal to fuse with a producer along the given operand tiles for
|
||||||
an operation, the implementation should return a failure.
|
an operation, the implementation should return a failure.
|
||||||
}],
|
}],
|
||||||
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
|
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
|
||||||
/*methodName=*/"getTiledImplementationFromOperandTile",
|
/*methodName=*/"getTiledImplementationFromOperandTiles",
|
||||||
/*args=*/(ins
|
/*args=*/(ins
|
||||||
"::mlir::OpBuilder &":$b,
|
"::mlir::OpBuilder &":$b,
|
||||||
"unsigned":$operandNumber,
|
"::mlir::ArrayRef<unsigned>":$operandNumbers,
|
||||||
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
|
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
|
||||||
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
|
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
return failure();
|
return failure();
|
||||||
@@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
|||||||
tile of the operand.
|
tile of the operand.
|
||||||
|
|
||||||
This method is required to allow operations to be "tiled and fused"
|
This method is required to allow operations to be "tiled and fused"
|
||||||
with an (already tiled) producer. Given a tile of an operand,
|
with an (already tiled) producer. Given tiles of operands,
|
||||||
returns the tile of the iteration space that uses this tile.
|
returns the tile of the iteration space that uses these tiles.
|
||||||
- `operandNumber` is the result of the producer used by the consumer.
|
- `operandNumbers` is the list of operands whose tiles are "produced"
|
||||||
- `offsets` is the offset of the slice of the producer result used by
|
by the producer(s).
|
||||||
|
- `allOffsets` is the offset of the slice of the producers used by
|
||||||
the tiled implementation of the consumer.
|
the tiled implementation of the consumer.
|
||||||
- `sizes` is the size of the slice of the producer result used by the
|
- `allSizes` is the size of the slice of the producers used by the
|
||||||
consumer.
|
consumer.
|
||||||
If it is illegal to fuse with a producer along the given operand for
|
If it is illegal to fuse with the producer slices for an operation,
|
||||||
an operation, or if this mapping cannot be computed, the
|
or if this mapping cannot be computed, the implementation should
|
||||||
implementation should return a failure.
|
return a failure.
|
||||||
|
|
||||||
Note that unlike the "tile consumer and fuse producer" case, the
|
Note that unlike the "tile consumer and fuse producer" case, the
|
||||||
"tile producer and fuse consumer" requires an additional method to get
|
"tile producer and fuse consumer" requires an additional method to get
|
||||||
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
|||||||
transformation. It does not provide guarantees on whether such a
|
transformation. It does not provide guarantees on whether such a
|
||||||
transformation is profitable.
|
transformation is profitable.
|
||||||
|
|
||||||
For most cases `getTiledImplementationFromOperandTile` could be a
|
For most cases `getTiledImplementationFromOperandTiles` could be a
|
||||||
implemented using `getIterationDomainTileFromOperandTile` +
|
implemented using `getIterationDomainTileFromOperandTiles` +
|
||||||
`getTiledImplementation` methods.
|
`getTiledImplementation` methods.
|
||||||
}],
|
}],
|
||||||
/*retType=*/"::llvm::LogicalResult",
|
/*retType=*/"::llvm::LogicalResult",
|
||||||
/*methodName=*/"getIterationDomainTileFromOperandTile",
|
/*methodName=*/"getIterationDomainTileFromOperandTiles",
|
||||||
/*args=*/(ins
|
/*args=*/(ins
|
||||||
"::mlir::OpBuilder &":$b,
|
"::mlir::OpBuilder &":$b,
|
||||||
"unsigned":$operandNumber,
|
"::mlir::ArrayRef<unsigned>":$operandNumbers,
|
||||||
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
|
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
|
||||||
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
|
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
|
||||||
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
|
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
|
||||||
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
|
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
|
|||||||
@@ -22,8 +22,11 @@
|
|||||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||||
#include "mlir/Interfaces/TilingInterface.h"
|
#include "mlir/Interfaces/TilingInterface.h"
|
||||||
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "linalg-tiling-interface-impl"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::linalg;
|
using namespace mlir::linalg;
|
||||||
|
|
||||||
@@ -148,55 +151,82 @@ struct LinalgOpTilingInterface
|
|||||||
/// Utility to fetch the offsets and sizes when applied as per the indexing
|
/// Utility to fetch the offsets and sizes when applied as per the indexing
|
||||||
/// map of the linalg op. This helps in fusing the linalg op as a consumer of
|
/// map of the linalg op. This helps in fusing the linalg op as a consumer of
|
||||||
/// a given slice op.
|
/// a given slice op.
|
||||||
void
|
static LogicalResult
|
||||||
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
|
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
|
||||||
ArrayRef<OpFoldResult> offsets,
|
ArrayRef<AffineMap> indexingMaps,
|
||||||
ArrayRef<OpFoldResult> sizes,
|
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||||
SmallVectorImpl<OpFoldResult> &mappedOffsets,
|
ArrayRef<SmallVector<OpFoldResult>> allSizes,
|
||||||
SmallVectorImpl<OpFoldResult> &mappedSizes) const {
|
SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
|
||||||
unsigned numLoops = linalgOp.getNumLoops();
|
SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
|
||||||
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
|
DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
|
||||||
mappedOffsets.resize(numLoops);
|
|
||||||
mappedSizes.resize(numLoops);
|
for (auto [indexingMap, offsets, sizes] :
|
||||||
if (!indexingMap.isPermutation()) {
|
llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
|
||||||
SmallVector<Range> iterationDomain =
|
for (auto [resultExpr, offset, size] :
|
||||||
tilingInterfaceOp.getIterationDomain(b);
|
llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
|
||||||
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
|
auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
|
||||||
mappedOffsets[index] = value.offset;
|
if (!dimExpr)
|
||||||
mappedSizes[index] = value.size;
|
continue;
|
||||||
|
unsigned position = dimExpr.getPosition();
|
||||||
|
auto it = mappedOffsets.find(position);
|
||||||
|
if (it != mappedOffsets.end()) {
|
||||||
|
OpFoldResult seenOffset = it->second;
|
||||||
|
OpFoldResult seenSize = mappedSizes.lookup(position);
|
||||||
|
if (seenOffset != offset || seenSize != size) {
|
||||||
|
LLVM_DEBUG({
|
||||||
|
llvm::dbgs() << "inconsistent iteration space mapping from "
|
||||||
|
"offsets/sizes of operands/results";
|
||||||
|
});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mappedOffsets[position] = offset;
|
||||||
|
mappedSizes[position] = size;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (const auto &&[index, value] :
|
|
||||||
llvm::enumerate(indexingMap.getResults())) {
|
// Aggregate from the given operand offsets and sizes, or default to
|
||||||
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
|
// iteration space values.
|
||||||
mappedOffsets[dimPosition] = offsets[index];
|
SmallVector<Range> iterationDomain =
|
||||||
mappedSizes[dimPosition] = sizes[index];
|
cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
|
||||||
|
mappedOffsetsVec.resize(iterationDomain.size());
|
||||||
|
mappedSizesVec.resize(iterationDomain.size());
|
||||||
|
for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
|
||||||
|
auto it = mappedOffsets.find(index);
|
||||||
|
if (it != mappedOffsets.end()) {
|
||||||
|
mappedOffsetsVec[index] = it->second;
|
||||||
|
mappedSizesVec[index] = mappedSizes.lookup(index);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
mappedOffsetsVec[index] = domain.offset;
|
||||||
|
mappedSizesVec[index] = domain.size;
|
||||||
}
|
}
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Method to return the position of the result tile computed by the tiled
|
/// Method to return the position of the result tile computed by the tiled
|
||||||
/// operation.
|
/// operation.
|
||||||
LogicalResult getIterationDomainTileFromOperandTile(
|
LogicalResult getIterationDomainTileFromOperandTiles(
|
||||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
|
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||||
|
ArrayRef<SmallVector<OpFoldResult>> allSizes,
|
||||||
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
|
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
|
||||||
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
|
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
|
||||||
auto linalgOp = cast<LinalgOp>(op);
|
auto linalgOp = cast<LinalgOp>(op);
|
||||||
|
|
||||||
// Check that the indexing map used for the operand is a projected
|
std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
|
||||||
// permutation. This could be relaxed with a more general approach that can
|
iterationSpaceSizes;
|
||||||
// map the offsets and sizes from the operand to iteration space tiles
|
SmallVector<AffineMap> indexingMaps =
|
||||||
// (filling in full extent for dimensions not used to access the result).
|
llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
|
||||||
AffineMap indexingMap =
|
OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
|
||||||
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
|
return linalgOp.getMatchingIndexingMap(&opOperand);
|
||||||
if (!indexingMap.isProjectedPermutation()) {
|
});
|
||||||
return op->emitError()
|
if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
|
||||||
<< "unhandled get iter domain position when operand is not "
|
allSizes, iterDomainOffsets,
|
||||||
"accessed using a permuted projection";
|
iterDomainSizes))) {
|
||||||
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
|
|
||||||
iterDomainOffsets, iterDomainSizes);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,8 +277,13 @@ struct LinalgOpTilingInterface
|
|||||||
"accessed using a permuted projection");
|
"accessed using a permuted projection");
|
||||||
}
|
}
|
||||||
|
|
||||||
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
|
SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
|
||||||
iterDomainOffsets, iterDomainSizes);
|
SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
|
||||||
|
auto status =
|
||||||
|
getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
|
||||||
|
{allSizes}, iterDomainOffsets, iterDomainSizes);
|
||||||
|
(void)status;
|
||||||
|
assert(succeeded(status) && "unexpected error in offset calculation");
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,12 +314,13 @@ struct LinalgOpTilingInterface
|
|||||||
|
|
||||||
/// Method to generate the tiled implementation of an operation from the tile
|
/// Method to generate the tiled implementation of an operation from the tile
|
||||||
/// of the operand.
|
/// of the operand.
|
||||||
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
|
FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
|
||||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
|
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||||
|
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
|
||||||
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
|
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
|
||||||
if (failed(getIterationDomainTileFromOperandTile(
|
if (failed(getIterationDomainTileFromOperandTiles(
|
||||||
op, b, operandNumber, offsets, sizes, mappedOffsets,
|
op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
|
||||||
mappedSizes))) {
|
mappedSizes))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
@@ -837,13 +873,20 @@ struct PackOpTiling
|
|||||||
/// Method to return the position of iteration domain tile computed by the
|
/// Method to return the position of iteration domain tile computed by the
|
||||||
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
|
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
|
||||||
/// `resultSizes` only cover outer dimensions.
|
/// `resultSizes` only cover outer dimensions.
|
||||||
LogicalResult getIterationDomainTileFromOperandTile(
|
LogicalResult getIterationDomainTileFromOperandTiles(
|
||||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
|
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||||
|
ArrayRef<SmallVector<OpFoldResult>> allSizes,
|
||||||
SmallVectorImpl<OpFoldResult> &resultOffsets,
|
SmallVectorImpl<OpFoldResult> &resultOffsets,
|
||||||
SmallVectorImpl<OpFoldResult> &resultSizes) const {
|
SmallVectorImpl<OpFoldResult> &resultSizes) const {
|
||||||
if (operandNumber != 0)
|
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
|
||||||
|
LLVM_DEBUG(
|
||||||
|
{ llvm::dbgs() << "unsupported operands for consumer fusion"; });
|
||||||
return failure();
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
|
||||||
|
ArrayRef<OpFoldResult> sizes(allSizes[0]);
|
||||||
|
|
||||||
auto packOp = cast<PackOp>(op);
|
auto packOp = cast<PackOp>(op);
|
||||||
// It is not trivial to infer dest tile from source tile if `packOp` has
|
// It is not trivial to infer dest tile from source tile if `packOp` has
|
||||||
@@ -904,11 +947,18 @@ struct PackOpTiling
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Method to return the tiled implementation of tensor.pack as a consumer.
|
/// Method to return the tiled implementation of tensor.pack as a consumer.
|
||||||
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
|
FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
|
||||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
|
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||||
if (operandNumber != 0)
|
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
|
||||||
|
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
|
||||||
|
LLVM_DEBUG(
|
||||||
|
{ llvm ::dbgs() << "unhandled operands for consumer fusion"; });
|
||||||
return failure();
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
|
||||||
|
ArrayRef<OpFoldResult> sizes(allSizes[0]);
|
||||||
|
|
||||||
auto packOp = cast<PackOp>(op);
|
auto packOp = cast<PackOp>(op);
|
||||||
Location loc = packOp.getLoc();
|
Location loc = packOp.getLoc();
|
||||||
@@ -923,8 +973,8 @@ struct PackOpTiling
|
|||||||
tiledOperands.push_back(sourceSlice);
|
tiledOperands.push_back(sourceSlice);
|
||||||
|
|
||||||
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
|
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
|
||||||
if (failed(getIterationDomainTileFromOperandTile(
|
if (failed(getIterationDomainTileFromOperandTiles(
|
||||||
op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
|
op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
|
||||||
outerDimSizes)))
|
outerDimSizes)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -1182,12 +1232,21 @@ struct UnPackOpTiling
|
|||||||
|
|
||||||
/// Method to return the position of iteration domain tile computed by the
|
/// Method to return the position of iteration domain tile computed by the
|
||||||
/// tiled operation.
|
/// tiled operation.
|
||||||
LogicalResult getIterationDomainTileFromOperandTile(
|
LogicalResult getIterationDomainTileFromOperandTiles(
|
||||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
|
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||||
|
ArrayRef<SmallVector<OpFoldResult>> allSizes,
|
||||||
SmallVectorImpl<OpFoldResult> &resultOffsets,
|
SmallVectorImpl<OpFoldResult> &resultOffsets,
|
||||||
SmallVectorImpl<OpFoldResult> &resultSizes) const {
|
SmallVectorImpl<OpFoldResult> &resultSizes) const {
|
||||||
|
if (operandNumbers.size() != 1) {
|
||||||
|
LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; });
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
auto unPackOp = cast<UnPackOp>(op);
|
auto unPackOp = cast<UnPackOp>(op);
|
||||||
|
unsigned operandNumber = operandNumbers[0];
|
||||||
|
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
|
||||||
|
ArrayRef<OpFoldResult> sizes(allSizes[0]);
|
||||||
|
|
||||||
// If the operand tile is the dest, then no adjustment is needed.
|
// If the operand tile is the dest, then no adjustment is needed.
|
||||||
if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
|
if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
|
||||||
resultOffsets = llvm::to_vector(offsets);
|
resultOffsets = llvm::to_vector(offsets);
|
||||||
@@ -1241,10 +1300,18 @@ struct UnPackOpTiling
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Method to return the tiled implementation of tensor.unpack as a consumer.
|
/// Method to return the tiled implementation of tensor.unpack as a consumer.
|
||||||
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
|
FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
|
||||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
|
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||||
|
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
|
||||||
|
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
|
||||||
|
LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; });
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
auto unPackOp = cast<UnPackOp>(op);
|
auto unPackOp = cast<UnPackOp>(op);
|
||||||
|
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
|
||||||
|
ArrayRef<OpFoldResult> sizes(allSizes[0]);
|
||||||
|
|
||||||
// tensor.unpack op is fusible (as a consumer) only if inner dims are not
|
// tensor.unpack op is fusible (as a consumer) only if inner dims are not
|
||||||
// tiled.
|
// tiled.
|
||||||
int64_t numTiles = unPackOp.getInnerDimsPos().size();
|
int64_t numTiles = unPackOp.getInnerDimsPos().size();
|
||||||
@@ -1259,8 +1326,8 @@ struct UnPackOpTiling
|
|||||||
// Fetch offset/size for creating the slice of the dest operand of
|
// Fetch offset/size for creating the slice of the dest operand of
|
||||||
// unpack op.
|
// unpack op.
|
||||||
SmallVector<OpFoldResult> outputOffsets, outputSizes;
|
SmallVector<OpFoldResult> outputOffsets, outputSizes;
|
||||||
if (failed(getIterationDomainTileFromOperandTile(
|
if (failed(getIterationDomainTileFromOperandTiles(
|
||||||
op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
|
op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
|
||||||
outputSizes)))
|
outputSizes)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
|||||||
@@ -2047,53 +2047,119 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
|
|||||||
|
|
||||||
/// A utility to fetch an untiled consumer of
|
/// A utility to fetch an untiled consumer of
|
||||||
/// tensor.insert_slice/tensor.parallel_insert_slice.
|
/// tensor.insert_slice/tensor.parallel_insert_slice.
|
||||||
static FailureOr<OpOperand *>
|
static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
|
||||||
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
|
RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
|
||||||
MutableArrayRef<LoopLikeOpInterface> loops) {
|
MutableArrayRef<LoopLikeOpInterface> loops) {
|
||||||
assert(!loops.empty() && "unexpected empty loops");
|
assert(!loops.empty() && "unexpected empty loops");
|
||||||
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
|
assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
|
||||||
return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
|
SmallVector<OpOperand *> fusedOperands;
|
||||||
} else if (auto parallelInsertSlice =
|
for (auto sliceOp : sliceOps) {
|
||||||
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
|
FailureOr<OpOperand *> fusedOperand =
|
||||||
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
|
TypeSwitch<Operation *, FailureOr<OpOperand *>>(sliceOp)
|
||||||
} else {
|
.Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
|
||||||
return failure();
|
[&](auto op) {
|
||||||
|
return getUntiledConsumerFromSlice(rewriter, op, loops);
|
||||||
|
})
|
||||||
|
.Default([&](Operation *op) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "unhandled slice type");
|
||||||
|
});
|
||||||
|
if (failed(fusedOperand)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!fusedOperands.empty() &&
|
||||||
|
fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
fusedOperand.value()->getOwner(),
|
||||||
|
"all candidate slices must be to the same consumer");
|
||||||
|
}
|
||||||
|
fusedOperands.push_back(fusedOperand.value());
|
||||||
}
|
}
|
||||||
|
return fusedOperands;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InsertSliceOpTy>
|
||||||
|
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
|
||||||
|
InsertSliceOpTy sliceOp);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
tensor::InsertSliceOp
|
||||||
|
cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
|
||||||
|
tensor::InsertSliceOp insertSliceOp) {
|
||||||
|
return cast<tensor::InsertSliceOp>(
|
||||||
|
rewriter.clone(*insertSliceOp.getOperation()));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
|
||||||
|
RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
|
||||||
|
return rewriter.create<tensor::InsertSliceOp>(
|
||||||
|
insertSliceOp->getLoc(), insertSliceOp.getSource(),
|
||||||
|
insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
|
||||||
|
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<tensor::InsertSliceOp>
|
||||||
|
cloneAsInsertSlices(RewriterBase &rewriter,
|
||||||
|
ArrayRef<Operation *> candidateSlices) {
|
||||||
|
assert(!candidateSlices.empty() &&
|
||||||
|
"unexpected empty list of slices to clone");
|
||||||
|
SmallVector<tensor::InsertSliceOp> clonedSlices;
|
||||||
|
for (auto sliceOp : candidateSlices) {
|
||||||
|
TypeSwitch<Operation *>(sliceOp)
|
||||||
|
.Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
|
||||||
|
[&](auto op) {
|
||||||
|
auto clonedOp = cloneAsInsertSlice(rewriter, op);
|
||||||
|
clonedSlices.push_back(clonedOp);
|
||||||
|
})
|
||||||
|
.Default([&](Operation *op) {
|
||||||
|
// Assert here assuming this has already been checked.
|
||||||
|
assert(0 && "unexpected slice type while cloning as insert slice");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return clonedSlices;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation of fusing consumer of a single slice by computing the
|
/// Implementation of fusing consumer of a single slice by computing the
|
||||||
/// slice of the consumer in-place for scf loop.
|
/// slice of the consumer in-place for scf loop.
|
||||||
FailureOr<scf::SCFFuseConsumerOfSliceResult>
|
FailureOr<scf::SCFFuseConsumerOfSliceResult>
|
||||||
mlir::scf::tileAndFuseConsumerOfSlice(
|
mlir::scf::tileAndFuseConsumerOfSlices(
|
||||||
RewriterBase &rewriter, Operation *candidateSliceOp,
|
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
|
||||||
MutableArrayRef<LoopLikeOpInterface> loops) {
|
MutableArrayRef<LoopLikeOpInterface> loops) {
|
||||||
|
if (candidateSlices.empty()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
rewriter.getUnknownLoc(),
|
||||||
|
"no candidate slices provided for consumer fusion");
|
||||||
|
}
|
||||||
// Return if `loops` is empty, return an error for now. Caller is expected
|
// Return if `loops` is empty, return an error for now. Caller is expected
|
||||||
// to handle this case.
|
// to handle this case.
|
||||||
if (loops.empty()) {
|
if (loops.empty()) {
|
||||||
return candidateSliceOp->emitOpError(
|
return rewriter.notifyMatchFailure(
|
||||||
|
candidateSlices.front(),
|
||||||
"cannot call tile and fuse consumer with an empty loop nest");
|
"cannot call tile and fuse consumer with an empty loop nest");
|
||||||
}
|
}
|
||||||
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
|
|
||||||
candidateSliceOp))
|
if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
|
||||||
return failure();
|
llvm::all_of(candidateSlices,
|
||||||
|
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
candidateSlices.front(),
|
||||||
|
"candidates slices need to be all `tensor.extract_slice`s or "
|
||||||
|
"`tensor.parallel_insert_slice`s");
|
||||||
|
}
|
||||||
|
|
||||||
// 1. Get the consumer of scf.for for the result yielded by
|
// 1. Get the consumer of scf.for for the result yielded by
|
||||||
// tensor.insert_slice/parallel_insert_slice.
|
// tensor.insert_slice/parallel_insert_slice.
|
||||||
FailureOr<OpOperand *> maybeConsumerOpOperand =
|
SmallVector<OpOperand *> consumerOpOperands;
|
||||||
getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
|
Operation *consumerOp;
|
||||||
if (failed(maybeConsumerOpOperand)) {
|
{
|
||||||
return rewriter.notifyMatchFailure(candidateSliceOp,
|
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
|
||||||
"could not fetch consumer to fuse");
|
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
|
||||||
}
|
if (failed(maybeConsumerOpOperand)) {
|
||||||
OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
|
return rewriter.notifyMatchFailure(candidateSlices.front(),
|
||||||
Operation *consumerOp = consumerOpOperand->getOwner();
|
"could not fetch consumer to fuse");
|
||||||
unsigned operandNumber = consumerOpOperand->getOperandNumber();
|
}
|
||||||
unsigned resultNumber = 0;
|
std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
|
||||||
if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
|
consumerOp = consumerOpOperands.front()->getOwner();
|
||||||
resultNumber = producerResult.getResultNumber();
|
|
||||||
} else {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LoopLikeOpInterface outerMostLoop = loops.front();
|
LoopLikeOpInterface outerMostLoop = loops.front();
|
||||||
@@ -2113,16 +2179,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
|
|||||||
if (!dstOp)
|
if (!dstOp)
|
||||||
return rewriter.notifyMatchFailure(consumerOp,
|
return rewriter.notifyMatchFailure(consumerOp,
|
||||||
"consumer op is not DPS operation");
|
"consumer op is not DPS operation");
|
||||||
SmallVector<Value> dpsInits =
|
if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
|
||||||
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
|
return dstOp.isDpsInit(opOperand);
|
||||||
if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
|
})) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
consumerOp,
|
consumerOp,
|
||||||
"consumer op taking the result of scf.for as init is not supported");
|
"consumer op taking the result of scf.for as init is not supported");
|
||||||
}
|
}
|
||||||
SmallVector<Value> newInits = dpsInits;
|
SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
|
||||||
|
|
||||||
Location loc = outerMostLoop->getLoc();
|
|
||||||
|
|
||||||
// 3. Move the whole loop structure right before firstUserOfLoop, the
|
// 3. Move the whole loop structure right before firstUserOfLoop, the
|
||||||
// dominance should be already ensured by `checkAssumptionForLoop`.
|
// dominance should be already ensured by `checkAssumptionForLoop`.
|
||||||
@@ -2137,43 +2201,52 @@ mlir::scf::tileAndFuseConsumerOfSlice(
|
|||||||
// tensor.insert_slice. In the scf.for case this is a clone of the
|
// tensor.insert_slice. In the scf.for case this is a clone of the
|
||||||
// candidateSliceOp whereas in the scf.forall case this is created from the
|
// candidateSliceOp whereas in the scf.forall case this is created from the
|
||||||
// operands of tensor.parallel_insert_slice.
|
// operands of tensor.parallel_insert_slice.
|
||||||
tensor::InsertSliceOp clonedInsertSliceOp;
|
|
||||||
if (auto sliceOp =
|
if (auto sliceOp =
|
||||||
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
|
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
|
||||||
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
|
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
|
||||||
rewriter.setInsertionPoint(newForallOp.getTerminator());
|
rewriter.setInsertionPoint(newForallOp.getTerminator());
|
||||||
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
|
|
||||||
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
|
|
||||||
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
|
|
||||||
} else {
|
} else {
|
||||||
rewriter.setInsertionPoint(candidateSliceOp);
|
rewriter.setInsertionPoint(candidateSlices.front());
|
||||||
clonedInsertSliceOp =
|
|
||||||
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
|
|
||||||
}
|
}
|
||||||
|
// 5.a. Clone all the candidate slices as equivalent insert slice ops.
|
||||||
|
SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
|
||||||
|
cloneAsInsertSlices(rewriter, candidateSlices);
|
||||||
|
|
||||||
// 5.a. Clone consumer op.
|
// 5.b. Clone consumer op.
|
||||||
auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
|
auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
|
||||||
|
SmallVector<unsigned> operandNumbers =
|
||||||
|
llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
|
||||||
|
return opOperand->getOperandNumber();
|
||||||
|
});
|
||||||
|
SmallVector<OpOperand *> clonedOpFusedOperandsList =
|
||||||
|
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
|
||||||
|
return &clonedConsumerOp->getOpOperand(operandNum);
|
||||||
|
});
|
||||||
|
|
||||||
// 5.b. Replace all uses of the loop result with the result of the cloned
|
// 5.c. Replace all uses of the loop result with the result of the cloned
|
||||||
// tensor.insert_slice.
|
// tensor.insert_slice.
|
||||||
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
|
|
||||||
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
|
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
|
||||||
operandToReplace.set(clonedInsertSliceOp.getResult());
|
for (auto [operandToReplace, clonedSliceOp] :
|
||||||
|
llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
|
||||||
|
operandToReplace->set(clonedSliceOp.getResult());
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// 6. Perform tiling of the cloned consumer and replace the operand at
|
// 6. Perform tiling of the cloned consumer and replace the operand at
|
||||||
// `operandNumber` with the source of the cloned tensor.insert_slice op.
|
// `operandNumber` with the source of the cloned tensor.insert_slice op.
|
||||||
auto ossSliceOp =
|
|
||||||
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
|
|
||||||
FailureOr<TilingResult> tileAndFuseResult =
|
FailureOr<TilingResult> tileAndFuseResult =
|
||||||
tensor::replaceInsertSliceWithTiledConsumer(
|
tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
|
||||||
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
|
clonedOpFusedOperandsList);
|
||||||
if (failed(tileAndFuseResult)) {
|
if (failed(tileAndFuseResult)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
|
auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
|
||||||
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
|
for (auto [operandNum, clonedSliceOp] :
|
||||||
clonedInsertSliceOp.getSource());
|
llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
|
||||||
|
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
|
||||||
|
clonedSliceOp.getSource());
|
||||||
|
}
|
||||||
|
|
||||||
// 7. Reconstruct [nested] loop with new inits.
|
// 7. Reconstruct [nested] loop with new inits.
|
||||||
YieldTiledValuesFn newYieldValuesFn =
|
YieldTiledValuesFn newYieldValuesFn =
|
||||||
@@ -2185,14 +2258,20 @@ mlir::scf::tileAndFuseConsumerOfSlice(
|
|||||||
// 8. Set inner insertPoint right before tiled consumer op.
|
// 8. Set inner insertPoint right before tiled consumer op.
|
||||||
innerRewriter.setInsertionPoint(tiledConsumerOp);
|
innerRewriter.setInsertionPoint(tiledConsumerOp);
|
||||||
|
|
||||||
SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
|
SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
|
||||||
SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
|
for (auto candidateSliceOp : clonedInsertSlices) {
|
||||||
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
|
SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
|
||||||
|
SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
|
||||||
|
SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
|
||||||
|
|
||||||
// 9. Check all insert stride is 1.
|
// 9. Check all insert stride is 1.
|
||||||
if (!llvm::all_of(strides, isOneInteger)) {
|
if (!llvm::all_of(strides, isOneInteger)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
candidateSliceOp, "containingOp's result yield with stride");
|
candidateSliceOp, "containingOp's result yield with stride");
|
||||||
|
}
|
||||||
|
|
||||||
|
allOffsets.emplace_back(std::move(offsets));
|
||||||
|
allSizes.emplace_back(std::move(sizes));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 10. Try to get iter domain position from input position. Use
|
// 10. Try to get iter domain position from input position. Use
|
||||||
@@ -2202,8 +2281,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(
|
|||||||
// tiledConsumerOp could lead to some chained unnecessary extra index
|
// tiledConsumerOp could lead to some chained unnecessary extra index
|
||||||
// computation.
|
// computation.
|
||||||
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
|
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
|
||||||
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
|
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
|
||||||
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
|
rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
|
||||||
iterDomainSizes))) {
|
iterDomainSizes))) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
clonedConsumerOp,
|
clonedConsumerOp,
|
||||||
@@ -2279,10 +2358,13 @@ mlir::scf::tileAndFuseConsumerOfSlice(
|
|||||||
// 16. Need to erase the old scf loop and the cloned consumer op.
|
// 16. Need to erase the old scf loop and the cloned consumer op.
|
||||||
rewriter.eraseOp(clonedConsumerOp);
|
rewriter.eraseOp(clonedConsumerOp);
|
||||||
|
|
||||||
|
SmallVector<OpOperand *> tiledAndFusedOpOperands =
|
||||||
|
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
|
||||||
|
return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
|
||||||
|
});
|
||||||
return scf::SCFFuseConsumerOfSliceResult{
|
return scf::SCFFuseConsumerOfSliceResult{
|
||||||
consumerOpOperand,
|
std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
|
||||||
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
|
std::move(tileAndFuseResult->tiledOps)};
|
||||||
tileAndFuseResult->tiledOps};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|||||||
@@ -17,6 +17,9 @@
|
|||||||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
||||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||||
#include "mlir/Interfaces/TilingInterface.h"
|
#include "mlir/Interfaces/TilingInterface.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "tensor-swap-slices"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -39,21 +42,55 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
|
|||||||
return *tiledResult;
|
return *tiledResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
|
FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
|
||||||
OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
|
OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
|
||||||
OpOperand &consumer) {
|
ArrayRef<OpOperand *> consumerOperands) {
|
||||||
auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
|
if (sliceOps.empty()) {
|
||||||
|
LLVM_DEBUG(
|
||||||
|
{ llvm::dbgs() << "expected candidate slices list to be non-empty"; });
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (sliceOps.size() != consumerOperands.size()) {
|
||||||
|
LLVM_DEBUG({
|
||||||
|
llvm::dbgs()
|
||||||
|
<< "expected as many operands as the number of slices passed";
|
||||||
|
});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto consumerOp =
|
||||||
|
dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
|
||||||
if (!consumerOp)
|
if (!consumerOp)
|
||||||
return failure();
|
return failure();
|
||||||
|
for (auto opOperand : consumerOperands.drop_front()) {
|
||||||
|
if (opOperand->getOwner() != consumerOp) {
|
||||||
|
LLVM_DEBUG({
|
||||||
|
llvm::dbgs()
|
||||||
|
<< "expected all consumer operands to be from the same operation";
|
||||||
|
});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// `TilingInterface` currently only supports strides being 1.
|
auto consumerOperandNums = llvm::map_to_vector(
|
||||||
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
|
consumerOperands, [](OpOperand *opOperand) -> unsigned {
|
||||||
return failure();
|
return opOperand->getOperandNumber();
|
||||||
|
});
|
||||||
|
SmallVector<SmallVector<OpFoldResult>> allOffsets;
|
||||||
|
SmallVector<SmallVector<OpFoldResult>> allSizes;
|
||||||
|
for (auto sliceOp : sliceOps) {
|
||||||
|
|
||||||
|
// `TilingInterface` currently only supports strides being 1.
|
||||||
|
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
|
||||||
|
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
|
||||||
|
allOffsets.emplace_back(std::move(offsets));
|
||||||
|
allSizes.emplace_back(std::move(sizes));
|
||||||
|
}
|
||||||
FailureOr<TilingResult> tiledResult =
|
FailureOr<TilingResult> tiledResult =
|
||||||
consumerOp.getTiledImplementationFromOperandTile(
|
consumerOp.getTiledImplementationFromOperandTiles(
|
||||||
builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
|
builder, consumerOperandNums, allOffsets, allSizes);
|
||||||
sliceOp.getMixedSizes());
|
|
||||||
if (failed(tiledResult))
|
if (failed(tiledResult))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
|||||||
@@ -653,6 +653,7 @@ module {
|
|||||||
%5 = affine.min #map2(%i)[%d0, %idx]
|
%5 = affine.min #map2(%i)[%d0, %idx]
|
||||||
%6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
|
%6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
|
||||||
|
|
||||||
|
// CHECK: linalg.generic
|
||||||
// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
|
// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
|
||||||
// CHECK: %[[T2:.*]] = linalg.generic {{.*}}
|
// CHECK: %[[T2:.*]] = linalg.generic {{.*}}
|
||||||
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
|
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
|
// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
|
||||||
|
|
||||||
#map = affine_map<(d0) -> (d0)>
|
#map = affine_map<(d0) -> (d0)>
|
||||||
module {
|
module {
|
||||||
@@ -620,3 +620,294 @@ module attributes {transform.with_named_sequence} {
|
|||||||
transform.yield
|
transform.yield
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||||
|
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
||||||
|
%loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
|
||||||
|
%tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
|
||||||
|
%arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||||
|
%init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
|
||||||
|
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
|
||||||
|
%generic:2 = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
|
||||||
|
iterator_types = ["parallel", "reduction"]}
|
||||||
|
ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||||
|
%0 = arith.mulf %b0, %b1 : f32
|
||||||
|
%1 = arith.addf %b0, %b2 : f32
|
||||||
|
linalg.yield %0, %1 : f32, f32
|
||||||
|
} -> (tensor<?xf32>, tensor<?xf32>)
|
||||||
|
scf.forall.in_parallel {
|
||||||
|
tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
|
||||||
|
tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
%empty = tensor.empty(%dim0) : tensor<?xf32>
|
||||||
|
%result = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||||
|
iterator_types = ["parallel"]}
|
||||||
|
ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||||
|
%0 = arith.addf %b0, %b1 : f32
|
||||||
|
linalg.yield %0 : f32
|
||||||
|
} -> tensor<?xf32>
|
||||||
|
return %result : tensor<?xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @multi_slice_fusion1(
|
||||||
|
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
|
||||||
|
// CHECK: %[[C0:.+]] = arith.constant 0
|
||||||
|
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
|
||||||
|
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
|
||||||
|
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
|
||||||
|
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
|
||||||
|
// CHECK: %[[TILESIZE:.+]] = affine.min
|
||||||
|
// CHECK-DAG: %[[GENERIC:.+]]:2 = linalg.generic
|
||||||
|
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
|
||||||
|
// CHECK: %[[FUSED:.+]] = linalg.generic
|
||||||
|
// CHECK-SAME: ins(%[[GENERIC]]#0, %[[GENERIC]]#1 :
|
||||||
|
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
|
||||||
|
// CHECK: return %[[RESULT]]#2
|
||||||
|
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
|
||||||
|
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||||
|
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
|
||||||
|
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Check that when the given operand tiles are inconsistent, tiling fails.
|
||||||
|
|
||||||
|
func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||||
|
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
||||||
|
%loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
|
||||||
|
%tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
|
||||||
|
%arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||||
|
%init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
|
||||||
|
%generic0 = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
|
||||||
|
iterator_types = ["parallel", "reduction"]}
|
||||||
|
ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32):
|
||||||
|
%0 = arith.mulf %b0, %b1 : f32
|
||||||
|
linalg.yield %0 : f32
|
||||||
|
} -> tensor<?xf32>
|
||||||
|
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
|
||||||
|
%generic1 = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
|
||||||
|
iterator_types = ["parallel", "reduction"]}
|
||||||
|
ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32):
|
||||||
|
%0 = arith.addf %b0, %b1 : f32
|
||||||
|
linalg.yield %0: f32
|
||||||
|
} -> tensor<?xf32>
|
||||||
|
scf.forall.in_parallel {
|
||||||
|
tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
|
||||||
|
tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
%empty = tensor.empty(%dim0) : tensor<?xf32>
|
||||||
|
%result = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||||
|
iterator_types = ["parallel"]}
|
||||||
|
ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||||
|
%0 = arith.addf %b0, %b1 : f32
|
||||||
|
linalg.yield %0 : f32
|
||||||
|
} -> tensor<?xf32>
|
||||||
|
return %result : tensor<?xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @multi_slice_fusion2(
|
||||||
|
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
|
||||||
|
// CHECK: %[[C0:.+]] = arith.constant 0
|
||||||
|
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
|
||||||
|
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
|
||||||
|
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
|
||||||
|
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
|
||||||
|
// CHECK: %[[TILESIZE:.+]] = affine.min
|
||||||
|
// CHECK: %[[GENERIC0:.+]] = linalg.generic
|
||||||
|
// CHECK: %[[GENERIC1:.+]] = linalg.generic
|
||||||
|
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
|
||||||
|
// CHECK: %[[FUSED:.+]] = linalg.generic
|
||||||
|
// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
|
||||||
|
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
|
||||||
|
// CHECK: return %[[RESULT]]#2
|
||||||
|
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
|
||||||
|
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||||
|
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
|
||||||
|
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
|
||||||
|
%arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c2 = arith.constant 2 : index
|
||||||
|
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
|
||||||
|
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
|
||||||
|
%dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
|
||||||
|
%loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
|
||||||
|
shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) {
|
||||||
|
%tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
|
||||||
|
%tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
|
||||||
|
%arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
|
||||||
|
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
||||||
|
%init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||||
|
: tensor<?x?xf32> to tensor<?x?xf32>
|
||||||
|
%generic0 = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||||
|
iterator_types = ["parallel", "parallel", "reduction"]}
|
||||||
|
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32):
|
||||||
|
%0 = arith.mulf %b0, %b1 : f32
|
||||||
|
linalg.yield %0 : f32
|
||||||
|
} -> tensor<?x?xf32>
|
||||||
|
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32>
|
||||||
|
%generic1 = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
|
||||||
|
iterator_types = ["parallel", "reduction"]}
|
||||||
|
ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32):
|
||||||
|
%0 = arith.addf %b0, %b1 : f32
|
||||||
|
linalg.yield %0: f32
|
||||||
|
} -> tensor<?xf32>
|
||||||
|
scf.forall.in_parallel {
|
||||||
|
tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||||
|
: tensor<?x?xf32> into tensor<?x?xf32>
|
||||||
|
tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
%empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
|
||||||
|
%result = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
|
iterator_types = ["parallel", "parallel"]}
|
||||||
|
ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||||
|
%0 = arith.addf %b0, %b1 : f32
|
||||||
|
linalg.yield %0 : f32
|
||||||
|
} -> tensor<?x?xf32>
|
||||||
|
return %result : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
|
||||||
|
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||||
|
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
|
||||||
|
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @multi_slice_fusion_with_broadcast(
|
||||||
|
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
|
||||||
|
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
|
||||||
|
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
|
||||||
|
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]])
|
||||||
|
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
|
||||||
|
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
|
||||||
|
// CHECK-DAG: %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]])
|
||||||
|
// CHECK-DAG: %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]])
|
||||||
|
// CHECK: %[[GENERIC0:.+]] = linalg.generic
|
||||||
|
// CHECK: %[[GENERIC1:.+]] = linalg.generic
|
||||||
|
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
|
||||||
|
// CHECK: %[[FUSED:.+]] = linalg.generic
|
||||||
|
// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
|
||||||
|
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
|
||||||
|
// CHECK: return %[[RESULT]]#2
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
|
||||||
|
%arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c2 = arith.constant 2 : index
|
||||||
|
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
|
||||||
|
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
|
||||||
|
%dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
|
||||||
|
%loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
|
||||||
|
shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||||
|
%tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
|
||||||
|
%tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
|
||||||
|
%arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
|
||||||
|
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
||||||
|
%init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||||
|
: tensor<?x?xf32> to tensor<?x?xf32>
|
||||||
|
%generic0 = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||||
|
iterator_types = ["parallel", "parallel", "reduction"]}
|
||||||
|
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32):
|
||||||
|
%0 = arith.mulf %b0, %b1 : f32
|
||||||
|
linalg.yield %0 : f32
|
||||||
|
} -> tensor<?x?xf32>
|
||||||
|
%init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||||
|
: tensor<?x?xf32> to tensor<?x?xf32>
|
||||||
|
%generic1 = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||||
|
iterator_types = ["parallel", "parallel", "reduction"]}
|
||||||
|
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32):
|
||||||
|
%0 = arith.addf %b0, %b1 : f32
|
||||||
|
linalg.yield %0: f32
|
||||||
|
} -> tensor<?x?xf32>
|
||||||
|
scf.forall.in_parallel {
|
||||||
|
// expected-error @below {{failed to fuse consumer of slice}}
|
||||||
|
tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||||
|
: tensor<?x?xf32> into tensor<?x?xf32>
|
||||||
|
tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||||
|
: tensor<?x?xf32> into tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
%empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
|
||||||
|
%result = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
|
iterator_types = ["parallel", "parallel"]}
|
||||||
|
ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
|
||||||
|
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||||
|
%0 = arith.addf %b0, %b1 : f32
|
||||||
|
linalg.yield %0 : f32
|
||||||
|
} -> tensor<?x?xf32>
|
||||||
|
return %result : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
|
||||||
|
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||||
|
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
|
||||||
|
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,9 @@
|
|||||||
#include "mlir/IR/Dominance.h"
|
#include "mlir/IR/Dominance.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/Interfaces/TilingInterface.h"
|
#include "mlir/Interfaces/TilingInterface.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "test-tiling-interface"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "TestTilingInterfaceTransformOps.h.inc"
|
#include "TestTilingInterfaceTransformOps.h.inc"
|
||||||
@@ -168,29 +171,30 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
|
|||||||
|
|
||||||
/// Apply fusing of consumer transformation to all payload ops and store both
|
/// Apply fusing of consumer transformation to all payload ops and store both
|
||||||
/// the original consumer operation as well as the fused consumer operation.
|
/// the original consumer operation as well as the fused consumer operation.
|
||||||
template <typename Range>
|
|
||||||
static LogicalResult applyFuseConsumer(
|
static LogicalResult applyFuseConsumer(
|
||||||
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
|
RewriterBase &rewriter, Operation *transformOp,
|
||||||
MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse,
|
ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
|
||||||
TransformResults &transformResults) {
|
uint32_t numConsumerToFuse, TransformResults &transformResults) {
|
||||||
SmallVector<Operation *> originalConsumerOps;
|
SmallVector<Operation *> originalConsumerOps;
|
||||||
SmallVector<Operation *> fusedConsumerOps;
|
SmallVector<Operation *> fusedConsumerOps;
|
||||||
|
|
||||||
for (Operation *target : payloadOps) {
|
rewriter.setInsertionPoint(slices.front());
|
||||||
rewriter.setInsertionPoint(target);
|
|
||||||
|
|
||||||
while (numConsumerToFuse--) {
|
while (numConsumerToFuse--) {
|
||||||
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
|
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
|
||||||
scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
|
scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
|
||||||
|
|
||||||
if (failed(fuseConsumerResults))
|
if (failed(fuseConsumerResults))
|
||||||
return failure();
|
return slices.front()->emitOpError("failed to fuse consumer of slice");
|
||||||
|
|
||||||
// Report back the relevant handles to the transform op.
|
// Report back the relevant handles to the transform op.
|
||||||
originalConsumerOps.push_back(
|
for (OpOperand *origConsumerOperand :
|
||||||
fuseConsumerResults->origConsumerOperand->getOwner());
|
fuseConsumerResults->origConsumerOperands) {
|
||||||
fusedConsumerOps.push_back(
|
originalConsumerOps.push_back(origConsumerOperand->getOwner());
|
||||||
fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
|
}
|
||||||
|
for (OpOperand *tiledAndFusedConsumerOperand :
|
||||||
|
fuseConsumerResults->tiledAndFusedConsumerOperands) {
|
||||||
|
fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,6 +207,12 @@ DiagnosedSilenceableFailure
|
|||||||
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
|
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
|
||||||
TransformResults &transformResults,
|
TransformResults &transformResults,
|
||||||
TransformState &state) {
|
TransformState &state) {
|
||||||
|
SmallVector<Operation *> slices;
|
||||||
|
for (auto op : getTargets()) {
|
||||||
|
auto sliceOp = *state.getPayloadOps(op).begin();
|
||||||
|
slices.push_back(sliceOp);
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<LoopLikeOpInterface> loops;
|
SmallVector<LoopLikeOpInterface> loops;
|
||||||
for (auto op : llvm::reverse(getLoops())) {
|
for (auto op : llvm::reverse(getLoops())) {
|
||||||
auto loopLikeOp =
|
auto loopLikeOp =
|
||||||
@@ -212,16 +222,16 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
|
|||||||
}
|
}
|
||||||
loops.push_back(loopLikeOp);
|
loops.push_back(loopLikeOp);
|
||||||
}
|
}
|
||||||
LogicalResult result = applyFuseConsumer(
|
LogicalResult result =
|
||||||
rewriter, getOperation(), state.getPayloadOps(getTarget()), loops,
|
applyFuseConsumer(rewriter, getOperation(), slices, loops,
|
||||||
getNumConsumerToFuse(), transformResults);
|
getNumConsumerToFuse(), transformResults);
|
||||||
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
||||||
: DiagnosedSilenceableFailure::success();
|
: DiagnosedSilenceableFailure::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void transform::TestFuseConsumerOp::getEffects(
|
void transform::TestFuseConsumerOp::getEffects(
|
||||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||||
consumesHandle(getTargetMutable(), effects);
|
consumesHandle(getTargetsMutable(), effects);
|
||||||
consumesHandle(getLoopsMutable(), effects);
|
consumesHandle(getLoopsMutable(), effects);
|
||||||
producesHandle(getOperation()->getOpResults(), effects);
|
producesHandle(getOperation()->getOpResults(), effects);
|
||||||
modifiesPayload(effects);
|
modifiesPayload(effects);
|
||||||
|
|||||||
@@ -50,7 +50,8 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
|
def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
|
||||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
[AttrSizedOperandSegments,
|
||||||
|
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||||
ReportTrackingListenerFailuresOpTrait]> {
|
ReportTrackingListenerFailuresOpTrait]> {
|
||||||
let description = [{
|
let description = [{
|
||||||
@@ -59,14 +60,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TransformHandleTypeInterface:$target,
|
Variadic<TransformHandleTypeInterface>:$targets,
|
||||||
Variadic<TransformHandleTypeInterface>:$loops,
|
Variadic<TransformHandleTypeInterface>:$loops,
|
||||||
DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
|
DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
|
||||||
let results = (outs TransformHandleTypeInterface:$consumer,
|
let results = (outs TransformHandleTypeInterface:$consumer,
|
||||||
TransformHandleTypeInterface:$fused_consumer);
|
TransformHandleTypeInterface:$fused_consumer);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$target `in` `(` $loops `)`
|
$targets `in` `(` $loops `)`
|
||||||
(`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)?
|
(`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)?
|
||||||
attr-dict `:` functional-type(operands, results)
|
attr-dict `:` functional-type(operands, results)
|
||||||
}];
|
}];
|
||||||
|
|||||||
Reference in New Issue
Block a user