Files
clang-p2996/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
Benjamin Maxwell 042800a4dd [mlir][ArmSME] Add initial SME vector legalization pass (#79152)
This adds a new pass (`-arm-sme-vector-legalization`) which legalizes
vector operations so that they can be lowered to ArmSME. This initial
patch adds decomposition for `vector.outerproduct`,
`vector.transfer_read`, and `vector.transfer_write` when they operate on
vector types larger than a single SME tile. For example, a [8]x[8]xf32
outer product would be decomposed into four [4]x[4]xf32 outer products,
which could then be lowered to ArmSME. These three ops have been picked
as supporting them alone allows lowering matmuls that use all ZA
accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple
of an SME tile size, but other than that any shape can be used. E.g.
`vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>`
can all be lowered to four `vector<[4]x[4]xf32>` operations.

In future, this pass will be extended with more SME-specific rewrites to
legalize unrolling the reduction dimension of matmuls (which is not
type-decomposition), which is why the pass has quite a general name.
2024-01-31 11:55:22 +00:00

120 lines
4.1 KiB
C++

//===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities for the ArmSME dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
namespace mlir::arm_sme {
unsigned getSMETileSliceMinNumElts(Type type) {
assert(isValidSMETileElementType(type) && "invalid tile type!");
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
}
bool isValidSMETileElementType(Type type) {
return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
type.isInteger(64) || type.isInteger(128) || type.isF16() ||
type.isBF16() || type.isF32() || type.isF64() || type.isF128();
}
bool isValidSMETileVectorType(VectorType vType) {
if ((vType.getRank() != 2) || !vType.allDimsScalable())
return false;
auto elemType = vType.getElementType();
if (!isValidSMETileElementType(elemType))
return false;
unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
return false;
return true;
}
std::optional<ArmSMETileType> getSMETileType(VectorType type) {
if (!isValidSMETileVectorType(type))
return {};
switch (type.getElementTypeBitWidth()) {
case 8:
return ArmSMETileType::ZAB;
case 16:
return ArmSMETileType::ZAH;
case 32:
return ArmSMETileType::ZAS;
case 64:
return ArmSMETileType::ZAD;
case 128:
return ArmSMETileType::ZAQ;
default:
llvm_unreachable("unknown SME tile type");
}
}
LogicalResult verifyOperationHasValidTileId(Operation *op) {
auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
if (!tileOp)
return success(); // Not a tile op (no need to check).
auto tileId = tileOp.getTileId();
if (!tileId)
return success(); // Not having a tile ID (yet) is okay.
if (!tileId.getType().isSignlessInteger(32))
return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
return success();
}
scf::ForOp createLoopOverTileSlices(
PatternRewriter &rewriter, Location loc, Value initTile,
std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
OpBuilder::InsertionGuard g(rewriter);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
Value nextTile =
makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
/*currentTile=*/forOp.getRegionIterArg(0));
rewriter.create<scf::YieldOp>(loc, nextTile);
return forOp;
}
bool isMultipleOfSMETileVectorType(VectorType vType) {
if (vType.getRank() != 2 || !vType.allDimsScalable())
return false;
auto elementType = vType.getElementType();
if (!isValidSMETileElementType(elementType))
return false;
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
int64_t vectorRows = vType.getDimSize(0);
int64_t vectorCols = vType.getDimSize(1);
return (vectorRows > minNumElts || vectorCols > minNumElts) &&
vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
}
VectorType getSMETileTypeForElement(Type elementType) {
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
}
} // namespace mlir::arm_sme