This patch adds support for lowering a 'vector.transfer_write' of zeroes
and type 'vector<[16x16]xi8>' to the SME 'zero {za}' instruction [1],
which zeroes the entire accumulator, and then writing it out to memory
with the 'str' instruction [2].
This contributes to supporting a path from 'linalg.fill' to SME.
[1] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
[2] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/STR--Store-vector-from-ZA-array-
Reviewed By: awarzynski, dcaballe, WanderAway
Differential Revision: https://reviews.llvm.org/D152508
83 lines
3.3 KiB
C++
83 lines
3.3 KiB
C++
//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
|
|
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::arm_sme;
|
|
|
|
namespace {
|
|
/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
|
|
/// ops to enable the ZA storage array.
|
|
struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(func::FuncOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPointToStart(&op.front());
|
|
rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
|
|
rewriter.updateRootInPlace(op, [] {});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
|
|
/// disable the ZA storage array.
|
|
struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(func::ReturnOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(op);
|
|
rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
|
|
rewriter.updateRootInPlace(op, [] {});
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
|
|
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
|
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
|
|
}
|
|
|
|
void mlir::configureArmSMELegalizeForExportTarget(
|
|
LLVMConversionTarget &target) {
|
|
target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::aarch64_sme_zero,
|
|
arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_za_enable,
|
|
arm_sme::aarch64_sme_za_disable>();
|
|
|
|
// Mark 'func.func' ops as legal if either:
|
|
// 1. no 'arm_za' function attribute is present.
|
|
// 2. the 'arm_za' function attribute is present and the first op in the
|
|
// function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
|
|
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
|
|
if (funcOp.isDeclaration())
|
|
return true;
|
|
auto firstOp = funcOp.getBody().front().begin();
|
|
return !funcOp->hasAttr("arm_za") ||
|
|
isa<arm_sme::aarch64_sme_za_enable>(firstOp);
|
|
});
|
|
|
|
// Mark 'func.return' ops as legal if either:
|
|
// 1. no 'arm_za' function attribute is present.
|
|
// 2. the 'arm_za' function attribute is present and there's a preceding
|
|
// 'arm_sme::aarch64_sme_za_disable' intrinsic.
|
|
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
|
|
bool hasDisableZA = false;
|
|
auto funcOp = returnOp->getParentOp();
|
|
funcOp->walk<WalkOrder::PreOrder>(
|
|
[&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
|
|
return !funcOp->hasAttr("arm_za") || hasDisableZA;
|
|
});
|
|
}
|