[mlir][TilingInterface] Handle multi operand consumer fusion. (#145193)
For consumer fusion cases of this form
```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {
tensor.parallel_insert_slice ... into %arg0
tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#0, %0#1)
```
the current consumer fusion that handles one slice at a time cannot fuse
the consumer into the loop, since fusing along one slice will create and
SSA violation on the other use from the `scf.forall`. The solution is to
allow consumer fusion to allow considering multiple slices at once. This
PR changes the `TilingInterface` methods related to consumer fusion,
i.e.
- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`
to allow fusion while considering multiple operands. It is upto the
`TilingInterface` implementation to return an error if a list of tiles
of the operands cannot result in a consistent implementation of the
tiled operation.
The Linalg operation implementation of `TilingInterface` has been
modified to account for these changes and allow cases where operand
tiles that can result in a consistent tiling implementation are handled.
---------
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
This commit is contained in:
committed by
GitHub
parent
28f6f87061
commit
c873e5f87d
@@ -313,19 +313,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
|
||||
TilingInterface consumer,
|
||||
const SCFTileAndFuseOptions &options);
|
||||
|
||||
/// Fuse the consumer of the source of `candidateSliceOp` by computing the
|
||||
/// required slice of the consumer in-place. Note that the method
|
||||
/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
|
||||
/// value but does not delete the slice operation.
|
||||
/// Fuse the consumer `candidateSlices` by computing the required slice of the
|
||||
/// consumer in-place. All the entries of `candidateSlices` are expected to map
|
||||
/// to the same consumer. The method returns an error if the consumer cannot be
|
||||
/// tiled in a manner that is consistent for all the passed slices. Note that
|
||||
/// the method replaces the uses of `candidateSlices` with the tiled and fused
|
||||
/// consumer value but does not delete the slice operations.
|
||||
struct SCFFuseConsumerOfSliceResult {
|
||||
OpOperand *origConsumerOperand; // Original untiled consumer's operand.
|
||||
OpOperand
|
||||
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
|
||||
// Original untiled consumer operands.
|
||||
SmallVector<OpOperand *> origConsumerOperands;
|
||||
// Tiled and fused consumer operands.
|
||||
SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
|
||||
SmallVector<Operation *> tiledOps;
|
||||
};
|
||||
FailureOr<scf::SCFFuseConsumerOfSliceResult>
|
||||
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
|
||||
MutableArrayRef<LoopLikeOpInterface> loops);
|
||||
tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
|
||||
ArrayRef<Operation *> candidateSlices,
|
||||
MutableArrayRef<LoopLikeOpInterface> loops);
|
||||
|
||||
/// Method to lower an `op` that implements the `TilingInterface` to
|
||||
/// loops/scalars.
|
||||
|
||||
@@ -31,12 +31,16 @@ namespace tensor {
|
||||
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
|
||||
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
|
||||
|
||||
/// Method to swap an `tensor.insert_slice` with its consumer when the
|
||||
/// consumer implements the `TilingInterface`.
|
||||
/// Method to swap `tensor.insert_slice`s with their consumers when the
|
||||
/// consumer implements the `TilingInterface`. The size of `sliceOps` and
|
||||
/// `consumerOperands` is expected to be the same. Every entry in
|
||||
/// `consumerOperands` represents a use of the the corresponding
|
||||
/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is
|
||||
/// expected to be uses in the same consumer.
|
||||
FailureOr<TilingResult>
|
||||
replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
|
||||
OffsetSizeAndStrideOpInterface sliceOp,
|
||||
OpOperand &consumerOp);
|
||||
replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
|
||||
ArrayRef<tensor::InsertSliceOp> sliceOps,
|
||||
ArrayRef<OpOperand *> consumerOperands);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Populate functions.
|
||||
|
||||
@@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
|
||||
using PointerUnion<Attribute, Value>::PointerUnion;
|
||||
|
||||
public:
|
||||
void dump() const { llvm::errs() << *this << "\n"; }
|
||||
LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }
|
||||
|
||||
MLIRContext *getContext() const {
|
||||
PointerUnion pu = *this;
|
||||
|
||||
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Method to generate the tiled implementation of an operation that uses
|
||||
exactly a tile of the given operand.
|
||||
the exact tiles of the given operands.
|
||||
|
||||
This method is required to allow operations to be "tiled and fused"
|
||||
with an (already tiled) producer. Given a tile of the producer, this
|
||||
method generates the tile of the consumer that uses exactly this
|
||||
produced tile. In some sense it is the "reverse" of
|
||||
with an (already tiled) producer. Given tiles of the producer, this
|
||||
method generates the tile of the consumer that uses exactly these
|
||||
produced tiles. In some sense it is the "reverse" of
|
||||
`generateResultTileValue`.
|
||||
- `operandNumber` is the result of the producer used by the consumer.
|
||||
- `offsets` is the offset of the slice of the producer result used by
|
||||
the tiled implementation of the consumer.
|
||||
- `sizes` is the size of the slice of the producer result used by the
|
||||
- `operandNumbers` is the list of operands whose tiles are "producers".
|
||||
- `allOffsets` is the offset of the slice of the producer used by the
|
||||
tiled implementation of the consumer.
|
||||
- `allSizes` is the size of the slice of the producer used by the
|
||||
consumer.
|
||||
If it is illegal to fuse with a producer along the given operand for
|
||||
If it is illegal to fuse with a producer along the given operand tiles for
|
||||
an operation, the implementation should return a failure.
|
||||
}],
|
||||
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
|
||||
/*methodName=*/"getTiledImplementationFromOperandTile",
|
||||
/*methodName=*/"getTiledImplementationFromOperandTiles",
|
||||
/*args=*/(ins
|
||||
"::mlir::OpBuilder &":$b,
|
||||
"unsigned":$operandNumber,
|
||||
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
|
||||
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
|
||||
"::mlir::ArrayRef<unsigned>":$operandNumbers,
|
||||
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
|
||||
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return failure();
|
||||
@@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
||||
tile of the operand.
|
||||
|
||||
This method is required to allow operations to be "tiled and fused"
|
||||
with an (already tiled) producer. Given a tile of an operand,
|
||||
returns the tile of the iteration space that uses this tile.
|
||||
- `operandNumber` is the result of the producer used by the consumer.
|
||||
- `offsets` is the offset of the slice of the producer result used by
|
||||
with an (already tiled) producer. Given tiles of operands,
|
||||
returns the tile of the iteration space that uses these tiles.
|
||||
- `operandNumbers` is the list of operands whose tiles are "produced"
|
||||
by the producer(s).
|
||||
- `allOffsets` is the offset of the slice of the producers used by
|
||||
the tiled implementation of the consumer.
|
||||
- `sizes` is the size of the slice of the producer result used by the
|
||||
- `allSizes` is the size of the slice of the producers used by the
|
||||
consumer.
|
||||
If it is illegal to fuse with a producer along the given operand for
|
||||
an operation, or if this mapping cannot be computed, the
|
||||
implementation should return a failure.
|
||||
If it is illegal to fuse with the producer slices for an operation,
|
||||
or if this mapping cannot be computed, the implementation should
|
||||
return a failure.
|
||||
|
||||
Note that unlike the "tile consumer and fuse producer" case, the
|
||||
"tile producer and fuse consumer" requires an additional method to get
|
||||
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
||||
transformation. It does not provide guarantees on whether such a
|
||||
transformation is profitable.
|
||||
|
||||
For most cases `getTiledImplementationFromOperandTile` could be a
|
||||
implemented using `getIterationDomainTileFromOperandTile` +
|
||||
For most cases `getTiledImplementationFromOperandTiles` could be a
|
||||
implemented using `getIterationDomainTileFromOperandTiles` +
|
||||
`getTiledImplementation` methods.
|
||||
}],
|
||||
/*retType=*/"::llvm::LogicalResult",
|
||||
/*methodName=*/"getIterationDomainTileFromOperandTile",
|
||||
/*methodName=*/"getIterationDomainTileFromOperandTiles",
|
||||
/*args=*/(ins
|
||||
"::mlir::OpBuilder &":$b,
|
||||
"unsigned":$operandNumber,
|
||||
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
|
||||
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
|
||||
"::mlir::ArrayRef<unsigned>":$operandNumbers,
|
||||
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
|
||||
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
|
||||
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
|
||||
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
|
||||
/*methodBody=*/"",
|
||||
|
||||
@@ -22,8 +22,11 @@
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <optional>
|
||||
|
||||
#define DEBUG_TYPE "linalg-tiling-interface-impl"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
@@ -148,55 +151,82 @@ struct LinalgOpTilingInterface
|
||||
/// Utility to fetch the offsets and sizes when applied as per the indexing
|
||||
/// map of the linalg op. This helps in fusing the linalg op as a consumer of
|
||||
/// a given slice op.
|
||||
void
|
||||
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
SmallVectorImpl<OpFoldResult> &mappedOffsets,
|
||||
SmallVectorImpl<OpFoldResult> &mappedSizes) const {
|
||||
unsigned numLoops = linalgOp.getNumLoops();
|
||||
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
|
||||
mappedOffsets.resize(numLoops);
|
||||
mappedSizes.resize(numLoops);
|
||||
if (!indexingMap.isPermutation()) {
|
||||
SmallVector<Range> iterationDomain =
|
||||
tilingInterfaceOp.getIterationDomain(b);
|
||||
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
|
||||
mappedOffsets[index] = value.offset;
|
||||
mappedSizes[index] = value.size;
|
||||
static LogicalResult
|
||||
getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
|
||||
ArrayRef<AffineMap> indexingMaps,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allSizes,
|
||||
SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
|
||||
SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
|
||||
DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
|
||||
|
||||
for (auto [indexingMap, offsets, sizes] :
|
||||
llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
|
||||
for (auto [resultExpr, offset, size] :
|
||||
llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
|
||||
auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
|
||||
if (!dimExpr)
|
||||
continue;
|
||||
unsigned position = dimExpr.getPosition();
|
||||
auto it = mappedOffsets.find(position);
|
||||
if (it != mappedOffsets.end()) {
|
||||
OpFoldResult seenOffset = it->second;
|
||||
OpFoldResult seenSize = mappedSizes.lookup(position);
|
||||
if (seenOffset != offset || seenSize != size) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "inconsistent iteration space mapping from "
|
||||
"offsets/sizes of operands/results";
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
mappedOffsets[position] = offset;
|
||||
mappedSizes[position] = size;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto &&[index, value] :
|
||||
llvm::enumerate(indexingMap.getResults())) {
|
||||
unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
|
||||
mappedOffsets[dimPosition] = offsets[index];
|
||||
mappedSizes[dimPosition] = sizes[index];
|
||||
|
||||
// Aggregate from the given operand offsets and sizes, or default to
|
||||
// iteration space values.
|
||||
SmallVector<Range> iterationDomain =
|
||||
cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
|
||||
mappedOffsetsVec.resize(iterationDomain.size());
|
||||
mappedSizesVec.resize(iterationDomain.size());
|
||||
for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
|
||||
auto it = mappedOffsets.find(index);
|
||||
if (it != mappedOffsets.end()) {
|
||||
mappedOffsetsVec[index] = it->second;
|
||||
mappedSizesVec[index] = mappedSizes.lookup(index);
|
||||
continue;
|
||||
}
|
||||
mappedOffsetsVec[index] = domain.offset;
|
||||
mappedSizesVec[index] = domain.size;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Method to return the position of the result tile computed by the tiled
|
||||
/// operation.
|
||||
LogicalResult getIterationDomainTileFromOperandTile(
|
||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
|
||||
LogicalResult getIterationDomainTileFromOperandTiles(
|
||||
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allSizes,
|
||||
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
|
||||
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
|
||||
auto linalgOp = cast<LinalgOp>(op);
|
||||
|
||||
// Check that the indexing map used for the operand is a projected
|
||||
// permutation. This could be relaxed with a more general approach that can
|
||||
// map the offsets and sizes from the operand to iteration space tiles
|
||||
// (filling in full extent for dimensions not used to access the result).
|
||||
AffineMap indexingMap =
|
||||
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
|
||||
if (!indexingMap.isProjectedPermutation()) {
|
||||
return op->emitError()
|
||||
<< "unhandled get iter domain position when operand is not "
|
||||
"accessed using a permuted projection";
|
||||
std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
|
||||
iterationSpaceSizes;
|
||||
SmallVector<AffineMap> indexingMaps =
|
||||
llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
|
||||
OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
|
||||
return linalgOp.getMatchingIndexingMap(&opOperand);
|
||||
});
|
||||
if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
|
||||
allSizes, iterDomainOffsets,
|
||||
iterDomainSizes))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
|
||||
iterDomainOffsets, iterDomainSizes);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -247,8 +277,13 @@ struct LinalgOpTilingInterface
|
||||
"accessed using a permuted projection");
|
||||
}
|
||||
|
||||
getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
|
||||
iterDomainOffsets, iterDomainSizes);
|
||||
SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
|
||||
SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
|
||||
auto status =
|
||||
getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
|
||||
{allSizes}, iterDomainOffsets, iterDomainSizes);
|
||||
(void)status;
|
||||
assert(succeeded(status) && "unexpected error in offset calculation");
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -279,12 +314,13 @@ struct LinalgOpTilingInterface
|
||||
|
||||
/// Method to generate the tiled implementation of an operation from the tile
|
||||
/// of the operand.
|
||||
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
|
||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
|
||||
FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
|
||||
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
|
||||
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
|
||||
if (failed(getIterationDomainTileFromOperandTile(
|
||||
op, b, operandNumber, offsets, sizes, mappedOffsets,
|
||||
if (failed(getIterationDomainTileFromOperandTiles(
|
||||
op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
|
||||
mappedSizes))) {
|
||||
return failure();
|
||||
}
|
||||
@@ -837,13 +873,20 @@ struct PackOpTiling
|
||||
/// Method to return the position of iteration domain tile computed by the
|
||||
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
|
||||
/// `resultSizes` only cover outer dimensions.
|
||||
LogicalResult getIterationDomainTileFromOperandTile(
|
||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
|
||||
LogicalResult getIterationDomainTileFromOperandTiles(
|
||||
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allSizes,
|
||||
SmallVectorImpl<OpFoldResult> &resultOffsets,
|
||||
SmallVectorImpl<OpFoldResult> &resultSizes) const {
|
||||
if (operandNumber != 0)
|
||||
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
|
||||
LLVM_DEBUG(
|
||||
{ llvm::dbgs() << "unsupported operands for consumer fusion"; });
|
||||
return failure();
|
||||
}
|
||||
|
||||
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
|
||||
ArrayRef<OpFoldResult> sizes(allSizes[0]);
|
||||
|
||||
auto packOp = cast<PackOp>(op);
|
||||
// It is not trivial to infer dest tile from source tile if `packOp` has
|
||||
@@ -904,11 +947,18 @@ struct PackOpTiling
|
||||
}
|
||||
|
||||
/// Method to return the tiled implementation of tensor.pack as a consumer.
|
||||
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
|
||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
|
||||
if (operandNumber != 0)
|
||||
FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
|
||||
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
|
||||
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
|
||||
LLVM_DEBUG(
|
||||
{ llvm ::dbgs() << "unhandled operands for consumer fusion"; });
|
||||
return failure();
|
||||
}
|
||||
|
||||
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
|
||||
ArrayRef<OpFoldResult> sizes(allSizes[0]);
|
||||
|
||||
auto packOp = cast<PackOp>(op);
|
||||
Location loc = packOp.getLoc();
|
||||
@@ -923,8 +973,8 @@ struct PackOpTiling
|
||||
tiledOperands.push_back(sourceSlice);
|
||||
|
||||
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
|
||||
if (failed(getIterationDomainTileFromOperandTile(
|
||||
op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
|
||||
if (failed(getIterationDomainTileFromOperandTiles(
|
||||
op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
|
||||
outerDimSizes)))
|
||||
return failure();
|
||||
|
||||
@@ -1182,12 +1232,21 @@ struct UnPackOpTiling
|
||||
|
||||
/// Method to return the position of iteration domain tile computed by the
|
||||
/// tiled operation.
|
||||
LogicalResult getIterationDomainTileFromOperandTile(
|
||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
|
||||
LogicalResult getIterationDomainTileFromOperandTiles(
|
||||
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allSizes,
|
||||
SmallVectorImpl<OpFoldResult> &resultOffsets,
|
||||
SmallVectorImpl<OpFoldResult> &resultSizes) const {
|
||||
if (operandNumbers.size() != 1) {
|
||||
LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; });
|
||||
return failure();
|
||||
}
|
||||
auto unPackOp = cast<UnPackOp>(op);
|
||||
unsigned operandNumber = operandNumbers[0];
|
||||
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
|
||||
ArrayRef<OpFoldResult> sizes(allSizes[0]);
|
||||
|
||||
// If the operand tile is the dest, then no adjustment is needed.
|
||||
if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
|
||||
resultOffsets = llvm::to_vector(offsets);
|
||||
@@ -1241,10 +1300,18 @@ struct UnPackOpTiling
|
||||
}
|
||||
|
||||
/// Method to return the tiled implementation of tensor.unpack as a consumer.
|
||||
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
|
||||
Operation *op, OpBuilder &b, unsigned operandNumber,
|
||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
|
||||
FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
|
||||
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
|
||||
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
|
||||
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
|
||||
LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; });
|
||||
return failure();
|
||||
}
|
||||
auto unPackOp = cast<UnPackOp>(op);
|
||||
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
|
||||
ArrayRef<OpFoldResult> sizes(allSizes[0]);
|
||||
|
||||
// tensor.unpack op is fusible (as a consumer) only if inner dims are not
|
||||
// tiled.
|
||||
int64_t numTiles = unPackOp.getInnerDimsPos().size();
|
||||
@@ -1259,8 +1326,8 @@ struct UnPackOpTiling
|
||||
// Fetch offset/size for creating the slice of the dest operand of
|
||||
// unpack op.
|
||||
SmallVector<OpFoldResult> outputOffsets, outputSizes;
|
||||
if (failed(getIterationDomainTileFromOperandTile(
|
||||
op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
|
||||
if (failed(getIterationDomainTileFromOperandTiles(
|
||||
op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
|
||||
outputSizes)))
|
||||
return failure();
|
||||
|
||||
|
||||
@@ -2047,53 +2047,119 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
|
||||
|
||||
/// A utility to fetch an untiled consumer of
|
||||
/// tensor.insert_slice/tensor.parallel_insert_slice.
|
||||
static FailureOr<OpOperand *>
|
||||
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
|
||||
MutableArrayRef<LoopLikeOpInterface> loops) {
|
||||
static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
|
||||
RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
|
||||
MutableArrayRef<LoopLikeOpInterface> loops) {
|
||||
assert(!loops.empty() && "unexpected empty loops");
|
||||
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
|
||||
return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
|
||||
} else if (auto parallelInsertSlice =
|
||||
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
|
||||
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
|
||||
} else {
|
||||
return failure();
|
||||
assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
|
||||
SmallVector<OpOperand *> fusedOperands;
|
||||
for (auto sliceOp : sliceOps) {
|
||||
FailureOr<OpOperand *> fusedOperand =
|
||||
TypeSwitch<Operation *, FailureOr<OpOperand *>>(sliceOp)
|
||||
.Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
|
||||
[&](auto op) {
|
||||
return getUntiledConsumerFromSlice(rewriter, op, loops);
|
||||
})
|
||||
.Default([&](Operation *op) {
|
||||
return rewriter.notifyMatchFailure(op, "unhandled slice type");
|
||||
});
|
||||
if (failed(fusedOperand)) {
|
||||
return failure();
|
||||
}
|
||||
if (!fusedOperands.empty() &&
|
||||
fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fusedOperand.value()->getOwner(),
|
||||
"all candidate slices must be to the same consumer");
|
||||
}
|
||||
fusedOperands.push_back(fusedOperand.value());
|
||||
}
|
||||
return fusedOperands;
|
||||
}
|
||||
|
||||
template <typename InsertSliceOpTy>
|
||||
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
|
||||
InsertSliceOpTy sliceOp);
|
||||
|
||||
template <>
|
||||
tensor::InsertSliceOp
|
||||
cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
|
||||
tensor::InsertSliceOp insertSliceOp) {
|
||||
return cast<tensor::InsertSliceOp>(
|
||||
rewriter.clone(*insertSliceOp.getOperation()));
|
||||
}
|
||||
|
||||
template <>
|
||||
tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
|
||||
RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
|
||||
return rewriter.create<tensor::InsertSliceOp>(
|
||||
insertSliceOp->getLoc(), insertSliceOp.getSource(),
|
||||
insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
|
||||
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
|
||||
}
|
||||
|
||||
static SmallVector<tensor::InsertSliceOp>
|
||||
cloneAsInsertSlices(RewriterBase &rewriter,
|
||||
ArrayRef<Operation *> candidateSlices) {
|
||||
assert(!candidateSlices.empty() &&
|
||||
"unexpected empty list of slices to clone");
|
||||
SmallVector<tensor::InsertSliceOp> clonedSlices;
|
||||
for (auto sliceOp : candidateSlices) {
|
||||
TypeSwitch<Operation *>(sliceOp)
|
||||
.Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
|
||||
[&](auto op) {
|
||||
auto clonedOp = cloneAsInsertSlice(rewriter, op);
|
||||
clonedSlices.push_back(clonedOp);
|
||||
})
|
||||
.Default([&](Operation *op) {
|
||||
// Assert here assuming this has already been checked.
|
||||
assert(0 && "unexpected slice type while cloning as insert slice");
|
||||
});
|
||||
}
|
||||
return clonedSlices;
|
||||
}
|
||||
|
||||
/// Implementation of fusing consumer of a single slice by computing the
|
||||
/// slice of the consumer in-place for scf loop.
|
||||
FailureOr<scf::SCFFuseConsumerOfSliceResult>
|
||||
mlir::scf::tileAndFuseConsumerOfSlice(
|
||||
RewriterBase &rewriter, Operation *candidateSliceOp,
|
||||
mlir::scf::tileAndFuseConsumerOfSlices(
|
||||
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
|
||||
MutableArrayRef<LoopLikeOpInterface> loops) {
|
||||
if (candidateSlices.empty()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
rewriter.getUnknownLoc(),
|
||||
"no candidate slices provided for consumer fusion");
|
||||
}
|
||||
// Return if `loops` is empty, return an error for now. Caller is expected
|
||||
// to handle this case.
|
||||
if (loops.empty()) {
|
||||
return candidateSliceOp->emitOpError(
|
||||
return rewriter.notifyMatchFailure(
|
||||
candidateSlices.front(),
|
||||
"cannot call tile and fuse consumer with an empty loop nest");
|
||||
}
|
||||
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
|
||||
candidateSliceOp))
|
||||
return failure();
|
||||
|
||||
if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
|
||||
llvm::all_of(candidateSlices,
|
||||
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
candidateSlices.front(),
|
||||
"candidates slices need to be all `tensor.extract_slice`s or "
|
||||
"`tensor.parallel_insert_slice`s");
|
||||
}
|
||||
|
||||
// 1. Get the consumer of scf.for for the result yielded by
|
||||
// tensor.insert_slice/parallel_insert_slice.
|
||||
FailureOr<OpOperand *> maybeConsumerOpOperand =
|
||||
getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
|
||||
if (failed(maybeConsumerOpOperand)) {
|
||||
return rewriter.notifyMatchFailure(candidateSliceOp,
|
||||
"could not fetch consumer to fuse");
|
||||
}
|
||||
OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
|
||||
Operation *consumerOp = consumerOpOperand->getOwner();
|
||||
unsigned operandNumber = consumerOpOperand->getOperandNumber();
|
||||
unsigned resultNumber = 0;
|
||||
if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
|
||||
resultNumber = producerResult.getResultNumber();
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
|
||||
SmallVector<OpOperand *> consumerOpOperands;
|
||||
Operation *consumerOp;
|
||||
{
|
||||
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
|
||||
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
|
||||
if (failed(maybeConsumerOpOperand)) {
|
||||
return rewriter.notifyMatchFailure(candidateSlices.front(),
|
||||
"could not fetch consumer to fuse");
|
||||
}
|
||||
std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
|
||||
consumerOp = consumerOpOperands.front()->getOwner();
|
||||
}
|
||||
|
||||
LoopLikeOpInterface outerMostLoop = loops.front();
|
||||
@@ -2113,16 +2179,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(
|
||||
if (!dstOp)
|
||||
return rewriter.notifyMatchFailure(consumerOp,
|
||||
"consumer op is not DPS operation");
|
||||
SmallVector<Value> dpsInits =
|
||||
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
|
||||
if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
|
||||
if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
|
||||
return dstOp.isDpsInit(opOperand);
|
||||
})) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
consumerOp,
|
||||
"consumer op taking the result of scf.for as init is not supported");
|
||||
}
|
||||
SmallVector<Value> newInits = dpsInits;
|
||||
|
||||
Location loc = outerMostLoop->getLoc();
|
||||
SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
|
||||
|
||||
// 3. Move the whole loop structure right before firstUserOfLoop, the
|
||||
// dominance should be already ensured by `checkAssumptionForLoop`.
|
||||
@@ -2137,43 +2201,52 @@ mlir::scf::tileAndFuseConsumerOfSlice(
|
||||
// tensor.insert_slice. In the scf.for case this is a clone of the
|
||||
// candidateSliceOp whereas in the scf.forall case this is created from the
|
||||
// operands of tensor.parallel_insert_slice.
|
||||
tensor::InsertSliceOp clonedInsertSliceOp;
|
||||
if (auto sliceOp =
|
||||
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
|
||||
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
|
||||
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
|
||||
rewriter.setInsertionPoint(newForallOp.getTerminator());
|
||||
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
|
||||
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
|
||||
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
|
||||
} else {
|
||||
rewriter.setInsertionPoint(candidateSliceOp);
|
||||
clonedInsertSliceOp =
|
||||
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
|
||||
rewriter.setInsertionPoint(candidateSlices.front());
|
||||
}
|
||||
// 5.a. Clone all the candidate slices as equivalent insert slice ops.
|
||||
SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
|
||||
cloneAsInsertSlices(rewriter, candidateSlices);
|
||||
|
||||
// 5.a. Clone consumer op.
|
||||
// 5.b. Clone consumer op.
|
||||
auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
|
||||
SmallVector<unsigned> operandNumbers =
|
||||
llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
|
||||
return opOperand->getOperandNumber();
|
||||
});
|
||||
SmallVector<OpOperand *> clonedOpFusedOperandsList =
|
||||
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
|
||||
return &clonedConsumerOp->getOpOperand(operandNum);
|
||||
});
|
||||
|
||||
// 5.b. Replace all uses of the loop result with the result of the cloned
|
||||
// 5.c. Replace all uses of the loop result with the result of the cloned
|
||||
// tensor.insert_slice.
|
||||
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
|
||||
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
|
||||
operandToReplace.set(clonedInsertSliceOp.getResult());
|
||||
for (auto [operandToReplace, clonedSliceOp] :
|
||||
llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
|
||||
operandToReplace->set(clonedSliceOp.getResult());
|
||||
}
|
||||
});
|
||||
|
||||
// 6. Perform tiling of the cloned consumer and replace the operand at
|
||||
// `operandNumber` with the source of the cloned tensor.insert_slice op.
|
||||
auto ossSliceOp =
|
||||
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
|
||||
FailureOr<TilingResult> tileAndFuseResult =
|
||||
tensor::replaceInsertSliceWithTiledConsumer(
|
||||
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
|
||||
tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
|
||||
clonedOpFusedOperandsList);
|
||||
if (failed(tileAndFuseResult)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
|
||||
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
|
||||
clonedInsertSliceOp.getSource());
|
||||
for (auto [operandNum, clonedSliceOp] :
|
||||
llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
|
||||
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
|
||||
clonedSliceOp.getSource());
|
||||
}
|
||||
|
||||
// 7. Reconstruct [nested] loop with new inits.
|
||||
YieldTiledValuesFn newYieldValuesFn =
|
||||
@@ -2185,14 +2258,20 @@ mlir::scf::tileAndFuseConsumerOfSlice(
|
||||
// 8. Set inner insertPoint right before tiled consumer op.
|
||||
innerRewriter.setInsertionPoint(tiledConsumerOp);
|
||||
|
||||
SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
|
||||
SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
|
||||
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
|
||||
SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
|
||||
for (auto candidateSliceOp : clonedInsertSlices) {
|
||||
SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
|
||||
SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
|
||||
SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
|
||||
|
||||
// 9. Check all insert stride is 1.
|
||||
if (!llvm::all_of(strides, isOneInteger)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
candidateSliceOp, "containingOp's result yield with stride");
|
||||
// 9. Check all insert stride is 1.
|
||||
if (!llvm::all_of(strides, isOneInteger)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
candidateSliceOp, "containingOp's result yield with stride");
|
||||
}
|
||||
|
||||
allOffsets.emplace_back(std::move(offsets));
|
||||
allSizes.emplace_back(std::move(sizes));
|
||||
}
|
||||
|
||||
// 10. Try to get iter domain position from input position. Use
|
||||
@@ -2202,8 +2281,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(
|
||||
// tiledConsumerOp could lead to some chained unnecessary extra index
|
||||
// computation.
|
||||
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
|
||||
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
|
||||
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
|
||||
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
|
||||
rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
|
||||
iterDomainSizes))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
clonedConsumerOp,
|
||||
@@ -2279,10 +2358,13 @@ mlir::scf::tileAndFuseConsumerOfSlice(
|
||||
// 16. Need to erase the old scf loop and the cloned consumer op.
|
||||
rewriter.eraseOp(clonedConsumerOp);
|
||||
|
||||
SmallVector<OpOperand *> tiledAndFusedOpOperands =
|
||||
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
|
||||
return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
|
||||
});
|
||||
return scf::SCFFuseConsumerOfSliceResult{
|
||||
consumerOpOperand,
|
||||
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
|
||||
tileAndFuseResult->tiledOps};
|
||||
std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
|
||||
std::move(tileAndFuseResult->tiledOps)};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -17,6 +17,9 @@
|
||||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "tensor-swap-slices"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -39,21 +42,55 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
|
||||
return *tiledResult;
|
||||
}
|
||||
|
||||
FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
|
||||
OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
|
||||
OpOperand &consumer) {
|
||||
auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
|
||||
FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
|
||||
OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
|
||||
ArrayRef<OpOperand *> consumerOperands) {
|
||||
if (sliceOps.empty()) {
|
||||
LLVM_DEBUG(
|
||||
{ llvm::dbgs() << "expected candidate slices list to be non-empty"; });
|
||||
return failure();
|
||||
}
|
||||
if (sliceOps.size() != consumerOperands.size()) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs()
|
||||
<< "expected as many operands as the number of slices passed";
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
auto consumerOp =
|
||||
dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
|
||||
if (!consumerOp)
|
||||
return failure();
|
||||
for (auto opOperand : consumerOperands.drop_front()) {
|
||||
if (opOperand->getOwner() != consumerOp) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs()
|
||||
<< "expected all consumer operands to be from the same operation";
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// `TilingInterface` currently only supports strides being 1.
|
||||
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
|
||||
return failure();
|
||||
auto consumerOperandNums = llvm::map_to_vector(
|
||||
consumerOperands, [](OpOperand *opOperand) -> unsigned {
|
||||
return opOperand->getOperandNumber();
|
||||
});
|
||||
SmallVector<SmallVector<OpFoldResult>> allOffsets;
|
||||
SmallVector<SmallVector<OpFoldResult>> allSizes;
|
||||
for (auto sliceOp : sliceOps) {
|
||||
|
||||
// `TilingInterface` currently only supports strides being 1.
|
||||
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
|
||||
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
|
||||
allOffsets.emplace_back(std::move(offsets));
|
||||
allSizes.emplace_back(std::move(sizes));
|
||||
}
|
||||
FailureOr<TilingResult> tiledResult =
|
||||
consumerOp.getTiledImplementationFromOperandTile(
|
||||
builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
|
||||
sliceOp.getMixedSizes());
|
||||
consumerOp.getTiledImplementationFromOperandTiles(
|
||||
builder, consumerOperandNums, allOffsets, allSizes);
|
||||
if (failed(tiledResult))
|
||||
return failure();
|
||||
|
||||
|
||||
@@ -653,6 +653,7 @@ module {
|
||||
%5 = affine.min #map2(%i)[%d0, %idx]
|
||||
%6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
|
||||
// CHECK: %[[T2:.*]] = linalg.generic {{.*}}
|
||||
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s
|
||||
|
||||
#map = affine_map<(d0) -> (d0)>
|
||||
module {
|
||||
@@ -620,3 +620,294 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
||||
%loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
|
||||
%tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
|
||||
%arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
%init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
%generic:2 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>],
|
||||
iterator_types = ["parallel", "reduction"]}
|
||||
ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||
%0 = arith.mulf %b0, %b1 : f32
|
||||
%1 = arith.addf %b0, %b2 : f32
|
||||
linalg.yield %0, %1 : f32, f32
|
||||
} -> (tensor<?xf32>, tensor<?xf32>)
|
||||
scf.forall.in_parallel {
|
||||
tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
|
||||
tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
|
||||
}
|
||||
}
|
||||
%empty = tensor.empty(%dim0) : tensor<?xf32>
|
||||
%result = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||
%0 = arith.addf %b0, %b1 : f32
|
||||
linalg.yield %0 : f32
|
||||
} -> tensor<?xf32>
|
||||
return %result : tensor<?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @multi_slice_fusion1(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
|
||||
// CHECK: %[[C0:.+]] = arith.constant 0
|
||||
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
|
||||
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
|
||||
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
|
||||
// CHECK: %[[TILESIZE:.+]] = affine.min
|
||||
// CHECK-DAG: %[[GENERIC:.+]]:2 = linalg.generic
|
||||
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
|
||||
// CHECK: %[[FUSED:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[GENERIC]]#0, %[[GENERIC]]#1 :
|
||||
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
|
||||
// CHECK: return %[[RESULT]]#2
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
|
||||
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
|
||||
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that when the given operand tiles are inconsistent, tiling fails.
|
||||
|
||||
func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
||||
%loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
|
||||
%tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
|
||||
%arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
%init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
%generic0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
|
||||
iterator_types = ["parallel", "reduction"]}
|
||||
ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32):
|
||||
%0 = arith.mulf %b0, %b1 : f32
|
||||
linalg.yield %0 : f32
|
||||
} -> tensor<?xf32>
|
||||
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
%generic1 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
|
||||
iterator_types = ["parallel", "reduction"]}
|
||||
ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32):
|
||||
%0 = arith.addf %b0, %b1 : f32
|
||||
linalg.yield %0: f32
|
||||
} -> tensor<?xf32>
|
||||
scf.forall.in_parallel {
|
||||
tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
|
||||
tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
|
||||
}
|
||||
}
|
||||
%empty = tensor.empty(%dim0) : tensor<?xf32>
|
||||
%result = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||
%0 = arith.addf %b0, %b1 : f32
|
||||
linalg.yield %0 : f32
|
||||
} -> tensor<?xf32>
|
||||
return %result : tensor<?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @multi_slice_fusion2(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
|
||||
// CHECK: %[[C0:.+]] = arith.constant 0
|
||||
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]])
|
||||
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) =
|
||||
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
|
||||
// CHECK: %[[TILESIZE:.+]] = affine.min
|
||||
// CHECK: %[[GENERIC0:.+]] = linalg.generic
|
||||
// CHECK: %[[GENERIC1:.+]] = linalg.generic
|
||||
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
|
||||
// CHECK: %[[FUSED:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
|
||||
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]]
|
||||
// CHECK: return %[[RESULT]]#2
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
|
||||
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
|
||||
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>,
|
||||
%arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
|
||||
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
|
||||
%dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
|
||||
%loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
|
||||
shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) {
|
||||
%tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
|
||||
%tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
|
||||
%arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
|
||||
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
||||
%init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||
: tensor<?x?xf32> to tensor<?x?xf32>
|
||||
%generic0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32):
|
||||
%0 = arith.mulf %b0, %b1 : f32
|
||||
linalg.yield %0 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
%generic1 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
|
||||
iterator_types = ["parallel", "reduction"]}
|
||||
ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32):
|
||||
%0 = arith.addf %b0, %b1 : f32
|
||||
linalg.yield %0: f32
|
||||
} -> tensor<?xf32>
|
||||
scf.forall.in_parallel {
|
||||
tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||
: tensor<?x?xf32> into tensor<?x?xf32>
|
||||
tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32>
|
||||
}
|
||||
}
|
||||
%empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
|
||||
%result = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||
%0 = arith.addf %b0, %b1 : f32
|
||||
linalg.yield %0 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %result : tensor<?x?xf32>
|
||||
}
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
|
||||
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
|
||||
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
// CHECK-LABEL: func @multi_slice_fusion_with_broadcast(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
|
||||
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]])
|
||||
// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
|
||||
// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]])
|
||||
// CHECK-DAG: %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]])
|
||||
// CHECK-DAG: %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]])
|
||||
// CHECK: %[[GENERIC0:.+]] = linalg.generic
|
||||
// CHECK: %[[GENERIC1:.+]] = linalg.generic
|
||||
// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
|
||||
// CHECK: %[[FUSED:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] :
|
||||
// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]]
|
||||
// CHECK: return %[[RESULT]]#2
|
||||
|
||||
// -----
|
||||
|
||||
func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
|
||||
%arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
|
||||
%dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
|
||||
%dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
|
||||
%loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4)
|
||||
shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3]
|
||||
%tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4]
|
||||
%arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1]
|
||||
: tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
||||
%init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||
: tensor<?x?xf32> to tensor<?x?xf32>
|
||||
%generic0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32):
|
||||
%0 = arith.mulf %b0, %b1 : f32
|
||||
linalg.yield %0 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||
: tensor<?x?xf32> to tensor<?x?xf32>
|
||||
%generic1 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32):
|
||||
%0 = arith.addf %b0, %b1 : f32
|
||||
linalg.yield %0: f32
|
||||
} -> tensor<?x?xf32>
|
||||
scf.forall.in_parallel {
|
||||
// expected-error @below {{failed to fuse consumer of slice}}
|
||||
tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||
: tensor<?x?xf32> into tensor<?x?xf32>
|
||||
tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
|
||||
: tensor<?x?xf32> into tensor<?x?xf32>
|
||||
}
|
||||
}
|
||||
%empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
|
||||
%result = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||
%0 = arith.addf %b0, %b1 : f32
|
||||
linalg.yield %0 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %result : tensor<?x?xf32>
|
||||
}
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
|
||||
%loop = transform.structured.match ops{["scf.forall"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
%a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop)
|
||||
: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "test-tiling-interface"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "TestTilingInterfaceTransformOps.h.inc"
|
||||
@@ -168,29 +171,30 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
|
||||
|
||||
/// Apply fusing of consumer transformation to all payload ops and store both
|
||||
/// the original consumer operation as well as the fused consumer operation.
|
||||
template <typename Range>
|
||||
static LogicalResult applyFuseConsumer(
|
||||
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
|
||||
MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse,
|
||||
TransformResults &transformResults) {
|
||||
RewriterBase &rewriter, Operation *transformOp,
|
||||
ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
|
||||
uint32_t numConsumerToFuse, TransformResults &transformResults) {
|
||||
SmallVector<Operation *> originalConsumerOps;
|
||||
SmallVector<Operation *> fusedConsumerOps;
|
||||
|
||||
for (Operation *target : payloadOps) {
|
||||
rewriter.setInsertionPoint(target);
|
||||
rewriter.setInsertionPoint(slices.front());
|
||||
|
||||
while (numConsumerToFuse--) {
|
||||
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
|
||||
scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
|
||||
while (numConsumerToFuse--) {
|
||||
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
|
||||
scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
|
||||
|
||||
if (failed(fuseConsumerResults))
|
||||
return failure();
|
||||
if (failed(fuseConsumerResults))
|
||||
return slices.front()->emitOpError("failed to fuse consumer of slice");
|
||||
|
||||
// Report back the relevant handles to the transform op.
|
||||
originalConsumerOps.push_back(
|
||||
fuseConsumerResults->origConsumerOperand->getOwner());
|
||||
fusedConsumerOps.push_back(
|
||||
fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
|
||||
// Report back the relevant handles to the transform op.
|
||||
for (OpOperand *origConsumerOperand :
|
||||
fuseConsumerResults->origConsumerOperands) {
|
||||
originalConsumerOps.push_back(origConsumerOperand->getOwner());
|
||||
}
|
||||
for (OpOperand *tiledAndFusedConsumerOperand :
|
||||
fuseConsumerResults->tiledAndFusedConsumerOperands) {
|
||||
fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,6 +207,12 @@ DiagnosedSilenceableFailure
|
||||
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
|
||||
TransformResults &transformResults,
|
||||
TransformState &state) {
|
||||
SmallVector<Operation *> slices;
|
||||
for (auto op : getTargets()) {
|
||||
auto sliceOp = *state.getPayloadOps(op).begin();
|
||||
slices.push_back(sliceOp);
|
||||
}
|
||||
|
||||
SmallVector<LoopLikeOpInterface> loops;
|
||||
for (auto op : llvm::reverse(getLoops())) {
|
||||
auto loopLikeOp =
|
||||
@@ -212,16 +222,16 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
|
||||
}
|
||||
loops.push_back(loopLikeOp);
|
||||
}
|
||||
LogicalResult result = applyFuseConsumer(
|
||||
rewriter, getOperation(), state.getPayloadOps(getTarget()), loops,
|
||||
getNumConsumerToFuse(), transformResults);
|
||||
LogicalResult result =
|
||||
applyFuseConsumer(rewriter, getOperation(), slices, loops,
|
||||
getNumConsumerToFuse(), transformResults);
|
||||
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
||||
: DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void transform::TestFuseConsumerOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
consumesHandle(getTargetMutable(), effects);
|
||||
consumesHandle(getTargetsMutable(), effects);
|
||||
consumesHandle(getLoopsMutable(), effects);
|
||||
producesHandle(getOperation()->getOpResults(), effects);
|
||||
modifiesPayload(effects);
|
||||
|
||||
@@ -50,7 +50,8 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
|
||||
}
|
||||
|
||||
def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
ReportTrackingListenerFailuresOpTrait]> {
|
||||
let description = [{
|
||||
@@ -59,14 +60,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TransformHandleTypeInterface:$target,
|
||||
Variadic<TransformHandleTypeInterface>:$targets,
|
||||
Variadic<TransformHandleTypeInterface>:$loops,
|
||||
DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
|
||||
let results = (outs TransformHandleTypeInterface:$consumer,
|
||||
TransformHandleTypeInterface:$fused_consumer);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$target `in` `(` $loops `)`
|
||||
$targets `in` `(` $loops `)`
|
||||
(`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)?
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
Reference in New Issue
Block a user