The `tileAndFuseLinalgOps` is a legacy approach for tiling + fusion of Linalg operations. Since it was also intended to work on operations with buffer operands, this method had fairly complex logic to make sure tile and fuse was correct even with side-effecting linalg ops. While complex, it still wasnt robust enough. This patch deprecates this method and thereby deprecating the tiling + fusion method for ops with buffer semantics. Note that the core transformation to do fusion of a producer with a tiled consumer still exists. The deprecation here only removes methods that auto-magically tried to tile and fuse correctly in presence of side-effects. The `tileAndFuseLinalgOps` also works with operations with tensor semantics. There are at least two other ways the same functionality exists. 1) The `tileConsumerAndFuseProducers` method. This does a similar transformation, but using a slightly different logic to automatically figure out the legal tile + fuse code. Note that this is also to be deprecated soon. 2) The prefered way uses the `TilingInterface` for tile + fuse, and relies on the caller to set the tiling options correctly to ensure that the generated code is correct. As proof that (2) is equivalent to the functionality provided by `tileAndFuseLinalgOps`, relevant tests have been moved to use the interface, where the test driver sets the tile sizes appropriately to generate the expected code. Differential Revision: https://reviews.llvm.org/D129901
221 lines
9.0 KiB
C++
221 lines
9.0 KiB
C++
//===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass for testing tiling operations using
|
|
// `TilingInterface`.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Interfaces/TilingInterface.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
|
|
/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
|
|
/// using a `filter` to avoid recursive application.
|
|
struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
|
|
TestTileUsingSCFForOpWithFilter(MLIRContext *context,
|
|
scf::SCFTilingOptions options,
|
|
linalg::LinalgTransformationFilter filter =
|
|
linalg::LinalgTransformationFilter(),
|
|
PatternBenefit benefit = 1)
|
|
: scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
|
|
|
|
/// Construct a generic pattern applied to `opName`.
|
|
TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
|
|
scf::SCFTilingOptions options,
|
|
linalg::LinalgTransformationFilter filter =
|
|
linalg::LinalgTransformationFilter(),
|
|
PatternBenefit benefit = 1)
|
|
: scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
|
|
|
|
LogicalResult matchAndRewrite(TilingInterface op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (failed(filter.checkAndNotify(rewriter, op)))
|
|
return failure();
|
|
|
|
auto tilingResult = returningMatchAndRewrite(op, rewriter);
|
|
if (failed(tilingResult)) {
|
|
return failure();
|
|
}
|
|
filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
linalg::LinalgTransformationFilter filter;
|
|
};
|
|
|
|
/// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern
|
|
/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
|
|
/// ops for iterating over the tiles) while using a `filter` to avoid recursive
|
|
/// application.
|
|
struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter
|
|
: public scf::TileConsumerAndFuseProducersUsingSCFForOp {
|
|
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
|
|
MLIRContext *context, scf::SCFTilingOptions options,
|
|
linalg::LinalgTransformationFilter filter =
|
|
linalg::LinalgTransformationFilter(),
|
|
PatternBenefit benefit = 1)
|
|
: scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options,
|
|
benefit),
|
|
filter(filter) {}
|
|
|
|
/// Construct a generic pattern applied to `opName`.
|
|
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
|
|
StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
|
|
linalg::LinalgTransformationFilter filter =
|
|
linalg::LinalgTransformationFilter(),
|
|
PatternBenefit benefit = 1)
|
|
: scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options,
|
|
benefit),
|
|
filter(filter) {}
|
|
|
|
LogicalResult matchAndRewrite(TilingInterface op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (failed(filter.checkAndNotify(rewriter, op)))
|
|
return failure();
|
|
|
|
auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter);
|
|
if (failed(tileAndFuseResult)) {
|
|
return failure();
|
|
}
|
|
filter.replaceLinalgTransformationFilter(
|
|
rewriter, tileAndFuseResult->tiledAndFusedOps.front());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
linalg::LinalgTransformationFilter filter;
|
|
};
|
|
|
|
/// Test pass for testing the use of `TilingInterface`.
|
|
struct TestTilingInterfacePass
|
|
: public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
|
|
|
|
TestTilingInterfacePass() = default;
|
|
TestTilingInterfacePass(const TestTilingInterfacePass &pass)
|
|
: PassWrapper(pass) {}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
|
|
tensor::TensorDialect>();
|
|
linalg::registerTilingInterfaceExternalModels(registry);
|
|
}
|
|
StringRef getArgument() const final { return "test-tiling-interface"; }
|
|
StringRef getDescription() const final {
|
|
return "Test tiling using TilingInterface";
|
|
}
|
|
|
|
Option<bool> testTiling{
|
|
*this, "tile-using-scf-for",
|
|
llvm::cl::desc(
|
|
"Test tiling using TilingInterface with scf.for operations"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> testTileConsumerAndFuseProducer{
|
|
*this, "tile-consumer-and-fuse-producer-using-scf-for",
|
|
llvm::cl::desc("Test tile and fuse transformation using TilingInterface "
|
|
"with scf.for operations"),
|
|
llvm::cl::init(false)};
|
|
|
|
void runOnOperation() override;
|
|
|
|
private:
|
|
void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns);
|
|
};
|
|
} // namespace
|
|
|
|
template <class Pattern>
|
|
static void
|
|
addPatternForTiling(MLIRContext *context, RewritePatternSet &patterns,
|
|
StringRef filterName, ArrayRef<int64_t> tileSizes,
|
|
ArrayRef<unsigned> interchange = {}) {
|
|
scf::SCFTilingOptions tilingOptions;
|
|
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
|
|
linalg::LinalgTransformationFilter filter(
|
|
StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
|
|
patterns.add<Pattern>(context, tilingOptions, filter);
|
|
}
|
|
|
|
void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
|
|
RewritePatternSet &patterns) {
|
|
if (testTiling) {
|
|
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
|
|
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
|
context, patterns, "simple_gemm", {10, 20});
|
|
// 2. Tiling M, N and K of `linalg.matmul` on buffers.
|
|
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
|
context, patterns, "simple_gemm_memref", {10, 20, 30});
|
|
// 3. Tiling 3D parallel generic op which implements a transpose
|
|
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
|
context, patterns, "parallel_generic_transpose", {10, 0, 20});
|
|
// 4. Tiling 2D conv op.
|
|
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
|
context, patterns, "simple_conv", {0, 0, 0, 0, 10, 20, 30});
|
|
// 5. Tiling a simple op with `linalg.index` inside.
|
|
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
|
context, patterns, "indexed_semantics", {10, 20});
|
|
// 6. Tiling + interchange of an operation
|
|
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
|
context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0});
|
|
return;
|
|
}
|
|
if (testTileConsumerAndFuseProducer) {
|
|
// 1. Tile and fuse of gemm with bias-add operation.
|
|
addPatternForTiling<
|
|
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
|
|
context, patterns, "fusion", {10, 20});
|
|
addPatternForTiling<
|
|
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
|
|
context, patterns, "gemm_fusion", {10});
|
|
addPatternForTiling<
|
|
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
|
|
context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0});
|
|
addPatternForTiling<
|
|
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
|
|
context, patterns, "gemm_plus_gemm_fusion", {10, 20});
|
|
addPatternForTiling<
|
|
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
|
|
context, patterns, "gemm_sequence_fusion", {10});
|
|
return;
|
|
}
|
|
}
|
|
|
|
void TestTilingInterfacePass::runOnOperation() {
|
|
MLIRContext *context = &getContext();
|
|
|
|
RewritePatternSet tilingPatterns(context);
|
|
addTestPatterns(context, tilingPatterns);
|
|
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
|
std::move(tilingPatterns))))
|
|
return signalPassFailure();
|
|
}
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestTilingInterface() {
|
|
PassRegistration<TestTilingInterfacePass>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|