Files
clang-p2996/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Hsiangkai Wang 27ee33d136 [mlir][linalg] Decompose winograd operators (#96183)
Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)

Reviewers: ftynse, Max191, GeorgeARM, nicolasvasilache, MaheshRavishankar, dcaballe, rengolin

Reviewed By: ftynse, Max191

Pull Request: https://github.com/llvm/llvm-project/pull/96183
2024-07-18 06:04:53 +01:00

267 lines
11 KiB
C++

//===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements logic for testing Linalg transformations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
using namespace mlir::linalg;
namespace {
struct TestLinalgTransforms
: public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms)
TestLinalgTransforms() = default;
TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
// clang-format off
registry.insert<affine::AffineDialect,
bufferization::BufferizationDialect,
memref::MemRefDialect,
scf::SCFDialect,
linalg::LinalgDialect,
vector::VectorDialect,
gpu::GPUDialect>();
// clang-format on
}
StringRef getArgument() const final {
return "test-linalg-transform-patterns";
}
StringRef getDescription() const final {
return "Test Linalg transformation patterns by applying them greedily.";
}
void runOnOperation() override;
Option<bool> testPatterns{*this, "test-patterns",
llvm::cl::desc("Test a mixed set of patterns"),
llvm::cl::init(false)};
Option<bool> testVectorTransferForwardingPatterns{
*this, "test-vector-transfer-forwarding-patterns",
llvm::cl::desc(
"Test a fused pass that forwards memref.copy to vector.transfer"),
llvm::cl::init(false)};
Option<bool> testGenericToVectorPattern{
*this, "test-linalg-to-vector-patterns",
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
"in vector.contract form"),
llvm::cl::init(false)};
Option<bool> testGeneralizePadTensor{
*this, "test-generalize-pad-tensor",
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
llvm::cl::init(false)};
Option<bool> testGeneralizeTensorPackOp{
*this, "test-generalize-tensor-pack",
llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
"of tensor and Linalg ops"),
llvm::cl::init(false)};
Option<bool> testGeneralizeTensorUnPackOp{
*this, "test-generalize-tensor-unpack",
llvm::cl::desc(
"Test transform that generalizes unpack ops into a sequence "
"of tensor and Linalg ops"),
llvm::cl::init(false)};
Option<bool> testSwapSubTensorPadTensor{
*this, "test-swap-subtensor-padtensor",
llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
"tensor.pad(subtensor)"),
llvm::cl::init(false)};
ListOption<int64_t> peeledLoops{
*this, "peeled-loops",
llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
ListOption<int64_t> tileSizes{
*this, "tile-sizes",
llvm::cl::desc("Linalg tile sizes for test-tile-pattern")};
Option<bool> skipPartial{
*this, "skip-partial",
llvm::cl::desc("Skip loops inside partial iterations during peeling"),
llvm::cl::init(false)};
Option<std::string> loopType{
*this, "loop-type",
llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
"tiled_loop"),
llvm::cl::init("for")};
Option<bool> testBubbleUpExtractSliceOpPattern{
*this, "test-bubble-up-extract-slice-op-pattern",
llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
"extract_slice + linalgOp"),
llvm::cl::init(false)};
Option<bool> testSwapExtractSliceWithFill{
*this, "test-swap-extract-slice-with-fill-pattern",
llvm::cl::desc(
"Test patterns to swap tensor.extract_slice(linalg.fill())"),
llvm::cl::init(false)};
Option<bool> testEraseUnusedOperandsAndResults{
*this, "test-erase-unused-operands-and-results",
llvm::cl::desc("Test patterns to erase unused operands and results"),
llvm::cl::init(false)};
Option<bool> testEraseUnnecessaryInputs{
*this, "test-erase-unnecessary-inputs",
llvm::cl::desc("Test patterns to erase unnecessary inputs"),
llvm::cl::init(false)};
Option<bool> testWinogradConv2D{
*this, "test-winograd-conv2d",
llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
llvm::cl::init(false)};
Option<bool> testDecomposeWinogradOps{
*this, "test-decompose-winograd-ops",
llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
};
} // namespace
static void applyPatterns(func::FuncOp funcOp) {
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(ctx);
//===--------------------------------------------------------------------===//
// Linalg distribution patterns.
//===--------------------------------------------------------------------===//
LinalgLoopDistributionOptions distributionOptions;
//===--------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===--------------------------------------------------------------------===//
patterns.add<CopyVectorizationPattern>(ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
RewritePatternSet forwardPattern(funcOp.getContext());
forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
}
static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
auto *ctx = funcOp.getContext();
patterns.add<CopyVectorizationPattern>(ctx);
populatePadOpVectorizationPatterns(patterns);
populateConvolutionVectorizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyGeneralizeTensorPackPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyGeneralizeTensorUnPackPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<GeneralizeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateBubbleUpExtractSliceOpPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateSwapExtractSliceWithFillPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateEraseUnusedOperandsAndResultsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateEraseUnnecessaryInputsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyWinogradConv2D(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateDecomposeWinogradOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
if (testPatterns)
return applyPatterns(getOperation());
if (testVectorTransferForwardingPatterns)
return applyVectorTransferForwardingPatterns(getOperation());
if (testGenericToVectorPattern)
return applyLinalgToVectorPatterns(getOperation());
if (testGeneralizePadTensor)
return applyGeneralizePadTensorPatterns(getOperation());
if (testGeneralizeTensorPackOp)
return applyGeneralizeTensorPackPatterns(getOperation());
if (testGeneralizeTensorUnPackOp)
return applyGeneralizeTensorUnPackPatterns(getOperation());
if (testSwapSubTensorPadTensor)
return applyExtractSliceOfPadTensorSwapPattern(getOperation());
if (testBubbleUpExtractSliceOpPattern)
return applyBubbleUpExtractSliceOpPattern(getOperation());
if (testSwapExtractSliceWithFill)
return applySwapExtractSliceWithFillPattern(getOperation());
if (testEraseUnusedOperandsAndResults)
return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
if (testEraseUnnecessaryInputs)
return applyEraseUnnecessaryInputs(getOperation());
if (testWinogradConv2D)
return applyWinogradConv2D(getOperation());
if (testDecomposeWinogradOps)
return applyDecomposeWinogradOps(getOperation());
}
namespace mlir {
namespace test {
void registerTestLinalgTransforms() {
PassRegistration<TestLinalgTransforms>();
}
} // namespace test
} // namespace mlir