//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" using namespace mlir; using namespace mlir::arm_sme; namespace { /// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func' /// ops to enable the ZA storage array. struct EnableZAPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(func::FuncOp op, PatternRewriter &rewriter) const final { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointToStart(&op.front()); rewriter.create(op->getLoc()); rewriter.updateRootInPlace(op, [] {}); return success(); } }; /// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to /// disable the ZA storage array. struct DisableZAPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(func::ReturnOp op, PatternRewriter &rewriter) const final { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); rewriter.create(op->getLoc()); rewriter.updateRootInPlace(op, [] {}); return success(); } }; } // namespace void mlir::populateArmSMELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } void mlir::configureArmSMELegalizeForExportTarget( LLVMConversionTarget &target) { target.addLegalOp(); // Mark 'func.func' ops as legal if either: // 1. no 'arm_za' function attribute is present. // 2. the 'arm_za' function attribute is present and the first op in the // function is an 'arm_sme::aarch64_sme_za_enable' intrinsic. target.addDynamicallyLegalOp([&](func::FuncOp funcOp) { if (funcOp.isDeclaration()) return true; auto firstOp = funcOp.getBody().front().begin(); return !funcOp->hasAttr("arm_za") || isa(firstOp); }); // Mark 'func.return' ops as legal if either: // 1. no 'arm_za' function attribute is present. // 2. the 'arm_za' function attribute is present and there's a preceding // 'arm_sme::aarch64_sme_za_disable' intrinsic. target.addDynamicallyLegalOp([&](func::ReturnOp returnOp) { bool hasDisableZA = false; auto funcOp = returnOp->getParentOp(); funcOp->walk( [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; }); return !funcOp->hasAttr("arm_za") || hasDisableZA; }); }