This adds a new option `-enable-arm-streaming=if-contains-scalable-vectors`, which only applies the selected streaming/ZA modes if the function contains scalable vector types. As a NFC this patch also removes the `only-` prefix from the `if-required-by-ops` mode.
103 lines
3.6 KiB
C++
103 lines
3.6 KiB
C++
//===- TestLowerToArmSME.cpp - Test lowering to ArmSME as a sink pass -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass for testing the lowering to ArmSME as a
|
|
// generally usable sink pass.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
|
|
#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
|
|
#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
|
|
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
|
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
|
|
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
|
|
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
|
|
#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/IR/DialectRegistry.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Pass/PassOptions.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
struct TestLowerToArmSMEOptions
|
|
: public PassPipelineOptions<TestLowerToArmSMEOptions> {
|
|
PassOptions::Option<bool> fuseOuterProducts{
|
|
*this, "fuse-outer-products",
|
|
llvm::cl::desc("Fuse outer product operations via "
|
|
"'-arm-sme-outer-product-fusion' pass"),
|
|
llvm::cl::init(true)};
|
|
PassOptions::Option<bool> dumpTileLiveRanges{
|
|
*this, "dump-tile-live-ranges",
|
|
llvm::cl::desc("Dump the live ranges of SME tiles (for debugging)"),
|
|
llvm::cl::init(false)};
|
|
};
|
|
|
|
void buildTestLowerToArmSME(OpPassManager &pm,
|
|
const TestLowerToArmSMEOptions &options) {
|
|
// Legalize vector operations so they can be converted to ArmSME.
|
|
pm.addPass(arm_sme::createVectorLegalizationPass());
|
|
|
|
// Sprinkle some cleanups.
|
|
pm.addPass(createCanonicalizerPass());
|
|
pm.addPass(createCSEPass());
|
|
|
|
// Passes that convert operations on vectors to ArmSME operations.
|
|
|
|
// Convert Arith to ArmSME.
|
|
pm.addPass(createArithToArmSMEConversionPass());
|
|
// Convert Vector to ArmSME.
|
|
pm.addPass(createConvertVectorToArmSMEPass());
|
|
|
|
// Fuse outer products.
|
|
if (options.fuseOuterProducts)
|
|
pm.addPass(arm_sme::createOuterProductFusionPass());
|
|
|
|
// Convert operations on high-level vectors to loops.
|
|
|
|
// Convert ArmSME to SCF.
|
|
pm.addPass(createConvertArmSMEToSCFPass());
|
|
|
|
// Convert Vector to SCF (with full unroll enabled).
|
|
pm.addPass(createConvertVectorToSCFPass(
|
|
VectorTransferToSCFOptions().enableFullUnroll()));
|
|
|
|
// Enable streaming-mode and ZA.
|
|
pm.addPass(arm_sme::createEnableArmStreamingPass(
|
|
arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
|
|
/*ifRequiredByOps=*/true));
|
|
|
|
// Convert SCF to CF (required for ArmSME tile allocation).
|
|
pm.addPass(createConvertSCFToCFPass());
|
|
|
|
// Convert ArmSME to LLVM.
|
|
pm.addNestedPass<func::FuncOp>(
|
|
createConvertArmSMEToLLVMPass(options.dumpTileLiveRanges));
|
|
|
|
// Sprinkle some cleanups.
|
|
pm.addPass(createCanonicalizerPass());
|
|
pm.addPass(createCSEPass());
|
|
}
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestLowerToArmSME() {
|
|
PassPipelineRegistration<TestLowerToArmSMEOptions>(
|
|
"test-lower-to-arm-sme",
|
|
"An example pipeline to lower operations on vectors (arith, vector) to "
|
|
"LLVM via ArmSME.",
|
|
buildTestLowerToArmSME);
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|