This adds a new pass (`-arm-sme-vector-legalization`) which legalizes vector operations so that they can be lowered to ArmSME. This initial patch adds decomposition for `vector.outerproduct`, `vector.transfer_read`, and `vector.transfer_write` when they operate on vector types larger than a single SME tile. For example, a [8]x[8]xf32 outer product would be decomposed into four [4]x[4]xf32 outer products, which could then be lowered to ArmSME. These three ops have been picked as supporting them alone allows lowering matmuls that use all ZA accumulators to ArmSME. For it to be possible to legalize a vector type it has to be a multiple of an SME tile size, but other than that any shape can be used. E.g. `vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>` can all be lowered to four `vector<[4]x[4]xf32>` operations. In future, this pass will be extended with more SME-specific rewrites to legalize unrolling the reduction dimension of matmuls (which is not type-decomposition), which is why the pass has quite a general name.
120 lines
4.1 KiB
C++
120 lines
4.1 KiB
C++
//===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements utilities for the ArmSME dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
|
|
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
|
|
|
|
namespace mlir::arm_sme {
|
|
|
|
unsigned getSMETileSliceMinNumElts(Type type) {
|
|
assert(isValidSMETileElementType(type) && "invalid tile type!");
|
|
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
|
|
}
|
|
|
|
bool isValidSMETileElementType(Type type) {
|
|
return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
|
|
type.isInteger(64) || type.isInteger(128) || type.isF16() ||
|
|
type.isBF16() || type.isF32() || type.isF64() || type.isF128();
|
|
}
|
|
|
|
bool isValidSMETileVectorType(VectorType vType) {
|
|
if ((vType.getRank() != 2) || !vType.allDimsScalable())
|
|
return false;
|
|
|
|
auto elemType = vType.getElementType();
|
|
if (!isValidSMETileElementType(elemType))
|
|
return false;
|
|
|
|
unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
|
|
if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
std::optional<ArmSMETileType> getSMETileType(VectorType type) {
|
|
if (!isValidSMETileVectorType(type))
|
|
return {};
|
|
switch (type.getElementTypeBitWidth()) {
|
|
case 8:
|
|
return ArmSMETileType::ZAB;
|
|
case 16:
|
|
return ArmSMETileType::ZAH;
|
|
case 32:
|
|
return ArmSMETileType::ZAS;
|
|
case 64:
|
|
return ArmSMETileType::ZAD;
|
|
case 128:
|
|
return ArmSMETileType::ZAQ;
|
|
default:
|
|
llvm_unreachable("unknown SME tile type");
|
|
}
|
|
}
|
|
|
|
LogicalResult verifyOperationHasValidTileId(Operation *op) {
|
|
auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
|
|
if (!tileOp)
|
|
return success(); // Not a tile op (no need to check).
|
|
auto tileId = tileOp.getTileId();
|
|
if (!tileId)
|
|
return success(); // Not having a tile ID (yet) is okay.
|
|
if (!tileId.getType().isSignlessInteger(32))
|
|
return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
|
|
return success();
|
|
}
|
|
|
|
scf::ForOp createLoopOverTileSlices(
|
|
PatternRewriter &rewriter, Location loc, Value initTile,
|
|
std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
|
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
|
|
loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
|
|
auto vscale =
|
|
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
|
|
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
auto numTileSlices =
|
|
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
|
|
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
|
|
ValueRange{initTile});
|
|
rewriter.setInsertionPointToStart(forOp.getBody());
|
|
Value nextTile =
|
|
makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
|
|
/*currentTile=*/forOp.getRegionIterArg(0));
|
|
rewriter.create<scf::YieldOp>(loc, nextTile);
|
|
return forOp;
|
|
}
|
|
|
|
bool isMultipleOfSMETileVectorType(VectorType vType) {
|
|
if (vType.getRank() != 2 || !vType.allDimsScalable())
|
|
return false;
|
|
|
|
auto elementType = vType.getElementType();
|
|
if (!isValidSMETileElementType(elementType))
|
|
return false;
|
|
|
|
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
|
|
|
|
int64_t vectorRows = vType.getDimSize(0);
|
|
int64_t vectorCols = vType.getDimSize(1);
|
|
|
|
return (vectorRows > minNumElts || vectorCols > minNumElts) &&
|
|
vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
|
|
}
|
|
|
|
VectorType getSMETileTypeForElement(Type elementType) {
|
|
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
|
|
return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
|
|
}
|
|
|
|
} // namespace mlir::arm_sme
|