Returning failure when tile sizes are all zero prevents the change in the marker. This makes pattern rewriter run the pattern multiple times only to exit when it hits a limit. Instead just clone the operation (since tiling is essentially cloning in this case). Then the transformation filter kicks in to avoid the pattern rewriter to be invoked many times. Differential Revision: https://reviews.llvm.org/D113949
573 lines
23 KiB
C++
573 lines
23 KiB
C++
//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the linalg dialect Tiling pass.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/Transforms.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#include "llvm/Support/CommandLine.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
using namespace mlir::scf;
|
|
|
|
#define DEBUG_TYPE "linalg-tiling"
|
|
|
|
static bool isZero(Value v) {
|
|
if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
|
|
return cst.value() == 0;
|
|
return false;
|
|
}
|
|
|
|
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
|
|
|
|
// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
|
|
// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
|
|
// one entry per surrounding loop. It uses zero as the convention that a
|
|
// particular loop is not tiled. This convention simplifies implementations by
|
|
// avoiding affine map manipulations.
|
|
// The returned ranges correspond to the loop ranges, in the proper order, that
|
|
// are tiled and for which new loops will be created. Also the function returns
|
|
// a map from loop indices of the LinalgOp to the corresponding non-empty range
|
|
// indices of newly created loops.
|
|
static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
|
|
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
|
ValueRange allShapeSizes, ValueRange allTileSizes) {
|
|
assert(allTileSizes.size() == map.getNumResults());
|
|
// Apply `map` to get shape sizes in loop order.
|
|
auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
|
|
SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
|
|
|
|
// Traverse the tile sizes, which are in loop order, erase zeros everywhere.
|
|
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
|
|
for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
|
|
if (isZero(tileSizes[idx - zerosCount])) {
|
|
shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
|
|
tileSizes.erase(tileSizes.begin() + idx - zerosCount);
|
|
++zerosCount;
|
|
continue;
|
|
}
|
|
loopIndexToRangeIndex[idx] = idx - zerosCount;
|
|
}
|
|
|
|
// Create a new range with the applied tile sizes.
|
|
SmallVector<Range, 4> res;
|
|
for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
|
|
res.push_back(Range{b.create<arith::ConstantIndexOp>(loc, 0),
|
|
shapeSizes[idx], tileSizes[idx]});
|
|
return std::make_tuple(res, loopIndexToRangeIndex);
|
|
}
|
|
|
|
// All indices returned by IndexOp should be invariant with respect to tiling.
|
|
// Therefore, if an operation is tiled, we have to transform the indices
|
|
// accordingly, i.e. offset them by the values of the corresponding induction
|
|
// variables that are captured implicitly in the body of the op.
|
|
//
|
|
// Example. `linalg.generic` before tiling:
|
|
//
|
|
// #id_2d = (i, j) -> (i, j)
|
|
// #pointwise_2d_trait = {
|
|
// indexing_maps = [#id_2d, #id_2d],
|
|
// iterator_types = ["parallel", "parallel"]
|
|
// }
|
|
// linalg.generic #pointwise_2d_trait %operand, %result {
|
|
// ^bb0(%operand_in: f32, %result_in: f32):
|
|
// %i = linalg.index 0 : index
|
|
// %j = linalg.index 1 : index
|
|
// <some operations that use %i, %j>
|
|
// }: memref<50x100xf32>, memref<50x100xf32>
|
|
//
|
|
// After tiling pass with tiles sizes 10 and 25:
|
|
//
|
|
// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
|
|
//
|
|
// %c1 = arith.constant 1 : index
|
|
// %c0 = arith.constant 0 : index
|
|
// %c25 = arith.constant 25 : index
|
|
// %c10 = arith.constant 10 : index
|
|
// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
|
|
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
|
|
// scf.for %k = %c0 to operand_dim_0 step %c10 {
|
|
// scf.for %l = %c0 to operand_dim_1 step %c25 {
|
|
// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
|
|
// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
|
// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
|
|
// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
|
// linalg.generic pointwise_2d_trait %4, %5 {
|
|
// ^bb0(%operand_in: f32, %result_in: f32):
|
|
// %i = linalg.index 0 : index
|
|
// %j = linalg.index 1 : index
|
|
// // Indices `k` and `l` are implicitly captured in the body.
|
|
// %transformed_i = arith.addi %i, %k : index // index `i` is offset by %k
|
|
// %transformed_j = arith.addi %j, %l : index // index `j` is offset by %l
|
|
// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
|
|
// <some operations that use %transformed_i, %transformed_j>
|
|
// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
|
|
// }
|
|
// }
|
|
//
|
|
// TODO: Investigate whether mixing implicit and explicit indices
|
|
// does not lead to losing information.
|
|
static void
|
|
transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
|
|
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
|
|
SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
|
|
for (auto &en : enumerate(allIvs)) {
|
|
auto rangeIndex = loopIndexToRangeIndex.find(en.index());
|
|
if (rangeIndex == loopIndexToRangeIndex.end())
|
|
continue;
|
|
en.value() = ivs[rangeIndex->second];
|
|
}
|
|
addTileLoopIvsToIndexOpResults(b, op, allIvs);
|
|
}
|
|
|
|
// Insert a tile `source` into the destination tensor `dest`. The position at
|
|
// which the tile is inserted (as well as size of tile) is taken from a given
|
|
// ExtractSliceOp `sliceOp`.
|
|
static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
|
|
tensor::ExtractSliceOp sliceOp, Value source,
|
|
Value dest) {
|
|
return b.create<tensor::InsertSliceOp>(
|
|
loc, sliceOp.source().getType(), source, dest, sliceOp.offsets(),
|
|
sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
|
|
sliceOp.static_sizes(), sliceOp.static_strides());
|
|
}
|
|
|
|
template <typename LoopTy>
|
|
static FailureOr<TiledLinalgOp>
|
|
tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
|
|
const LinalgTilingOptions &options) {
|
|
auto nLoops = op.getNumLoops();
|
|
// Initial tile sizes may be too big, only take the first nLoops.
|
|
tileSizes = tileSizes.take_front(nLoops);
|
|
|
|
if (llvm::all_of(tileSizes, isZero)) {
|
|
TiledLinalgOp tiledOp;
|
|
tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
|
|
tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
|
|
tiledOp.op->result_end());
|
|
return tiledOp;
|
|
}
|
|
|
|
// 1. Build the tiled loop ranges.
|
|
auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
|
|
AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
|
|
if (!shapeSizesToLoopsMap)
|
|
return failure();
|
|
|
|
SmallVector<Range, 4> loopRanges;
|
|
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
|
|
std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
|
|
b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
|
|
|
|
SmallVector<Attribute, 4> iteratorTypes;
|
|
for (auto attr :
|
|
enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
|
|
if (loopIndexToRangeIndex.count(attr.index()))
|
|
iteratorTypes.push_back(attr.value());
|
|
}
|
|
// If interchangeVector is empty, use the identity. Build the permutation map
|
|
// otherwise.
|
|
auto invPermutationMap =
|
|
AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
|
|
if (!options.interchangeVector.empty()) {
|
|
// Based on the pruned iterations (due to zero tile size), recompute the
|
|
// interchange vector.
|
|
SmallVector<unsigned, 4> interchangeVector;
|
|
interchangeVector.reserve(options.interchangeVector.size());
|
|
for (auto pos : options.interchangeVector) {
|
|
auto it = loopIndexToRangeIndex.find(pos);
|
|
if (it == loopIndexToRangeIndex.end())
|
|
continue;
|
|
interchangeVector.push_back(it->second);
|
|
}
|
|
// Interchange vector is guaranteed to be a permutation,
|
|
// `inversePermutation` must succeed.
|
|
invPermutationMap = inversePermutation(
|
|
AffineMap::getPermutationMap(interchangeVector, b.getContext()));
|
|
assert(invPermutationMap);
|
|
SmallVector<int64_t> permutation(interchangeVector.begin(),
|
|
interchangeVector.end());
|
|
applyPermutationToVector(loopRanges, permutation);
|
|
applyPermutationToVector(iteratorTypes, permutation);
|
|
}
|
|
|
|
// 2. Create the tiled loops.
|
|
LinalgOp res = op;
|
|
SmallVector<Value, 4> ivs, tensorResults;
|
|
auto tiledLoopBodyBuilder =
|
|
[&](OpBuilder &b, Location loc, ValueRange localIvs,
|
|
ValueRange operandValuesToUse) -> scf::ValueVector {
|
|
ivs.assign(localIvs.begin(), localIvs.end());
|
|
|
|
// When an `interchangeVector` is present, it has been applied to the
|
|
// loop ranges and the iterator types. Apply its inverse to the
|
|
// resulting loop `ivs` to match the op definition.
|
|
SmallVector<Value, 4> interchangedIvs;
|
|
if (!options.interchangeVector.empty())
|
|
interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs);
|
|
else
|
|
interchangedIvs.assign(ivs.begin(), ivs.end());
|
|
|
|
// Tile the `operandValuesToUse` that either match the `op` operands
|
|
// themselves or the tile loop arguments forwarding them.
|
|
assert(operandValuesToUse.size() ==
|
|
static_cast<size_t>(op.getNumInputsAndOutputs()) &&
|
|
"expect the number of operands and inputs and outputs to match");
|
|
SmallVector<Value> valuesToTile = operandValuesToUse;
|
|
auto sizeBounds =
|
|
applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
|
|
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
|
|
b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds);
|
|
|
|
// TODO: use an interface/adaptor to avoid leaking position in
|
|
// `tiledOperands`.
|
|
SmallVector<Type, 4> resultTensorTypes;
|
|
for (OpOperand *opOperand : op.getOutputTensorOperands())
|
|
resultTensorTypes.push_back(
|
|
tiledOperands[opOperand->getOperandNumber()].getType());
|
|
|
|
res = op.clone(b, loc, resultTensorTypes, tiledOperands);
|
|
|
|
// Insert a insert_slice for each output tensor.
|
|
unsigned resultIdx = 0;
|
|
for (OpOperand *opOperand : op.getOutputTensorOperands()) {
|
|
// TODO: use an interface/adaptor to avoid leaking position in
|
|
// `tiledOperands`.
|
|
Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
|
|
if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
|
|
tensorResults.push_back(insertSliceIntoTensor(
|
|
b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source()));
|
|
} else {
|
|
tensorResults.push_back(res->getResult(resultIdx));
|
|
}
|
|
++resultIdx;
|
|
}
|
|
return scf::ValueVector(tensorResults.begin(), tensorResults.end());
|
|
};
|
|
GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
|
|
tiledLoopBodyBuilder, options.distribution,
|
|
options.distributionTypes);
|
|
|
|
// 3. Transform IndexOp results w.r.t. the tiling.
|
|
transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
|
|
|
|
// 4. Gather the newly created loops and return them with the new op.
|
|
SmallVector<Operation *, 8> loops;
|
|
loops.reserve(ivs.size());
|
|
for (auto iv : ivs) {
|
|
if (iv.isa<BlockArgument>()) {
|
|
loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
|
|
assert(loops.back() && "no owner found for induction variable!");
|
|
} else {
|
|
// TODO: Instead of doing this, try to recover the ops used instead of the
|
|
// loop.
|
|
loops.push_back(nullptr);
|
|
}
|
|
}
|
|
|
|
// 5. Get the tensor results from the outermost loop if available. Otherwise
|
|
// use the previously captured `tensorResults`.
|
|
Operation *outermostLoop = nullptr;
|
|
for (Operation *loop : loops)
|
|
if ((outermostLoop = loop))
|
|
break;
|
|
|
|
return TiledLinalgOp{
|
|
res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
|
|
}
|
|
|
|
template <typename LoopTy>
|
|
FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
|
|
OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
|
|
OpBuilder::InsertionGuard g(b);
|
|
b.setInsertionPoint(op);
|
|
|
|
if (!options.tileSizeComputationFunction)
|
|
return failure();
|
|
|
|
// Enforce the convention that "tiling by zero" skips tiling a particular
|
|
// dimension. This convention is significantly simpler to handle instead of
|
|
// adjusting affine maps to account for missing dimensions.
|
|
auto nLoops = op.getNumLoops();
|
|
SmallVector<Value, 4> tileSizeVector =
|
|
options.tileSizeComputationFunction(b, op);
|
|
if (tileSizeVector.size() < nLoops) {
|
|
auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
|
|
tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
|
|
}
|
|
|
|
return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
|
|
}
|
|
|
|
FailureOr<TiledLinalgOp>
|
|
mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
|
|
const LinalgTilingOptions &options) {
|
|
switch (options.loopType) {
|
|
case LinalgTilingLoopType::Loops:
|
|
return tileLinalgOpImpl<scf::ForOp>(b, op, options);
|
|
case LinalgTilingLoopType::ParallelLoops:
|
|
return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
|
|
case LinalgTilingLoopType::TiledLoops:
|
|
return tileLinalgOpImpl<linalg::TiledLoopOp>(b, op, options);
|
|
default:;
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
/// Generate a loop nest around a given PadTensorOp (for tiling). `newPadOp`
|
|
/// and `loopNest` are output parameters that return the new (tiled) PadTensorOp
|
|
/// and the loop nest.
|
|
static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
|
|
PadTensorOp &newPadOp, LoopNest &loopNest,
|
|
const LinalgTilingOptions &options) {
|
|
Location loc = op.getLoc();
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPoint(op);
|
|
|
|
// Clone PadTensorOp so that the existing op can be replaced more easily.
|
|
newPadOp = cast<PadTensorOp>(builder.clone(*op.getOperation()));
|
|
// Get rank and tile sizes.
|
|
int64_t rank = op.getResultType().getRank();
|
|
SmallVector<Value> tileSizes =
|
|
options.tileSizeComputationFunction(builder, op);
|
|
assert(static_cast<int64_t>(tileSizes.size()) == rank);
|
|
// Compute lower and upper bounds of the loop nest.
|
|
SmallVector<Range> ranges = op.getLoopBounds(builder);
|
|
SmallVector<Value> lbs, dims, allDims, steps;
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
allDims.push_back(ranges[i].size);
|
|
if (!isZero(tileSizes[i])) {
|
|
lbs.push_back(ranges[i].offset);
|
|
dims.push_back(ranges[i].size);
|
|
steps.push_back(tileSizes[i]);
|
|
}
|
|
}
|
|
// Generate loop nest: One loop per dimension.
|
|
SmallVector<Value> destOperand = op.getDestinationOperands(builder);
|
|
loopNest = mlir::scf::buildLoopNest(
|
|
builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand),
|
|
[&](OpBuilder &b, Location loc, ValueRange localIvs,
|
|
ValueRange iterArgs) -> scf::ValueVector {
|
|
// Compute offsets and sizes of ExtractSliceOp.
|
|
SmallVector<Value> offsets =
|
|
computeTileOffsets(b, loc, localIvs, tileSizes);
|
|
SmallVector<Value> sizes =
|
|
computeTileSizes(b, loc, localIvs, tileSizes, allDims);
|
|
// Create ExtractSliceOp: Extract a tile from the PadTensorOp.
|
|
// Note: The PadTensorOp is located outside of the loop nest. It is
|
|
// later moved inside by ExtractSliceOfPadTensorSwapPattern.
|
|
auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext());
|
|
Value tiledOutput =
|
|
makeTiledShape(b, loc, newPadOp->getResult(0), tileSizes, map,
|
|
offsets, allDims, sizes);
|
|
auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
|
|
assert(sliceOp && "expected ExtractSliceOp");
|
|
// Insert the tile into the output tensor.
|
|
Value yieldValue =
|
|
insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]);
|
|
return scf::ValueVector({yieldValue});
|
|
});
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
struct PadTensorOpTilingPattern : public OpRewritePattern<PadTensorOp> {
|
|
PadTensorOpTilingPattern(MLIRContext *ctx, LinalgTilingOptions opt)
|
|
: OpRewritePattern<PadTensorOp>(ctx), options(opt) {}
|
|
|
|
LogicalResult matchAndRewrite(PadTensorOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker))
|
|
return failure();
|
|
PadTensorOp newPadOp;
|
|
LoopNest loopNest;
|
|
if (failed(tilePadTensorOp(rewriter, op, newPadOp, loopNest, options)))
|
|
return failure();
|
|
newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker,
|
|
rewriter.getUnitAttr());
|
|
// Replace all uses of the original PadTensorOp.
|
|
rewriter.replaceOp(op, loopNest.getResults()[0]);
|
|
return success();
|
|
}
|
|
|
|
LinalgTilingOptions options;
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
/// Helper classes for type list expansion.
|
|
template <typename... OpTypes>
|
|
class CanonicalizationPatternList;
|
|
|
|
template <>
|
|
class CanonicalizationPatternList<> {
|
|
public:
|
|
static void insert(RewritePatternSet &patterns) {}
|
|
};
|
|
|
|
template <typename OpTy, typename... OpTypes>
|
|
class CanonicalizationPatternList<OpTy, OpTypes...> {
|
|
public:
|
|
static void insert(RewritePatternSet &patterns) {
|
|
OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
|
|
CanonicalizationPatternList<OpTypes...>::insert(patterns);
|
|
}
|
|
};
|
|
|
|
/// Helper classes for type list expansion.
|
|
template <typename... OpTypes>
|
|
class RewritePatternList;
|
|
|
|
template <>
|
|
class RewritePatternList<> {
|
|
public:
|
|
static void insert(RewritePatternSet &patterns,
|
|
const LinalgTilingOptions &options) {}
|
|
};
|
|
|
|
template <typename OpTy, typename... OpTypes>
|
|
class RewritePatternList<OpTy, OpTypes...> {
|
|
public:
|
|
static void insert(RewritePatternSet &patterns,
|
|
const LinalgTilingOptions &options) {
|
|
auto *ctx = patterns.getContext();
|
|
patterns.add<LinalgTilingPattern<OpTy>>(
|
|
ctx, options,
|
|
LinalgTransformationFilter(ArrayRef<StringAttr>{},
|
|
StringAttr::get(ctx, "tiled")));
|
|
RewritePatternList<OpTypes...>::insert(patterns, options);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
RewritePatternSet
|
|
mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
|
|
RewritePatternSet patterns(ctx);
|
|
populateLinalgTilingCanonicalizationPatterns(patterns);
|
|
return patterns;
|
|
}
|
|
|
|
void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
|
|
RewritePatternSet &patterns) {
|
|
auto *ctx = patterns.getContext();
|
|
AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
|
|
AffineForOp::getCanonicalizationPatterns(patterns, ctx);
|
|
AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
|
|
AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
|
|
arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
|
|
|
|
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
|
|
memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
|
|
|
|
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
|
|
scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
|
|
|
|
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
|
|
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
|
|
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
|
|
|
|
InitTensorOp::getCanonicalizationPatterns(patterns, ctx);
|
|
PadTensorOp::getCanonicalizationPatterns(patterns, ctx);
|
|
ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
|
|
|
|
CanonicalizationPatternList<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
|
>::insert(patterns);
|
|
}
|
|
|
|
/// Populate the given list with patterns that apply Linalg tiling.
|
|
static void insertTilingPatterns(RewritePatternSet &patterns,
|
|
const LinalgTilingOptions &options) {
|
|
RewritePatternList<GenericOp,
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
|
>::insert(patterns, options);
|
|
patterns.add<PadTensorOpTilingPattern>(patterns.getContext(), options);
|
|
}
|
|
|
|
static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
|
|
MLIRContext *ctx = funcOp.getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext());
|
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
|
(void)applyPatternsAndFoldGreedily(
|
|
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
|
|
}
|
|
|
|
namespace {
|
|
struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
|
|
LinalgTilingPass() = default;
|
|
LinalgTilingPass(ArrayRef<int64_t> tileSizes, LinalgTilingLoopType loopType,
|
|
ArrayRef<StringRef> distributionTypes) {
|
|
this->tileSizes = tileSizes;
|
|
this->loopType = "";
|
|
this->loopTypeEnum = loopType;
|
|
this->distributionTypes = llvm::to_vector<2>(llvm::map_range(
|
|
distributionTypes, [](StringRef ref) { return ref.str(); }));
|
|
}
|
|
|
|
void runOnFunction() override {
|
|
FuncOp funcOp = getFunction();
|
|
LinalgTilingLoopType type =
|
|
llvm::StringSwitch<LinalgTilingLoopType>(loopType)
|
|
.Case("for", LinalgTilingLoopType::Loops)
|
|
.Case("affine", LinalgTilingLoopType::AffineLoops)
|
|
.Case("parallel", LinalgTilingLoopType::ParallelLoops)
|
|
.Case("tiled_loop", LinalgTilingLoopType::TiledLoops)
|
|
.Default(loopTypeEnum);
|
|
auto distTypes = llvm::to_vector<2>(llvm::map_range(
|
|
distributionTypes, [](std::string &str) { return StringRef(str); }));
|
|
auto options = LinalgTilingOptions()
|
|
.setTileSizes(tileSizes)
|
|
.setLoopType(type)
|
|
.setDistributionTypes(distTypes);
|
|
MLIRContext *ctx = funcOp.getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
insertTilingPatterns(patterns, options);
|
|
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
|
(void)applyPatternsAndFoldGreedily(
|
|
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
|
|
// Drop the marker.
|
|
funcOp.walk([](LinalgOp op) {
|
|
op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
|
|
});
|
|
|
|
// Apply swap pattern after generating loop nest and running
|
|
// canonicalizations.
|
|
applyExtractSliceOfPadTensorSwapPattern(funcOp);
|
|
}
|
|
|
|
LinalgTilingLoopType loopTypeEnum;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<FuncOp>>
|
|
mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes,
|
|
linalg::LinalgTilingLoopType loopType,
|
|
ArrayRef<StringRef> distributionTypes) {
|
|
return std::make_unique<LinalgTilingPass>(tileSizes, loopType,
|
|
distributionTypes);
|
|
}
|