Files
clang-p2996/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
Benjamin Maxwell d319fc41d0 [mlir][ArmSME] Add option to only enable streaming mode for scalable code (#94759)
This adds a new option
`-enable-arm-streaming=if-contains-scalable-vectors`, which only applies
the selected streaming/ZA modes if the function contains scalable vector
types.

As a NFC this patch also removes the `only-` prefix from the
`if-required-by-ops` mode.
2024-06-10 12:02:16 +01:00

103 lines
3.6 KiB
C++

//===- TestLowerToArmSME.cpp - Test lowering to ArmSME as a sink pass -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass for testing the lowering to ArmSME as a
// generally usable sink pass.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
namespace {
struct TestLowerToArmSMEOptions
: public PassPipelineOptions<TestLowerToArmSMEOptions> {
PassOptions::Option<bool> fuseOuterProducts{
*this, "fuse-outer-products",
llvm::cl::desc("Fuse outer product operations via "
"'-arm-sme-outer-product-fusion' pass"),
llvm::cl::init(true)};
PassOptions::Option<bool> dumpTileLiveRanges{
*this, "dump-tile-live-ranges",
llvm::cl::desc("Dump the live ranges of SME tiles (for debugging)"),
llvm::cl::init(false)};
};
void buildTestLowerToArmSME(OpPassManager &pm,
const TestLowerToArmSMEOptions &options) {
// Legalize vector operations so they can be converted to ArmSME.
pm.addPass(arm_sme::createVectorLegalizationPass());
// Sprinkle some cleanups.
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
// Passes that convert operations on vectors to ArmSME operations.
// Convert Arith to ArmSME.
pm.addPass(createArithToArmSMEConversionPass());
// Convert Vector to ArmSME.
pm.addPass(createConvertVectorToArmSMEPass());
// Fuse outer products.
if (options.fuseOuterProducts)
pm.addPass(arm_sme::createOuterProductFusionPass());
// Convert operations on high-level vectors to loops.
// Convert ArmSME to SCF.
pm.addPass(createConvertArmSMEToSCFPass());
// Convert Vector to SCF (with full unroll enabled).
pm.addPass(createConvertVectorToSCFPass(
VectorTransferToSCFOptions().enableFullUnroll()));
// Enable streaming-mode and ZA.
pm.addPass(arm_sme::createEnableArmStreamingPass(
arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
/*ifRequiredByOps=*/true));
// Convert SCF to CF (required for ArmSME tile allocation).
pm.addPass(createConvertSCFToCFPass());
// Convert ArmSME to LLVM.
pm.addNestedPass<func::FuncOp>(
createConvertArmSMEToLLVMPass(options.dumpTileLiveRanges));
// Sprinkle some cleanups.
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
}
} // namespace
namespace mlir {
namespace test {
void registerTestLowerToArmSME() {
PassPipelineRegistration<TestLowerToArmSMEOptions>(
"test-lower-to-arm-sme",
"An example pipeline to lower operations on vectors (arith, vector) to "
"LLVM via ArmSME.",
buildTestLowerToArmSME);
}
} // namespace test
} // namespace mlir