The implementation of these methods are legacy and they are removed in favor of using the `scf::tileUsingSCF` methods as replacements. To get the latter on par with requirements of the deprecated methods, the tiling allows one to specify the maximum number of tiles to use instead of specifying the tile sizes. When tiling to `scf.forall` this specification is used to generate the `num_threads` version of the operation. A slight deviation from previous implementation is that the deprecated method always generated the `num_threads` variant of the `scf.forall` operation. Instead now this is driven by the tiling options specified. This reduces the indexing math generated when the tile sizes are specified. **Moving from `linalg::tileToForallOp` to `scf::tileUsingSCF`** ``` OpBuilder b; TilingInterface op; ArrayRef<OpFoldResult> numThreads; ArrayAttr mapping; FailureOr<ForallTilingResult> result =linalg::tileToForallOp(b, op, numThreads, mapping); ``` can be replaced by ``` scf::SCFTilingOptions options; options.setNumThreads(numThreads); options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); options.setMapping(mapping.getValue()); /*note the difference that setMapping takes an ArrayRef<Attribute> */ FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(b, op, options); ``` This generates the `numThreads` version of the `scf.forall` for the inter-tile loops, i.e. ``` ... = scf.forall (%arg0, %arg1) in (%nt0, %nt1) shared_outs(...) ``` **Moving from `linalg::tileToForallOpUsingTileSizes` to `scf::tileUsingSCF`** ``` OpBuilder b; TilingInterface op; ArrayRef<OpFoldResult> tileSizes; ArrayAttr mapping; FailureOr<ForallTilingResult> result =linalg::tileToForallOpUsingTileSizes(b, op, tileSizes, mapping); ``` can be replaced by ``` scf::SCFTilingOptions options; options.setTileSizes(tileSizes); options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); options.setMapping(mapping.getValue()); /*note the difference that setMapping takes an ArrayRef<Attribute> */ FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(b, op, options); ``` Also note that `linalg::tileToForallOpUsingTileSizes` would effectively call the `linalg::tileToForallOp` by computing the `numThreads` from the `op` and `tileSizes` and generate the `numThreads` version of the `scf.forall`. That is not the case anymore. Instead this will directly generate the `tileSizes` version of the `scf.forall` op ``` ... = scf.forall(%arg0, %arg1) = (%lb0, %lb1) to (%ub0, %ub1) step(%step0, %step1) shared_outs(...) ``` If you actually want to use the `numThreads` version, it is upto the caller to compute the `numThreads` and set `options.setNumThreads` instead of `options.setTileSizes`. Note that there is a slight difference in the num threads version and tile size version. The former requires an additional `affine.max` on the tile size to ensure non-negative tile sizes. When lowering to `numThreads` version this `affine.max` is not needed since by construction the tile sizes are non-negative. In previous implementations, the `numThreads` version generated when using the `linalg::tileToForallOpUsingTileSizes` method would avoid generating the `affine.max` operation. To get the same state, downstream users will have to additionally normalize the `scf.forall` operation. **Changes to `transform.structured.tile_using_forall`** The transform dialect op that called into `linalg::tileToForallOp` and `linalg::tileToForallOpUsingTileSizes` have been modified to call `scf::tileUsingSCF`. The transform dialect op always generates the `numThreads` version of the `scf.forall` op. So when `tile_sizes` are specified for the transform dialect op, first the `tile_sizes` version of the `scf.forall` is generated by the `scf::tileUsingSCF` method which is then further normalized to get back to the same state. So there is no functional change to `transform.structured.tile_using_forall`. It always generates the `numThreads` version of the `scf.forall` op (as it did before this change). --------- Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
1873 lines
80 KiB
C++
1873 lines
80 KiB
C++
//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the tiling using TilingInterface.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/IR/Dominance.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
|
#include "mlir/Interfaces/TilingInterface.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include <optional>
|
|
|
|
#define DEBUG_TYPE "tile-using-interface"
|
|
|
|
using namespace mlir;
|
|
|
|
scf::SCFTilingOptions &
|
|
scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
|
|
assert(!tileSizeComputationFunction && "tile sizes already set");
|
|
auto tileSizes = llvm::to_vector(ts);
|
|
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
|
|
return tileSizes;
|
|
};
|
|
return *this;
|
|
}
|
|
|
|
scf::SCFTilingOptions &
|
|
scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
|
|
assert(!numThreadsComputationFunction && "num tiles already set");
|
|
auto numThreads = llvm::to_vector(nt);
|
|
numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
|
|
return numThreads;
|
|
};
|
|
return *this;
|
|
}
|
|
|
|
/// Helper method to adjust the interchange vector to match the iteration
|
|
/// domain.
|
|
static SmallVector<int64_t>
|
|
fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
|
|
size_t iterationDomainSize) {
|
|
SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
|
|
if (filledVector.size() < iterationDomainSize) {
|
|
auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
|
|
filledVector.append(range.begin(), range.end());
|
|
}
|
|
if (filledVector.size() > iterationDomainSize)
|
|
filledVector.resize(iterationDomainSize);
|
|
return filledVector;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// tileUsingSCF implementation.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Verify the tile size options are set in a consistent manner.
|
|
static LogicalResult
|
|
verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
|
|
const scf::SCFTilingOptions &options) {
|
|
// Specifying number of threads is only supported on `scf.forall` op.
|
|
if (options.numThreadsComputationFunction &&
|
|
options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "number of threads can only by specified when loop type is "
|
|
"set to use `scf.forall`");
|
|
}
|
|
|
|
// If specified, check that the interchange vector is a permutation.
|
|
if (!options.interchangeVector.empty()) {
|
|
if (!isPermutationVector(options.interchangeVector)) {
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "invalid interchange vector, not a permutation of the entire "
|
|
"iteration space");
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Method to instantiate the tile sizes and/or number of threads specified
|
|
/// by the user.
|
|
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
|
|
getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
|
|
ArrayRef<Range> iterationDomain,
|
|
const scf::SCFTilingOptions &options) {
|
|
OpFoldResult zero = rewriter.getIndexAttr(0);
|
|
SmallVector<OpFoldResult> tileSizes, numThreads;
|
|
size_t numLoops = iterationDomain.size();
|
|
|
|
// Check whether the number of tiles to use is specified.
|
|
if (options.numThreadsComputationFunction) {
|
|
numThreads = options.numThreadsComputationFunction(rewriter, op);
|
|
numThreads.resize(numLoops, zero);
|
|
|
|
// If the number of tiles is also specified, use that.
|
|
if (options.tileSizeComputationFunction) {
|
|
tileSizes = options.tileSizeComputationFunction(rewriter, op);
|
|
tileSizes.resize(numLoops, zero);
|
|
return {tileSizes, numThreads};
|
|
}
|
|
|
|
// Compute the tile sizes from the iteration domain and number
|
|
// of tiles as follows
|
|
// - niters = ceilDiv(ub - lb, step)
|
|
// - tileSize = ceilDiv(niters, numThreads)
|
|
AffineExpr s0, s1, s2;
|
|
bindSymbols(rewriter.getContext(), s0, s1, s2);
|
|
// TODO: The step here is assumed to be 1.
|
|
AffineExpr numItersExpr = (s1 - s0);
|
|
AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
|
|
tileSizes.resize(numLoops, zero);
|
|
for (auto [index, range, nt] :
|
|
llvm::enumerate(iterationDomain, numThreads)) {
|
|
if (isConstantIntValue(nt, 0))
|
|
continue;
|
|
|
|
tileSizes[index] = affine::makeComposedFoldedAffineApply(
|
|
rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
|
|
}
|
|
tileSizes.resize(numLoops, zero);
|
|
return {tileSizes, numThreads};
|
|
}
|
|
|
|
// Enforce the convention that "tiling by zero"
|
|
// skips tiling a particular dimension. This convention is significantly
|
|
// simpler to handle instead of adjusting affine maps to account for missing
|
|
// dimensions.
|
|
assert(options.tileSizeComputationFunction &&
|
|
"expected tile sizes to be specified");
|
|
tileSizes = options.tileSizeComputationFunction(rewriter, op);
|
|
tileSizes.resize(numLoops, zero);
|
|
|
|
return {tileSizes, numThreads};
|
|
}
|
|
|
|
/// Checks if any of the tiled loops are not parallel.
|
|
static void checkSafeToTileToForall(TilingInterface op,
|
|
ArrayRef<OpFoldResult> tileSizes,
|
|
ArrayRef<OpFoldResult> numThreads) {
|
|
auto iterators = op.getLoopIteratorTypes();
|
|
assert(iterators.size() == tileSizes.size() &&
|
|
"expected as many tile size values as number of loops");
|
|
assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
|
|
"when specified, expected number of threads to use for each loop");
|
|
|
|
for (auto [index, iterator, tileSize] :
|
|
llvm::enumerate(iterators, tileSizes)) {
|
|
// If num threads is specified, check that it is greater than one only for
|
|
// parallel dimensions.
|
|
if (!numThreads.empty()) {
|
|
if (std::optional<int64_t> constNumThreads =
|
|
getConstantIntValue(numThreads[index])) {
|
|
if (constNumThreads.value() > 1 &&
|
|
iterator != utils::IteratorType::parallel) {
|
|
op.emitWarning() << "tiling is not thread safe at axis #" << index;
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
|
|
if (constTileSize.value() > 0 &&
|
|
iterator != utils::IteratorType::parallel) {
|
|
op.emitWarning() << "tiling is not thread safe at axis #" << index;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Check if `stride` evenly divides the trip count `size - offset`.
|
|
static bool tileDividesIterationDomain(Range loopRange) {
|
|
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
|
|
if (!offsetAsInt)
|
|
return false;
|
|
std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
|
|
if (!sizeAsInt)
|
|
return false;
|
|
std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
|
|
if (!strideAsInt)
|
|
return false;
|
|
return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
|
|
}
|
|
|
|
/// Returns the bounded tile size given the current `offset`, `loopRange` and
|
|
/// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
|
|
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
|
|
Range loopRange, OpFoldResult offset,
|
|
OpFoldResult tileSize) {
|
|
std::optional<int64_t> ts = getConstantIntValue(tileSize);
|
|
if (ts && ts.value() == 1)
|
|
return tileSize;
|
|
|
|
if (tileDividesIterationDomain(
|
|
Range{loopRange.offset, loopRange.size, tileSize}))
|
|
return tileSize;
|
|
|
|
// The tile size to use (to avoid out of bounds access) is minimum of
|
|
// `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
|
|
// loop.
|
|
AffineExpr s0, s1, d0;
|
|
bindDims(b.getContext(), d0);
|
|
bindSymbols(b.getContext(), s0, s1);
|
|
AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
|
|
Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
|
|
return affine::makeComposedFoldedAffineMin(
|
|
b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
|
|
}
|
|
|
|
/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
|
|
/// than `iterationSize`.
|
|
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
|
|
OpFoldResult numThreads,
|
|
OpFoldResult iterationSize) {
|
|
std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
|
|
std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
|
|
std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
|
|
if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
|
|
return false;
|
|
return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
|
|
}
|
|
|
|
/// Compute the `OpFoldResult`s that represents the multi-dimensional
|
|
/// `offset`s and `size`s of the tile of the iteration space that the
|
|
/// innermost loop body of the generated tiled loops corresponds to.
|
|
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
|
|
getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
|
|
ArrayRef<Range> iterationDomain,
|
|
ArrayRef<OpFoldResult> tileSizes,
|
|
ArrayRef<OpFoldResult> numThreads) {
|
|
SmallVector<OpFoldResult> offsets, sizes;
|
|
int materializedLoopNum = 0;
|
|
|
|
if (!numThreads.empty()) {
|
|
AffineExpr d0, d1, s0, s1;
|
|
AffineExpr offsetExpr, residualTileSizeExpr;
|
|
bindDims(rewriter.getContext(), d0, d1);
|
|
bindSymbols(rewriter.getContext(), s0, s1);
|
|
offsetExpr = d0 + d1 * s0;
|
|
residualTileSizeExpr = s1 - (d0 + d1 * s0);
|
|
|
|
for (auto [nt, tileSize, loopRange] :
|
|
llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
|
|
|
|
// Non-tiled cases, set the offset and size to the
|
|
// `loopRange.offset/size`.
|
|
if (isConstantIntValue(nt, 0)) {
|
|
offsets.push_back(loopRange.offset);
|
|
sizes.push_back(loopRange.size);
|
|
continue;
|
|
}
|
|
|
|
Value iv = ivs[materializedLoopNum++];
|
|
OpFoldResult offset = affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, offsetExpr,
|
|
ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
|
|
OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, residualTileSizeExpr,
|
|
{loopRange.offset, nt, tileSize, loopRange.size});
|
|
|
|
OpFoldResult size = tileSize;
|
|
if (!isConstantIntValue(residualTileSize, 0)) {
|
|
OpFoldResult sizeMinusOffsetPerThread =
|
|
affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
|
|
{offset, loopRange.size});
|
|
size = affine::makeComposedFoldedAffineMin(
|
|
rewriter, loc,
|
|
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
|
|
{sizeMinusOffsetPerThread, tileSize});
|
|
}
|
|
|
|
// Consider the case where the original loop was `[0, 100)`.
|
|
// If number of threads are `7`, the tile size would be computed as
|
|
// `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
|
|
// - `offset = 0 + 6 * 15 = 105`
|
|
// - `tileSize = min(15, 100 - 105) = -5`
|
|
// To avoid negative tile sizes, we need to do a further
|
|
// `nonNegativeTileSize = affine.max(0, tileSize)`.
|
|
// This `max` can be avoided if
|
|
// `offset + tileSize * (numThreads - 1) < (ub - lb)`
|
|
if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
|
|
AffineMap maxMap =
|
|
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
|
|
size = affine::makeComposedFoldedAffineMax(
|
|
rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
|
|
}
|
|
|
|
offsets.push_back(offset);
|
|
sizes.push_back(size);
|
|
}
|
|
return {offsets, sizes};
|
|
} else {
|
|
for (auto [tileSize, loopRange] :
|
|
llvm::zip_equal(tileSizes, iterationDomain)) {
|
|
|
|
// Non-tiled cases, set the offset and size to the
|
|
// `loopRange.offset/size`.
|
|
if (isConstantIntValue(tileSize, 0)) {
|
|
offsets.push_back(loopRange.offset);
|
|
sizes.push_back(loopRange.size);
|
|
continue;
|
|
}
|
|
|
|
Value iv = ivs[materializedLoopNum++];
|
|
OpFoldResult offset = getAsOpFoldResult(iv);
|
|
offsets.push_back(offset);
|
|
OpFoldResult size =
|
|
getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
|
|
sizes.push_back(size);
|
|
}
|
|
return {offsets, sizes};
|
|
}
|
|
}
|
|
|
|
/// Function to return the bounds of the loops to be generated.
|
|
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
|
|
SmallVector<OpFoldResult>>
|
|
getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
|
|
ArrayRef<OpFoldResult> tileSizes) {
|
|
SmallVector<OpFoldResult> lbs, ubs, steps;
|
|
for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
|
|
// No loop if the tile size is 0.
|
|
if (isConstantIntValue(tileSize, 0))
|
|
continue;
|
|
lbs.push_back(loopRange.offset);
|
|
ubs.push_back(loopRange.size);
|
|
steps.push_back(tileSize);
|
|
}
|
|
return {lbs, ubs, steps};
|
|
}
|
|
|
|
/// A function that allows returning additional yielded values during
|
|
/// `yieldTiledValuesAndReplace`.
|
|
/// - `ivs` induction variable for the loop.
|
|
/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
|
|
/// - `tiledValues` the tiled values to return. Must be of same size as
|
|
/// `newbbArgs`, each element of this array is inserted into the corresponding
|
|
/// element in `newbbArgs`.
|
|
/// - `resultOffsets` is of the same size as `tiledValues` and represents
|
|
/// the offsets to use when inserting corresponding element from `tiledValues`
|
|
/// into the element from `newBbArgs`.
|
|
/// - `resultSizes` is of the same size as `tiledValues` and represents
|
|
/// the size of the corresponding element from `tiledValues` inserted into
|
|
/// the element from `newBbArgs`.
|
|
/// In case the method needs to return `failure()` the method is expected
|
|
/// to clean up any inserted operations.
|
|
using YieldTiledValuesFn = std::function<LogicalResult(
|
|
RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
|
|
SmallVector<Value> &tiledValues,
|
|
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
|
|
SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
|
|
|
|
/// Clones the operation and updates the destination if the operation
|
|
/// implements the `DestinationStyleOpInterface`.
|
|
static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
|
|
Operation *op,
|
|
ValueRange newDestArgs) {
|
|
Operation *clonedOp = rewriter.clone(*op);
|
|
if (newDestArgs.empty())
|
|
return clonedOp;
|
|
if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
|
|
destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
|
|
return clonedOp;
|
|
}
|
|
|
|
/// Generate the tile-loop nest using `scf.for` operation.
|
|
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
|
|
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
|
|
/// - `destinationTensors` are the init values to use for the outer most loop.
|
|
/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
|
|
/// most
|
|
/// loop.
|
|
/// - `loops` is an in-out parameter into which the generated loops are
|
|
/// populated.
|
|
static LogicalResult generateLoopNestUsingForOp(
|
|
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
|
|
ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
|
|
YieldTiledValuesFn yieldTiledValuesFn,
|
|
SmallVector<LoopLikeOpInterface> &loops) {
|
|
assert(!loopRanges.empty() && "unexpected empty loop ranges");
|
|
assert(loopRanges.size() == tileSizes.size() &&
|
|
"expected as many tile sizes as loop ranges");
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
SmallVector<OpFoldResult> lbs, ubs, steps;
|
|
std::tie(lbs, ubs, steps) =
|
|
getLoopBounds(rewriter, loc, loopRanges, tileSizes);
|
|
SmallVector<Value> lbVals =
|
|
getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
|
|
SmallVector<Value> ubVals =
|
|
getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
|
|
SmallVector<Value> stepVals =
|
|
getValueOrCreateConstantIndexOp(rewriter, loc, steps);
|
|
|
|
SmallVector<Value> ivs;
|
|
for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
|
|
auto loop =
|
|
rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
|
|
[](OpBuilder &bodyBuilder, Location bodyLoc,
|
|
Value iv, ValueRange /*iterArgs*/) {});
|
|
loops.push_back(loop);
|
|
ivs.push_back(loop.getInductionVar());
|
|
rewriter.setInsertionPointToEnd(loop.getBody());
|
|
destinationTensors = loop.getRegionIterArgs();
|
|
}
|
|
|
|
SmallVector<Value> tiledResults;
|
|
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
|
|
if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
|
|
tiledResults, resultOffsets, resultSizes))) {
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "failed to generate inner tile loop body");
|
|
}
|
|
if (loops.empty())
|
|
return success();
|
|
|
|
assert(tiledResults.size() == destinationTensors.size() &&
|
|
"Number of results of body should be equal to number of iter args");
|
|
|
|
// 6. Yield all the results of the tiled operation.
|
|
SmallVector<Value> yieldedValues;
|
|
for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
|
|
llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
|
|
resultSizes)) {
|
|
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
|
|
rewriter.getIndexAttr(1));
|
|
auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
|
|
loc, tiledValue, destinationTensor, resultOffset, resultSize,
|
|
resultStride);
|
|
yieldedValues.push_back(insertSlice);
|
|
}
|
|
rewriter.create<scf::YieldOp>(loc, yieldedValues);
|
|
|
|
// Add the scf.yield operations for all the outer loops.
|
|
for (auto [outerLoop, innerLoop] :
|
|
llvm::zip_equal(MutableArrayRef(loops).drop_back(),
|
|
MutableArrayRef(loops).drop_front())) {
|
|
rewriter.setInsertionPointToEnd(
|
|
cast<scf::ForOp>(outerLoop.getOperation()).getBody());
|
|
rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Generate the tile-loop nest using `scf.forall` operation.
|
|
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
|
|
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
|
|
/// - `destinationTensors` are the init values to use for the outer most loop.
|
|
/// - `mappingVector` is the mapping attributes to use for loop construction.
|
|
/// Can be empty.
|
|
/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
|
|
/// most
|
|
/// loop.
|
|
/// - `loops` is an in-out parameter into which the generated loops are
|
|
/// populated.
|
|
static LogicalResult generateLoopNestUsingForallOp(
|
|
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
|
|
ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
|
|
ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
|
|
YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
|
|
assert(!loopRanges.empty() && "unexpected empty loop ranges");
|
|
assert(loopRanges.size() == tileSizes.size() &&
|
|
"expected as many tile sizes as loop ranges");
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
SmallVector<OpFoldResult> offsets(loopRanges.size()),
|
|
sizes(loopRanges.size());
|
|
|
|
std::optional<ArrayAttr> mappingAttr;
|
|
if (!mappingVector.empty())
|
|
mappingAttr = rewriter.getArrayAttr(mappingVector);
|
|
|
|
scf::ForallOp forallOp;
|
|
bool useNumThreads = !numThreads.empty();
|
|
|
|
if (useNumThreads) {
|
|
// Prune the zero numthreads.
|
|
SmallVector<OpFoldResult> nonZeroNumThreads;
|
|
for (auto nt : numThreads) {
|
|
if (isConstantIntValue(nt, 0))
|
|
continue;
|
|
nonZeroNumThreads.push_back(nt);
|
|
}
|
|
forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
|
|
destinationTensors, mappingAttr);
|
|
} else {
|
|
SmallVector<OpFoldResult> lbs, ubs, steps;
|
|
std::tie(lbs, ubs, steps) =
|
|
getLoopBounds(rewriter, loc, loopRanges, tileSizes);
|
|
forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
|
|
destinationTensors, mappingAttr);
|
|
}
|
|
loops.push_back(forallOp);
|
|
|
|
rewriter.setInsertionPoint(forallOp.getTerminator());
|
|
destinationTensors = forallOp.getRegionOutArgs();
|
|
|
|
SmallVector<Value> tiledResults;
|
|
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
|
|
if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
|
|
destinationTensors, tiledResults, resultOffsets,
|
|
resultSizes)))
|
|
return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
|
|
|
|
rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
|
|
for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
|
|
llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
|
|
resultSizes)) {
|
|
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
|
|
rewriter.getIndexAttr(1));
|
|
|
|
rewriter.create<tensor::ParallelInsertSliceOp>(
|
|
loc, tiledValue, destinationTensor, resultOffset, resultSize,
|
|
resultStride);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Generate the tile-loop nest using the loop construct specifed in `options`.
|
|
/// - `options`: Tiling options specified.
|
|
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
|
|
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
|
|
/// - `destinationTensors` are the init values to use for the outer most loop.
|
|
/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
|
|
/// most
|
|
/// loop.
|
|
/// - `loops` is an in-out parameter into which the generated loops are
|
|
/// populated.
|
|
static LogicalResult generateLoopNest(
|
|
RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
|
|
ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
|
|
ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
|
|
YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
|
|
// If the tile sizes are all zero, no loops are generated. Just call the
|
|
// callback function to handle untiled case.
|
|
if (llvm::all_of(tileSizes, isZeroIndex)) {
|
|
SmallVector<Value> tiledResults;
|
|
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
|
|
return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
|
|
tiledResults, resultOffsets, resultSizes);
|
|
}
|
|
if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
|
|
return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
|
|
destinationTensors, tiledBodyFn, loops);
|
|
}
|
|
if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
|
|
return generateLoopNestUsingForallOp(
|
|
rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
|
|
destinationTensors, tiledBodyFn, loops);
|
|
}
|
|
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
|
|
}
|
|
|
|
/// Append the specified additional `newInitOperands` operands to the
|
|
/// loops existing `init` operands (or similar), and replace `loopOp` with
|
|
/// the new loop that has the additional init operands. The loop body of
|
|
/// this loop is moved over to the new loop. `yieldTiledValuesFn`
|
|
/// is called to get the new tiled values returned, and the offset
|
|
/// and sizes at which the tiled value is inserted into the
|
|
/// new region iter_args that correspond to the newly added init operands.
|
|
template <typename LoopType>
|
|
FailureOr<LoopLikeOpInterface>
|
|
yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
|
|
ValueRange newInitOperands,
|
|
YieldTiledValuesFn yieldTiledValuesFn) {
|
|
return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
|
|
}
|
|
|
|
/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
|
|
template <>
|
|
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
|
|
scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
|
|
YieldTiledValuesFn yieldTiledValuesFn) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
Location loc = loopOp.getLoc();
|
|
rewriter.setInsertionPoint(loopOp);
|
|
|
|
auto inits = llvm::to_vector(loopOp.getInitArgs());
|
|
inits.append(newInitOperands.begin(), newInitOperands.end());
|
|
auto newLoop = rewriter.create<scf::ForOp>(
|
|
loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
|
|
inits, [](OpBuilder &, Location, Value, ValueRange) {});
|
|
|
|
// Move the loop body to the new op.
|
|
Block *loopBody = loopOp.getBody();
|
|
Block *newLoopBody = newLoop.getBody();
|
|
rewriter.mergeBlocks(
|
|
loopBody, newLoopBody,
|
|
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
|
|
|
|
auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
|
|
rewriter.setInsertionPoint(yieldOp);
|
|
|
|
SmallVector<Value> tiledValues;
|
|
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
|
|
ValueRange newRegionIterArgs =
|
|
newLoop.getRegionIterArgs().take_back(newInitOperands.size());
|
|
if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
|
|
newRegionIterArgs, tiledValues, resultOffsets,
|
|
resultSizes))) {
|
|
rewriter.eraseOp(newLoop);
|
|
return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
|
|
}
|
|
|
|
SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
|
|
for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
|
|
llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
|
|
resultSizes)) {
|
|
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
|
|
rewriter.getIndexAttr(1));
|
|
Value insert = rewriter.create<tensor::InsertSliceOp>(
|
|
yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
|
|
resultStride);
|
|
newYieldValues.push_back(insert);
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
|
|
rewriter.replaceOp(loopOp,
|
|
newLoop->getResults().take_front(loopOp.getNumResults()));
|
|
return cast<LoopLikeOpInterface>(newLoop.getOperation());
|
|
}
|
|
|
|
/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
|
|
template <>
|
|
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
|
|
scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
|
|
YieldTiledValuesFn yieldTiledValuesFn) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
Location loc = loopOp.getLoc();
|
|
rewriter.setInsertionPoint(loopOp);
|
|
auto inits = llvm::to_vector(loopOp.getOutputs());
|
|
inits.append(newInitOperands.begin(), newInitOperands.end());
|
|
auto newLoop = rewriter.create<scf::ForallOp>(
|
|
loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
|
|
loopOp.getMixedStep(), inits, loopOp.getMapping(),
|
|
[](OpBuilder &, Location, ValueRange) {});
|
|
|
|
// Move the region of the current block to the newly created op.
|
|
Block *loopBody = loopOp.getBody();
|
|
Block *newLoopBody = newLoop.getBody();
|
|
rewriter.mergeBlocks(
|
|
loopBody, newLoopBody,
|
|
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
|
|
|
|
auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
|
|
rewriter.setInsertionPoint(terminator);
|
|
SmallVector<Value> tiledValues;
|
|
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
|
|
ValueRange regionIterArgs =
|
|
newLoop.getRegionIterArgs().take_back(newInitOperands.size());
|
|
if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
|
|
regionIterArgs, tiledValues, resultOffsets,
|
|
resultSizes))) {
|
|
rewriter.eraseOp(newLoop);
|
|
return rewriter.notifyMatchFailure(loopOp,
|
|
"failed to get yielded tiled values");
|
|
}
|
|
|
|
// Update the terminator.
|
|
rewriter.setInsertionPointToEnd(terminator.getBody());
|
|
|
|
for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
|
|
tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
|
|
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
|
|
rewriter.getIndexAttr(1));
|
|
rewriter.create<tensor::ParallelInsertSliceOp>(
|
|
terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
|
|
resultStride);
|
|
}
|
|
|
|
rewriter.replaceOp(loopOp,
|
|
newLoop->getResults().take_front(loopOp.getNumResults()));
|
|
return cast<LoopLikeOpInterface>(newLoop.getOperation());
|
|
}
|
|
|
|
/// Implementation of `yieldTiledValuesAndReplaceLoop` for
|
|
/// `LoopLikeOpInterface`, that just dispatches to the implementation for each
|
|
/// supported loop type.
|
|
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
|
|
LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
|
|
ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
|
|
return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
|
|
loopLikeOp.getOperation())
|
|
.Case<scf::ForOp, scf::ForallOp>(
|
|
[&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
|
|
return yieldTiledValuesAndReplaceLoop(
|
|
loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
|
|
})
|
|
.Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
|
|
return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
|
|
});
|
|
}
|
|
|
|
/// Method to add new init values to a loop nest. Updates `loops` in-place with
|
|
/// new loops that use the `newInitValues`.
|
|
/// The outer-loops are updated to yield the new result values of the inner
|
|
/// loop. For the innermost loop, the call back `getNewYields` is invoked to get
|
|
/// the additional values to yield form the innermost loop.
|
|
static LogicalResult addInitOperandsToLoopNest(
|
|
RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
|
|
ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
|
|
SmallVector<scf::ForOp> newLoops;
|
|
if (loops.empty())
|
|
return success();
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(loops.front());
|
|
|
|
SmallVector<Value> ivs;
|
|
for (auto &loop : loops.drop_back()) {
|
|
rewriter.setInsertionPoint(loop);
|
|
|
|
// if loops.size() > 1 we assume that scf.for is used for the loops.
|
|
auto forLoop = cast<scf::ForOp>(loop.getOperation());
|
|
|
|
// Create a new loop with the new init values for this loop.
|
|
SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
|
|
newInits.append(newInitValues.begin(), newInitValues.end());
|
|
auto newLoop = rewriter.create<scf::ForOp>(
|
|
forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
|
|
forLoop.getStep(), newInits,
|
|
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
|
|
|
|
// Merge the body of the new loop with the body of the old loops.
|
|
SmallVector<Value> sourceBlockArgs;
|
|
sourceBlockArgs.push_back(newLoop.getInductionVar());
|
|
auto newRegionIterArgs = newLoop.getRegionIterArgs();
|
|
sourceBlockArgs.append(
|
|
newRegionIterArgs.begin(),
|
|
std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
|
|
rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
|
|
rewriter.replaceOp(
|
|
forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
|
|
loop = newLoop;
|
|
ivs.push_back(newLoop.getInductionVar());
|
|
newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
|
|
}
|
|
|
|
// Update the loop body of the innermost loop to get new yield values.
|
|
LoopLikeOpInterface innerMostLoop = loops.back();
|
|
FailureOr<LoopLikeOpInterface> newInnerMostLoop =
|
|
yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
|
|
getNewTiledYieldsFn);
|
|
|
|
if (failed(newInnerMostLoop))
|
|
return innerMostLoop.emitOpError("failed to return additional yields");
|
|
loops.back() = newInnerMostLoop.value();
|
|
|
|
// Make all other loops except the innermost loops yield the values returned
|
|
// by the inner loop.
|
|
for (auto [outerLoop, innerLoop] :
|
|
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
|
|
// Again assume that all the outer loops are scf.for operations.
|
|
auto outerForLoop = cast<scf::ForOp>(outerLoop);
|
|
auto outerLoopYield =
|
|
cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
|
|
SmallVector<Value> newYields =
|
|
llvm::to_vector(outerLoopYield.getOperands());
|
|
ValueRange additionalYields =
|
|
innerLoop->getResults().take_back(newInitValues.size());
|
|
newYields.append(additionalYields.begin(), additionalYields.end());
|
|
rewriter.setInsertionPoint(outerLoopYield);
|
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Implementation of tiling transformation of `op` that implements the
|
|
/// `TilingInterface` using `scf.for` to iterate over the tiles.
|
|
FailureOr<scf::SCFTilingResult>
|
|
mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
|
|
const scf::SCFTilingOptions &options) {
|
|
if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
|
|
return failure();
|
|
}
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointAfter(op);
|
|
|
|
// 1. Get the range of the loops that are represented by the operation.
|
|
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
|
|
|
|
// 2. Materialize the tile sizes and/or number of threads;
|
|
SmallVector<OpFoldResult> tileSizes, numThreads;
|
|
std::tie(tileSizes, numThreads) =
|
|
getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
|
|
|
|
// Check if it is safe to tile. This is hold over from previous iterations
|
|
// of tile to for-all. Consider dropping it.
|
|
if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
|
|
checkSafeToTileToForall(op, tileSizes, numThreads);
|
|
}
|
|
|
|
// 3. If there is an interchange specified, permute the iteration domain and
|
|
// the tile sizes.
|
|
SmallVector<int64_t> interchangeVector;
|
|
if (!options.interchangeVector.empty()) {
|
|
interchangeVector = fillInterchangeVector(options.interchangeVector,
|
|
iterationDomain.size());
|
|
assert(isPermutationVector(interchangeVector) &&
|
|
"expected interchange vector to be a permutation");
|
|
|
|
applyPermutationToVector(iterationDomain, interchangeVector);
|
|
applyPermutationToVector(tileSizes, interchangeVector);
|
|
if (!numThreads.empty())
|
|
applyPermutationToVector(numThreads, interchangeVector);
|
|
}
|
|
|
|
FailureOr<TilingResult> tilingResult;
|
|
// 4. Define the lambda function used later to generate the body of the
|
|
// innermost tiled loop.
|
|
YieldTiledValuesFn innerYieldTiledValuesFn =
|
|
[&](RewriterBase &rewriter, Location loc, ValueRange ivs,
|
|
ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
|
|
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
|
|
SmallVector<SmallVector<OpFoldResult>> &resultSizes)
|
|
-> LogicalResult {
|
|
// 4a. Compute the `offsets` and `sizes` to use for tiling.
|
|
SmallVector<OpFoldResult> offsets, sizes;
|
|
std::tie(offsets, sizes) = getTileOffsetAndSizes(
|
|
rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
|
|
|
|
// 4b. If interchange was provided, apply inverse of the interchange
|
|
// to get back the offsets/sizes in the order to be specified.
|
|
if (!interchangeVector.empty()) {
|
|
auto inversePermutation = invertPermutationVector(interchangeVector);
|
|
applyPermutationToVector(offsets, inversePermutation);
|
|
applyPermutationToVector(sizes, inversePermutation);
|
|
}
|
|
|
|
// 5. Generate the tiled implementation within the inner most loop.
|
|
|
|
// 5a. Clone the operation within the loop body.
|
|
auto clonedOp = cast<TilingInterface>(
|
|
cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
|
|
|
|
// 5b. Early return cloned op if tiling is not happening. We can not return
|
|
// the original op because it could lead to
|
|
// `rewriter.replaceOp(op, op->getResults())` and users would get crash.
|
|
if (llvm::all_of(tileSizes, isZeroIndex)) {
|
|
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
|
|
tilingResult =
|
|
TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
|
|
return success();
|
|
}
|
|
|
|
// 5c. Tile the cloned operation.
|
|
tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
|
|
if (failed(tilingResult)) {
|
|
rewriter.eraseOp(clonedOp);
|
|
return op.emitOpError("faild to tile operation");
|
|
}
|
|
|
|
// 5d. Delete the cloned operation.
|
|
rewriter.eraseOp(clonedOp);
|
|
|
|
// 5e. Compute the offsets at which the result values are to be inserted
|
|
// back into its destinations.
|
|
for (auto [index, tiledValue] :
|
|
llvm::enumerate(tilingResult->tiledValues)) {
|
|
tiledResults.push_back(tiledValue);
|
|
SmallVector<OpFoldResult> resultOffset, resultSize;
|
|
if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
|
|
resultOffset, resultSize))) {
|
|
for (auto op : tilingResult->tiledOps) {
|
|
rewriter.eraseOp(op);
|
|
}
|
|
return rewriter.notifyMatchFailure(
|
|
op, "failed to get slice of result produced");
|
|
}
|
|
resultOffsets.emplace_back(std::move(resultOffset));
|
|
resultSizes.emplace_back(std::move(resultSize));
|
|
}
|
|
|
|
return success();
|
|
};
|
|
|
|
// 6. Find the destination tensors to use for the operation.
|
|
SmallVector<Value> destinationTensors;
|
|
if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
|
|
destinationTensors))) {
|
|
return rewriter.notifyMatchFailure(op,
|
|
"unable to create destination tensors");
|
|
}
|
|
|
|
// 7. Generate the tiled loops nest using the callback defined above.
|
|
SmallVector<LoopLikeOpInterface> loops;
|
|
if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
|
|
tileSizes, numThreads, destinationTensors,
|
|
innerYieldTiledValuesFn, loops)))
|
|
return op.emitOpError("failed to generate tiling loops");
|
|
assert(succeeded(tilingResult) &&
|
|
"expected tiling result to be computed after loop generation");
|
|
|
|
// If loops are empty, the tiled op is used as the replacement for the untiled
|
|
// op.
|
|
if (loops.empty()) {
|
|
return scf::SCFTilingResult{tilingResult->tiledOps, loops,
|
|
tilingResult->tiledValues};
|
|
}
|
|
|
|
SmallVector<Value> replacements = llvm::map_to_vector(
|
|
loops.front()->getResults(), [](OpResult r) -> Value { return r; });
|
|
return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
|
|
}
|
|
|
|
FailureOr<scf::SCFReductionTilingResult>
|
|
mlir::scf::tileReductionUsingScf(RewriterBase &b,
|
|
PartialReductionOpInterface op,
|
|
ArrayRef<OpFoldResult> tileSizes) {
|
|
Location loc = op.getLoc();
|
|
// Ops implementing PartialReductionOpInterface are expected to implement
|
|
// TilingInterface.
|
|
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
|
|
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
|
|
auto tileSizesVector = llvm::to_vector(tileSizes);
|
|
if (tileSizesVector.size() < iterationDomain.size()) {
|
|
auto zero = b.getIndexAttr(0);
|
|
tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
|
|
zero);
|
|
}
|
|
SmallVector<utils::IteratorType> iterators =
|
|
tilingInterfaceOp.getLoopIteratorTypes();
|
|
|
|
SmallVector<int> reductionDims;
|
|
for (auto [idx, iteratorType] :
|
|
llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
|
|
if (iteratorType == utils::IteratorType::reduction)
|
|
reductionDims.push_back(idx);
|
|
}
|
|
|
|
// 2. create the inital tensor value.
|
|
FailureOr<SmallVector<Value>> maybeInitTensors =
|
|
op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
|
|
reductionDims);
|
|
if (failed(maybeInitTensors)) {
|
|
return b.notifyMatchFailure(op, "Failed to create initial tensors.");
|
|
}
|
|
SmallVector<Value> &initTensors = maybeInitTensors.value();
|
|
|
|
// 3. Define the callback to use for generating the inner most tile loop body.
|
|
SmallVector<Operation *> parallelTiledOps;
|
|
auto innerYieldTiledValuesFn =
|
|
[&](RewriterBase &rewriter, Location loc, ValueRange ivs,
|
|
ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
|
|
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
|
|
SmallVector<SmallVector<OpFoldResult>> &resultSizes)
|
|
-> LogicalResult {
|
|
SmallVector<OpFoldResult> offsets, sizes;
|
|
{
|
|
int materializedLoopNum = 0;
|
|
for (auto [tileSize, loopRange] :
|
|
llvm::zip_equal(tileSizesVector, iterationDomain)) {
|
|
if (isConstantIntValue(tileSize, 0)) {
|
|
offsets.push_back(loopRange.offset);
|
|
sizes.push_back(loopRange.size);
|
|
continue;
|
|
}
|
|
Value iv = ivs[materializedLoopNum++];
|
|
offsets.push_back(iv);
|
|
sizes.push_back(
|
|
getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
|
|
}
|
|
}
|
|
|
|
// 4a. Clone the operation.
|
|
{
|
|
auto clonedOp = cast<PartialReductionOpInterface>(
|
|
cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
|
|
|
|
// 4b. Tile the cloned operation.
|
|
FailureOr<TilingResult> partialTilingResult =
|
|
clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
|
|
sizes, reductionDims);
|
|
if (failed(partialTilingResult)) {
|
|
return failure();
|
|
}
|
|
std::swap(parallelTiledOps, partialTilingResult->tiledOps);
|
|
std::swap(tiledResult, partialTilingResult->tiledValues);
|
|
|
|
// 4c. Delete the cloned operation.
|
|
b.eraseOp(clonedOp);
|
|
}
|
|
|
|
// 4d. Compute the offsets and sizes needed to insert the result of the
|
|
// tiled value back into destination before yielding the destination.
|
|
for (auto result : tiledResult) {
|
|
SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
|
|
resultOffsets.emplace_back(std::move(outOffsets));
|
|
|
|
SmallVector<OpFoldResult> outSizes;
|
|
for (size_t i = 0; i < offsets.size(); i++) {
|
|
outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
|
|
}
|
|
resultSizes.emplace_back(std::move(outSizes));
|
|
}
|
|
return success();
|
|
};
|
|
|
|
// 5. Generate the tiled implementation using the destination tensors.
|
|
SmallVector<LoopLikeOpInterface> loops;
|
|
scf::SCFTilingOptions options;
|
|
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
|
|
if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
|
|
/*numThreads=*/ArrayRef<OpFoldResult>{},
|
|
initTensors, innerYieldTiledValuesFn, loops)))
|
|
return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
|
|
|
|
SmallVector<Value> replacements = llvm::map_to_vector(
|
|
loops.front()->getResults(), [](OpResult r) -> Value { return r; });
|
|
|
|
// 5. Apply the merge reduction to combine all the partial values.
|
|
b.setInsertionPointAfter(*loops.begin());
|
|
FailureOr<MergeResult> mergeResult =
|
|
op.mergeReductions(b, loc, replacements, reductionDims);
|
|
if (failed(mergeResult)) {
|
|
return failure();
|
|
}
|
|
b.replaceOp(op, mergeResult->replacements);
|
|
|
|
SCFReductionTilingResult reductionTilingResult;
|
|
std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
|
|
std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
|
|
std::swap(reductionTilingResult.initialValues, initTensors);
|
|
std::swap(reductionTilingResult.loops, loops);
|
|
std::swap(reductionTilingResult.replacements, mergeResult->replacements);
|
|
|
|
return reductionTilingResult;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// tileConsumerAndFuseProducersUsingSCF implementation.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Return the untiled producer whose slice is used in a tiled consumer. The
|
|
/// method traverses the tile loop nest (`loops`) if needed, and returns the
|
|
/// `iter_args` of the outer most that is encountered. Traversing the iter_args
|
|
/// indicates that this is a destination operand of the consumer. If there was
|
|
/// no loop traversal needed, the second value of the returned tuple is empty.
|
|
static std::tuple<OpResult, std::optional<OpOperand *>>
|
|
getUntiledProducerFromSliceSource(OpOperand *source,
|
|
ArrayRef<LoopLikeOpInterface> loops) {
|
|
std::optional<OpOperand *> destinationIterArg;
|
|
auto loopIt = loops.rbegin();
|
|
while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
|
|
auto loop = *loopIt;
|
|
if (iterArg.getOwner()->getParentOp() != loop)
|
|
break;
|
|
source = loop.getTiedLoopInit(iterArg);
|
|
loopIt++;
|
|
}
|
|
if (loopIt == loops.rend())
|
|
destinationIterArg = source;
|
|
return {dyn_cast<OpResult>(source->get()), destinationIterArg};
|
|
}
|
|
|
|
/// Implementation of fusing producer of a single slice by computing the
|
|
/// slice of the producer in-place.
|
|
std::optional<scf::SCFFuseProducerOfSliceResult>
|
|
mlir::scf::tileAndFuseProducerOfSlice(
|
|
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
|
|
MutableArrayRef<LoopLikeOpInterface> loops) {
|
|
// 1. Get the producer of the source (potentially walking through
|
|
// `iter_args` of nested `scf.for`)
|
|
auto [fusableProducer, destinationInitArg] =
|
|
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
|
|
loops);
|
|
if (!fusableProducer)
|
|
return std::nullopt;
|
|
unsigned resultNumber = fusableProducer.getResultNumber();
|
|
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(candidateSliceOp);
|
|
|
|
// 2. Clone the fused producer
|
|
// 2a. Compute the destination operands to use for the cloned operation.
|
|
SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
|
|
Operation *fusableProducerOp = fusableProducer.getOwner();
|
|
if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
|
|
failed(tensor::getOrCreateDestinations(
|
|
rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
|
|
origDestinationTensors)))
|
|
return std::nullopt;
|
|
|
|
clonedOpDestinationTensors = origDestinationTensors;
|
|
if (destinationInitArg &&
|
|
isa<DestinationStyleOpInterface>(fusableProducerOp)) {
|
|
// 2b. If the producer is also destination style, then to maintain the
|
|
// destination passing style, update the destination of the producer to be
|
|
// the source of the slice.
|
|
clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
|
|
}
|
|
// 2c. Clone the fused producer.
|
|
Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
|
|
rewriter, fusableProducerOp, clonedOpDestinationTensors);
|
|
// 2d. Update the source of the candidateSlice to be the cloned producer.
|
|
// Easier to just clone the slice with different source since replacements
|
|
// and DCE of cloned ops becomes easier
|
|
SmallVector<Value> candidateSliceOpOperands =
|
|
llvm::to_vector(candidateSliceOp->getOperands());
|
|
candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
|
|
tensor::ExtractSliceOp clonedCandidateSliceOp =
|
|
mlir::clone(rewriter, candidateSliceOp,
|
|
candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
|
|
|
|
// 3. Generate the tiled implementation of the producer of the source
|
|
FailureOr<TilingResult> tileAndFuseResult =
|
|
tensor::replaceExtractSliceWithTiledProducer(
|
|
rewriter, clonedCandidateSliceOp,
|
|
clonedProducerOp->getResult(resultNumber));
|
|
if (failed(tileAndFuseResult))
|
|
return std::nullopt;
|
|
// Note: Do not delete the candidateSliceOp, since its passed in from the
|
|
// caller.
|
|
rewriter.replaceAllUsesWith(candidateSliceOp,
|
|
tileAndFuseResult->tiledValues[0]);
|
|
rewriter.eraseOp(clonedCandidateSliceOp);
|
|
rewriter.eraseOp(clonedProducerOp);
|
|
|
|
// 3. If the slice is for a destination operand, for example,
|
|
//
|
|
// ```mlir
|
|
// %0 = linalg.init
|
|
// %1 = linalg.fill .. outs(%0 : )
|
|
// %2 = scf.for .. iter_args(%arg0 = %1) {
|
|
// %3 = scf.for .. iter_args(%arg1 = %arg0) {
|
|
// %4 = tensor.extract_slice %arg1 [..]
|
|
// .. = linalg.matmul .. outs(%4 : )
|
|
// }
|
|
// }
|
|
// ```
|
|
//
|
|
// the IR is currently
|
|
//
|
|
// ```
|
|
// %0 = linalg.init
|
|
// %1 = linalg.fill
|
|
// %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
|
|
// %3 = scf.for .. iter_args(%arg1 = %arg0) {
|
|
// %4 = tensor.extract_slice %arg1[..]
|
|
// %5 = linalg.fill .. outs(%4 : )
|
|
// .. = linalg.matmul .. outs(%5 : )
|
|
// }
|
|
// }
|
|
// ```
|
|
//
|
|
// The untiled `linalg.fill` is still used as the `init_value` since it
|
|
// was originally a destination operand of the untiled `linalg.matmul`.
|
|
// When fusing an operand that is a destination operand, the iter_arg of
|
|
// the outer most loop should be changed to use the destination of the
|
|
// fused operation. With this the IR will be.
|
|
//
|
|
// ```
|
|
// %0 = linalg.init
|
|
// %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
|
|
// %2 = scf.for .. iter_args(%arg1 = %arg0) {
|
|
// %3 = tensor.extract_slice %arg1[..]
|
|
// %4 = linalg.fill .. outs(%3 : )
|
|
// .. = linalg.matmul .. outs(%4 : )
|
|
// }
|
|
// }
|
|
// ```
|
|
if (destinationInitArg &&
|
|
isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
|
|
loops.front()
|
|
->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
|
|
.set(origDestinationTensors[resultNumber]);
|
|
}
|
|
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
|
|
tileAndFuseResult->tiledValues[0],
|
|
tileAndFuseResult->tiledOps};
|
|
}
|
|
|
|
/// Reconstruct the fused producer from within the tiled-and-fused code.
|
|
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
|
|
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
|
|
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
|
|
MutableArrayRef<LoopLikeOpInterface> loops,
|
|
ArrayRef<unsigned> yieldResultNumber) {
|
|
if (loops.empty())
|
|
return success();
|
|
|
|
Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
|
|
*tiledOwner = fusedProducerInfo.tiledOps[0];
|
|
|
|
Location loc = originalOwner->getLoc();
|
|
// a. collect all init Value to be appended
|
|
SmallVector<unsigned> initNumberList =
|
|
yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
|
|
0, originalOwner->getNumResults()))
|
|
: llvm::to_vector(yieldResultNumber);
|
|
SmallVector<Value> initValueList;
|
|
for (const auto &resultNumber : initNumberList) {
|
|
FailureOr<Value> initValue = tensor::getOrCreateDestination(
|
|
rewriter, loc, originalOwner->getResult(resultNumber));
|
|
if (succeeded(initValue)) {
|
|
initValueList.push_back(initValue.value());
|
|
} else {
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
YieldTiledValuesFn newYieldValuesFn =
|
|
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
|
|
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
|
|
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
|
|
SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
|
|
OpBuilder::InsertionGuard g(innerRewriter);
|
|
|
|
// get sliceOp tile information
|
|
SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
|
|
sliceSizes = sliceOp.getMixedSizes();
|
|
|
|
// expect all strides of sliceOp being 1
|
|
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
|
|
return !isConstantIntValue(ofr, 1);
|
|
}))
|
|
return failure();
|
|
|
|
unsigned sliceResultNumber =
|
|
fusedProducerInfo.origProducer.getResultNumber();
|
|
|
|
auto tilableOp = cast<TilingInterface>(originalOwner);
|
|
// b. get iterDomain Offset and Sizes based on sliceOp tile
|
|
SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
|
|
// skip tensor.pack/unpack/pad, which expects single opResult
|
|
if (tilableOp->getNumResults() > 1 &&
|
|
failed(tilableOp.getIterationDomainTileFromResultTile(
|
|
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
|
|
iterDomainOffset, iterDomainSizes))) {
|
|
// In theory, it is unnecessary to raise an error here. Actually although
|
|
// it fails to reconstruct the result tensor, it should not broke current
|
|
// fusion anyway. The reason why we must return failure currently is that
|
|
// the callback function `newYieldValuesFn` will be called after new init
|
|
// operand(s) has already been appended. It will take more refactoring to
|
|
// make sure the init operands are added consistently in the future. For
|
|
// more details, please refer to:
|
|
// https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
|
|
return failure();
|
|
}
|
|
|
|
// c. calculate offsets and sizes info of all OpResults respectively based
|
|
// on iteration Domain Tile
|
|
SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
|
|
for (const auto &resultNumber : initNumberList) {
|
|
if (resultNumber == sliceResultNumber) {
|
|
offsetList.push_back(sliceOffset);
|
|
sizesList.push_back(sliceSizes);
|
|
} else {
|
|
assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
|
|
// infer result tile according to the iteration domain tile
|
|
SmallVector<OpFoldResult> offset, sizes;
|
|
if (failed(tilableOp.getResultTilePosition(
|
|
rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
|
|
offset, sizes))) {
|
|
return failure();
|
|
}
|
|
offsetList.push_back(offset);
|
|
sizesList.push_back(sizes);
|
|
}
|
|
}
|
|
|
|
// d. create `extract_slice` for `iter_args` for DPS operation if necessary
|
|
if (auto tiledDestStyleOp =
|
|
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
|
|
rewriter.setInsertionPoint(tiledDestStyleOp);
|
|
for (const auto &&[index, newRegionArg] :
|
|
llvm::enumerate(newRegionIterArgs)) {
|
|
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
|
|
loc, newRegionArg, offsetList[index], sizesList[index],
|
|
SmallVector<OpFoldResult>(offsetList[index].size(),
|
|
rewriter.getIndexAttr(1)));
|
|
unsigned resultNumber = initNumberList[index];
|
|
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
|
|
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
|
|
});
|
|
}
|
|
}
|
|
|
|
// e. prepare tiled offset and sizes for later `insert_slice` creation by
|
|
// caller
|
|
Block *block = rewriter.getInsertionPoint()->getBlock();
|
|
rewriter.setInsertionPoint(block->getTerminator());
|
|
for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
|
|
tiledResult.push_back(tiledOwner->getResult(resultNumber));
|
|
tiledOffset.emplace_back(offsetList[index]);
|
|
tiledSizes.emplace_back(sizesList[index]);
|
|
}
|
|
return success();
|
|
};
|
|
|
|
return addInitOperandsToLoopNest(rewriter, loops, initValueList,
|
|
newYieldValuesFn);
|
|
}
|
|
|
|
/// Implementation of tile consumer and fuse producer greedily.
|
|
FailureOr<scf::SCFTileAndFuseResult>
|
|
mlir::scf::tileConsumerAndFuseProducersUsingSCF(
|
|
RewriterBase &rewriter, TilingInterface consumer,
|
|
const scf::SCFTileAndFuseOptions &options) {
|
|
// This transformation is only valid for ops that return values (i.e. not
|
|
// valid to use with operations that have memref operands).
|
|
if (!consumer->getNumResults()) {
|
|
return rewriter.notifyMatchFailure(
|
|
consumer, "invalid pattern for op with no results");
|
|
}
|
|
|
|
// 1. First tile the consumer.
|
|
SetVector<Operation *> fusedProducers, tiledAndFusedOps;
|
|
llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
|
|
|
|
FailureOr<scf::SCFTilingResult> tilingResult =
|
|
tileUsingSCF(rewriter, consumer, options.tilingOptions);
|
|
|
|
if (failed(tilingResult))
|
|
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
|
|
for (auto *tiledOp : tilingResult->tiledOps)
|
|
tiledAndFusedOps.insert(tiledOp);
|
|
|
|
// If there are no loops generated, fusion is immaterial.
|
|
auto &loops = tilingResult->loops;
|
|
if (loops.empty()) {
|
|
DenseMap<Value, Value> replacements;
|
|
for (auto [origVal, replacement] :
|
|
llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
|
|
replacements[origVal] = replacement;
|
|
}
|
|
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
|
|
replacements};
|
|
}
|
|
|
|
// To keep track of replacements for now just record the map from the original
|
|
// untiled value to the result number of the for loop. Since the loop gets
|
|
// potentially replaced during fusion, keeping the value directly wont work.
|
|
DenseMap<Value, size_t> origValToResultNumber;
|
|
for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
|
|
origValToResultNumber[result] = index;
|
|
}
|
|
|
|
// 2. Typically, the operands of the tiled operation are slices of the
|
|
// operands of the untiled operation. These are expressed in IR using
|
|
// `tensor.extract_slice` operations with source being the operands of the
|
|
// untiled operation. Create a worklist of these `tensor.extract_slice`
|
|
// operations. If the producers of the source of the `tensor.extract_slice`
|
|
// can be tiled such that the tiled value is generated in-place, that
|
|
// effectively tiles + fuses the operations.
|
|
auto addCandidateSlices = [](Operation *fusedOp,
|
|
std::deque<tensor::ExtractSliceOp> &candidates) {
|
|
for (Value operand : fusedOp->getOperands())
|
|
if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
|
|
candidates.push_back(sliceOp);
|
|
};
|
|
|
|
std::deque<tensor::ExtractSliceOp> candidates;
|
|
addCandidateSlices(tiledAndFusedOps.back(), candidates);
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
while (!candidates.empty()) {
|
|
// Traverse the slices in BFS fashion.
|
|
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
|
|
candidates.pop_front();
|
|
|
|
// Find the original producer of the slice.
|
|
auto [fusableProducer, destinationInitArg] =
|
|
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
|
|
loops);
|
|
if (!fusableProducer)
|
|
continue;
|
|
|
|
auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
|
|
candidateSliceOp, fusableProducer, destinationInitArg.has_value());
|
|
if (!fuseSlice)
|
|
continue;
|
|
|
|
// The operands of the fused producer might themselved be slices of
|
|
// values produced by operations that implement the `TilingInterface`.
|
|
// Add these operations to the worklist.
|
|
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
|
|
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
|
|
if (!fusedResult)
|
|
continue;
|
|
|
|
if (yieldReplacement) {
|
|
// Reconstruct and yield all opResult of fusableProducerOp by default. The
|
|
// caller can specific which one to yield by designating optional argument
|
|
// named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
|
|
Operation *fusableProducerOp = fusableProducer.getOwner();
|
|
if (failed(yieldReplacementForFusedProducer(
|
|
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
|
|
return rewriter.notifyMatchFailure(
|
|
fusableProducerOp, "failed to replacement value for this "
|
|
"operation from within the tiled loop");
|
|
}
|
|
for (auto [index, result] :
|
|
llvm::enumerate(fusableProducerOp->getResults())) {
|
|
origValToResultNumber[result] = loops.front()->getNumResults() -
|
|
fusableProducerOp->getNumResults() +
|
|
index;
|
|
}
|
|
}
|
|
|
|
if (Operation *tiledAndFusedOp =
|
|
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
|
|
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
|
|
tiledAndFusedOps.insert(tiledAndFusedOp);
|
|
addCandidateSlices(tiledAndFusedOp, candidates);
|
|
}
|
|
}
|
|
|
|
DenseMap<Value, Value> replacements;
|
|
for (auto [origVal, resultNumber] : origValToResultNumber) {
|
|
replacements[origVal] = loops.front()->getResult(resultNumber);
|
|
}
|
|
|
|
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
|
|
replacements};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// tileAndFuseConsumerUsingSCF implementation.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// A utility function that checks whether the only use of the result of a
|
|
/// tensor.insert_slice op is in a scf.yield op.
|
|
static LogicalResult
|
|
checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
|
|
Value result = candidateSliceOp.getResult();
|
|
Value::use_range uses = result.getUses();
|
|
if (!llvm::hasSingleElement(uses)) {
|
|
LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
|
|
return failure();
|
|
}
|
|
OpOperand &operandUse = (*uses.begin());
|
|
Operation *userOp = operandUse.getOwner();
|
|
if (!isa<scf::YieldOp>(userOp)) {
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "Expected scf.yield to be the only user, but got -> "
|
|
<< (*userOp));
|
|
return failure();
|
|
}
|
|
if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
|
|
LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
|
|
"be in the same block\n");
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Fetches the OpOperand of the only user (and use) of the value `val` which
|
|
/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
|
|
/// failure otherwise.
|
|
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
|
|
Block *containingOpBlock) {
|
|
// Step 1. Check that the value has exactly one use.
|
|
if (!llvm::hasSingleElement(val.getUses()))
|
|
return failure();
|
|
// Step 2. Get uses.
|
|
OpOperand &operand = (*val.getUses().begin());
|
|
Operation *consumerOp = operand.getOwner();
|
|
// TODO: We have to init result of consumer before scf.for, use
|
|
// DestinationStyleOpInterface to get result shape from init for now.
|
|
// Add support for other op such as op has InferTypeOpInterface.
|
|
if (!isa<TilingInterface>(consumerOp) ||
|
|
!isa<DestinationStyleOpInterface>(consumerOp))
|
|
return failure();
|
|
if (containingOpBlock != consumerOp->getBlock())
|
|
return failure();
|
|
return &operand;
|
|
}
|
|
|
|
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
|
|
/// tensor.insert_slice. This function makes the following assumptions :
|
|
/// 1. tensor.insert_slice has scf.yield as its only user.
|
|
/// 2. scf.for's corresponding result has only one use.
|
|
static FailureOr<OpOperand *>
|
|
getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
|
|
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
|
|
return failure();
|
|
Value sliceResult = candidateSliceOp.getResult();
|
|
// Step 1. Fetch the corresponding output.
|
|
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
|
|
unsigned resultNumber = yieldOpOperand.getOperandNumber();
|
|
// Step 2. Check containing op is scf.for.
|
|
Operation *containingOp = candidateSliceOp->getParentOp();
|
|
auto forOp = dyn_cast<scf::ForOp>(containingOp);
|
|
if (!forOp)
|
|
return failure();
|
|
Value resultingValue = forOp->getResult(resultNumber);
|
|
|
|
return getConsumerFromUses(resultingValue, containingOp->getBlock());
|
|
}
|
|
|
|
/// Fetch the first untiled consumer of a scf.forall's result which is yielded
|
|
/// by a tensor.parallel_insert_slice.
|
|
static FailureOr<OpOperand *>
|
|
getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
|
|
// Step 1. Fetch the corresponding output
|
|
Value sliceDest = candidateSliceOp.getDest();
|
|
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
|
|
if (!iterArg)
|
|
return failure();
|
|
Operation *containingOp = iterArg.getOwner()->getParentOp();
|
|
if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
|
|
return failure();
|
|
// Step 2. Check that the containing op is scf.forall.
|
|
auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
|
|
if (!forallOp)
|
|
return failure();
|
|
Value resultingValue =
|
|
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
|
|
|
|
return getConsumerFromUses(resultingValue, containingOp->getBlock());
|
|
}
|
|
|
|
/// This utility currently checks whether the loop either :-
|
|
/// 1. Yields exactly one result.
|
|
/// 2. Has consumer op as its first user and other users to be in the same
|
|
/// containing block as that of consumer op's. Currently we clone the loop op
|
|
/// right before the consumer op in order to maintain a valid def-use chain.
|
|
/// This utility thus helps ensuring that no invalid IR is formed due to the
|
|
/// same.
|
|
static LogicalResult checkAssumptionForLoop(Operation *loopOp,
|
|
Operation *consumerOp) {
|
|
// Check if the loop op yields one result.
|
|
if (loopOp->getNumResults() == 1)
|
|
return success();
|
|
// Check if the consumerOp is the first user of the loopOp and if other users
|
|
// are in the same containing block as that of consumer op's.
|
|
Block *parentBlock = consumerOp->getBlock();
|
|
for (Operation *userOp : loopOp->getUsers()) {
|
|
if (userOp == consumerOp)
|
|
continue;
|
|
if (parentBlock != userOp->getBlock() ||
|
|
!consumerOp->isBeforeInBlock(userOp))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// A utility to fetch an untiled consumer of
|
|
/// tensor.insert_slice/tensor.parallel_insert_slice.
|
|
static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
|
|
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
|
|
return getUntiledConsumerFromSlice(insertSlice);
|
|
} else if (auto parallelInsertSlice =
|
|
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
|
|
return getUntiledConsumerFromSlice(parallelInsertSlice);
|
|
} else {
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
/// After fusing consumer into scf.for we want to modify the scf.yield operation
|
|
/// to reflect the same by returning the values yielded by the tiled consumer.
|
|
static void
|
|
fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
|
|
TilingResult &tilingResult,
|
|
ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
|
|
ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
|
|
ArrayRef<BlockArgument> bbArgs) {
|
|
scf::YieldOp oldTerminatorOp =
|
|
cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
|
|
unsigned totalOldResults = oldTerminatorOp->getNumResults();
|
|
unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
|
|
SmallVector<Value> newYieldOperands;
|
|
newYieldOperands.reserve(totalOldResults + totalTiledResults);
|
|
for (auto oldResult : oldTerminatorOp.getResults()) {
|
|
newYieldOperands.push_back(oldResult);
|
|
}
|
|
rewriter.setInsertionPointAfter(oldTerminatorOp);
|
|
Location loc = newForOp.getLoc();
|
|
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
|
|
llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
|
|
resultOffsets, resultSizes)) {
|
|
SmallVector<OpFoldResult> strides(resultOffset.size(),
|
|
rewriter.getIndexAttr(1));
|
|
Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
|
|
loc, tiledResult, bbArg, resultOffset, resultSize, strides);
|
|
newYieldOperands.push_back(newInsertSliceOp);
|
|
}
|
|
rewriter.create<scf::YieldOp>(loc, newYieldOperands);
|
|
rewriter.eraseOp(oldTerminatorOp);
|
|
}
|
|
|
|
/// After fusing consumer into scf.forall we want to yield each of the resulting
|
|
/// values by the tiled consumer within scf.forall.in_parallel region.
|
|
static void
|
|
fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
|
|
SmallVector<Value> tiledResults,
|
|
ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
|
|
ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
|
|
ArrayRef<BlockArgument> bbArgs) {
|
|
scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
|
|
rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
|
|
Location firstYieldOpLoc =
|
|
(*(newTerminatorOp.getYieldingOps().begin())).getLoc();
|
|
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
|
|
llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
|
|
SmallVector<OpFoldResult> strides(resultOffset.size(),
|
|
rewriter.getIndexAttr(1));
|
|
rewriter.create<tensor::ParallelInsertSliceOp>(
|
|
firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
|
|
}
|
|
}
|
|
|
|
/// Implementation of fusing consumer of a single slice by computing the
|
|
/// slice of the consumer in-place for scf loop.
|
|
FailureOr<scf::SCFFuseConsumerOfSliceResult>
|
|
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
|
|
Operation *candidateSliceOp) {
|
|
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
|
|
candidateSliceOp))
|
|
return failure();
|
|
|
|
bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
|
|
|
|
// 1. Get the consumer of scf.for for the result yielded by
|
|
// tensor.insert_slice/parallel_insert_slice.
|
|
FailureOr<OpOperand *> maybeConsumerOpOperand =
|
|
getUntiledConsumerFromSlice(candidateSliceOp);
|
|
if (failed(maybeConsumerOpOperand)) {
|
|
return rewriter.notifyMatchFailure(candidateSliceOp,
|
|
"could not fetch consumer to fuse");
|
|
}
|
|
OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
|
|
Operation *consumerOp = consumerOpOperand->getOwner();
|
|
unsigned operandNumber = consumerOpOperand->getOperandNumber();
|
|
unsigned resultNumber = 0;
|
|
if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
|
|
resultNumber = producerResult.getResultNumber();
|
|
} else {
|
|
return rewriter.notifyMatchFailure(
|
|
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
|
|
}
|
|
|
|
Operation *oldLoopOp = nullptr;
|
|
SmallVector<Value> newOuts;
|
|
Block *oldLoopBody = nullptr;
|
|
unsigned initSize = 0;
|
|
unsigned rank = 1;
|
|
if (isInsertSliceOp) {
|
|
auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
|
|
oldLoopOp = forOp;
|
|
llvm::append_range(newOuts, forOp.getInits());
|
|
oldLoopBody = forOp.getBody();
|
|
initSize = forOp.getInits().size();
|
|
} else {
|
|
auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
|
|
oldLoopOp = forallOp;
|
|
llvm::append_range(newOuts, forallOp.getOutputs());
|
|
oldLoopBody = forallOp.getBody();
|
|
initSize = forallOp.getOutputs().size();
|
|
rank = forallOp.getRank();
|
|
}
|
|
|
|
if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
|
|
return rewriter.notifyMatchFailure(
|
|
oldLoopOp, "containing loop op should either yield just one value or "
|
|
"have the consumer op as its first user");
|
|
}
|
|
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
|
|
// 2. Check consumer is not using scf loop's output as init.
|
|
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
|
|
SmallVector<Value> dpsInits =
|
|
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
|
|
if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
|
|
return rewriter.notifyMatchFailure(
|
|
consumerOp,
|
|
"consumer op taking the result of scf.for as init is not supported");
|
|
}
|
|
newOuts.append(dpsInits);
|
|
|
|
Location loc = oldLoopOp->getLoc();
|
|
|
|
// 3. Create new scf loop op.
|
|
rewriter.setInsertionPoint(consumerOp);
|
|
Operation *newLoopOp = nullptr;
|
|
Block *newLoopBody = nullptr;
|
|
if (isInsertSliceOp) {
|
|
auto forOp = cast<scf::ForOp>(oldLoopOp);
|
|
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
|
|
forOp.getUpperBound(),
|
|
forOp.getStep(), newOuts);
|
|
newLoopOp = newForOp;
|
|
newLoopBody = newForOp.getBody();
|
|
} else {
|
|
auto forallOp = cast<scf::ForallOp>(oldLoopOp);
|
|
auto newForallOp = rewriter.create<scf::ForallOp>(
|
|
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
|
|
forallOp.getMixedStep(), newOuts, forallOp.getMapping());
|
|
newLoopOp = newForallOp;
|
|
rewriter.eraseOp(newForallOp.getTerminator());
|
|
newLoopBody = newForallOp.getBody();
|
|
}
|
|
|
|
// 4. Move the loop body to the new op.
|
|
unsigned oldNumArguments = oldLoopBody->getNumArguments();
|
|
rewriter.mergeBlocks(oldLoopBody, newLoopBody,
|
|
newLoopBody->getArguments().take_front(oldNumArguments));
|
|
|
|
// 5. Set insertion point before terminator op of the loop and create a new
|
|
// tensor.insert_slice. In the scf.for case this is a clone of the
|
|
// candidateSliceOp whereas in the scf.forall case this is created from the
|
|
// operands of tensor.parallel_insert_slice.
|
|
tensor::InsertSliceOp clonedInsertSliceOp;
|
|
if (auto sliceOp =
|
|
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
|
|
auto newForallOp = cast<scf::ForallOp>(newLoopOp);
|
|
rewriter.setInsertionPoint(newForallOp.getTerminator());
|
|
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
|
|
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
|
|
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
|
|
} else {
|
|
rewriter.setInsertionPoint(candidateSliceOp);
|
|
clonedInsertSliceOp =
|
|
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
|
|
}
|
|
|
|
// 6.a. Clone consumer op.
|
|
auto newForOpBlockArgsForConsumerDest =
|
|
newLoopBody->getArguments().drop_front(oldNumArguments);
|
|
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
|
|
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
|
|
|
|
// 6.b. Replace all uses of the loop result with the result of the cloned
|
|
// tensor.insert_slice.
|
|
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
|
|
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
|
|
operandToReplace.set(clonedInsertSliceOp.getResult());
|
|
});
|
|
|
|
// 7 - Perform tiling of the cloned consumer and replace the operand at
|
|
// `operandNumber` with the source of the cloned tensor.insert_slice op.
|
|
auto ossSliceOp =
|
|
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
|
|
FailureOr<TilingResult> tileAndFuseResult =
|
|
tensor::replaceInsertSliceWithTiledConsumer(
|
|
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
|
|
if (failed(tileAndFuseResult)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceAllUsesWith(
|
|
tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
|
|
clonedInsertSliceOp.getSource());
|
|
|
|
// 8 - Extract offset/sizes/strides required to create the
|
|
// tensor.insert_slice/parallel_insert_slice for each result of the consumer.
|
|
SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
|
|
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
|
|
|
|
// 9. Check all insert stride is 1.
|
|
if (llvm::any_of(strides, [](OpFoldResult stride) {
|
|
return !isConstantIntValue(stride, 1);
|
|
})) {
|
|
return rewriter.notifyMatchFailure(
|
|
candidateSliceOp, "containingOp's result yield with stride");
|
|
}
|
|
|
|
// 10. Try to get iter domain position from input position.
|
|
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
|
|
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
|
|
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
|
|
iterDomainSizes))) {
|
|
return rewriter.notifyMatchFailure(
|
|
clonedConsumerOp, "can't get iter domain position from input position");
|
|
}
|
|
|
|
// 11. Try to fetch the offset and size for all results of the cloned
|
|
// consumer. This would then be used to form the corresponding
|
|
// tensor.insert_slice/parallel_insert_slice later.
|
|
unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
|
|
SmallVector<SmallVector<OpFoldResult>> resultOffsets(
|
|
totalNumResultsOfConsumer);
|
|
SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
|
|
for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
|
|
if (failed(clonedConsumerOp.getResultTilePosition(
|
|
rewriter, idx, iterDomainOffsets, iterDomainSizes,
|
|
resultOffsets[idx], resultSizes[idx]))) {
|
|
return rewriter.notifyMatchFailure(
|
|
clonedConsumerOp,
|
|
"can't get result domain position from iter domain position");
|
|
}
|
|
}
|
|
|
|
auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
|
|
auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
|
|
if (isInsertSliceOp) {
|
|
auto newForOp = cast<scf::ForOp>(newLoopOp);
|
|
fixTerminatorSCFYield(
|
|
rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
|
|
newForOp.getBody()->getArguments().drop_front(1 + initSize));
|
|
} else {
|
|
auto newForallOp = cast<scf::ForallOp>(newLoopOp);
|
|
fixTerminatorSCFInParallel(
|
|
rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
|
|
arrayRefOffsets, arrayRefSizes,
|
|
newForallOp.getBody()->getArguments().drop_front(rank + initSize));
|
|
}
|
|
|
|
// 12. Replace the result of scf loop and consumer op with new loop's results.
|
|
for (auto &&[oldResult, newResult] :
|
|
llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
|
|
rewriter.replaceAllUsesWith(oldResult, newResult);
|
|
}
|
|
|
|
for (auto &&[oldResult, newResult] :
|
|
llvm::zip(consumerOp->getResults(),
|
|
newLoopOp->getResults().drop_front(initSize))) {
|
|
rewriter.replaceAllUsesWith(oldResult, newResult);
|
|
}
|
|
|
|
// 13. Need to erase the old scf loop and the cloned consumer op.
|
|
rewriter.eraseOp(oldLoopOp);
|
|
rewriter.eraseOp(clonedConsumerOp);
|
|
|
|
return scf::SCFFuseConsumerOfSliceResult{
|
|
consumerOpOperand,
|
|
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
|
|
tileAndFuseResult->tiledOps};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// lowerToLoopsUsingSCFForOp implementation.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
FailureOr<SmallVector<scf::ForOp>>
|
|
mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
|
|
TilingInterface op) {
|
|
// TODO: Handle cases where the op has results if needed.
|
|
if (op->getNumResults() > 0) {
|
|
return rewriter.notifyMatchFailure(
|
|
op, "unable to lower to loops operations with return values");
|
|
}
|
|
|
|
SmallVector<Range> domain = op.getIterationDomain(rewriter);
|
|
SmallVector<Value> ivs;
|
|
SmallVector<scf::ForOp> loops;
|
|
Location loc = op.getLoc();
|
|
for (auto loopRange : domain) {
|
|
Value offsetVal =
|
|
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
|
|
Value sizeVal =
|
|
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
|
|
Value strideVal =
|
|
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
|
|
auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
|
|
strideVal, ValueRange{});
|
|
loops.push_back(loop);
|
|
ivs.push_back(loop.getInductionVar());
|
|
rewriter.setInsertionPoint(loop.getBody()->getTerminator());
|
|
}
|
|
if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
|
|
return failure();
|
|
}
|
|
return loops;
|
|
}
|