Files
clang-p2996/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Cullen Rhodes 564713c471 [mlir][ArmSME] Add basic lowering of vector.transfer_write to zero
This patch adds support for lowering a 'vector.transfer_write' of zeroes
and type 'vector<[16x16]xi8>' to the SME 'zero {za}' instruction [1],
which zeroes the entire accumulator, and then writing it out to memory
with the 'str' instruction [2].

This contributes to supporting a path from 'linalg.fill' to SME.

[1] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
[2] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/STR--Store-vector-from-ZA-array-

Reviewed By: awarzynski, dcaballe, WanderAway

Differential Revision: https://reviews.llvm.org/D152508
2023-07-03 10:18:43 +00:00

83 lines
3.3 KiB
C++

//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
using namespace mlir;
using namespace mlir::arm_sme;
namespace {
/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
/// ops to enable the ZA storage array.
struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter &rewriter) const final {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(&op.front());
rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
rewriter.updateRootInPlace(op, [] {});
return success();
}
};
/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
/// disable the ZA storage array.
struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(func::ReturnOp op,
PatternRewriter &rewriter) const final {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
rewriter.updateRootInPlace(op, [] {});
return success();
}
};
} // namespace
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
}
void mlir::configureArmSMELegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::aarch64_sme_zero,
arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();
// Mark 'func.func' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
// 2. the 'arm_za' function attribute is present and the first op in the
// function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
if (funcOp.isDeclaration())
return true;
auto firstOp = funcOp.getBody().front().begin();
return !funcOp->hasAttr("arm_za") ||
isa<arm_sme::aarch64_sme_za_enable>(firstOp);
});
// Mark 'func.return' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
// 2. the 'arm_za' function attribute is present and there's a preceding
// 'arm_sme::aarch64_sme_za_disable' intrinsic.
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
bool hasDisableZA = false;
auto funcOp = returnOp->getParentOp();
funcOp->walk<WalkOrder::PreOrder>(
[&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
return !funcOp->hasAttr("arm_za") || hasDisableZA;
});
}