The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures. 1:1 target materializations (producing a single `Value`) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing a `SmallVector<Value>`). Internally, all target materializations are stored as 1:N materializations. The 1:N type converter is removed. Note for LLVM integration: If you are using the `OneToNTypeConverter`, simply switch all occurrences to `TypeConverter`. --------- Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
967 lines
41 KiB
C++
967 lines
41 KiB
C++
//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This pass legalizes vector operations so they can be lowered to ArmSME.
|
|
//
|
|
// Note: In the context of this pass 'tile' always refers to an SME tile.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
|
|
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
|
|
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
|
|
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
|
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
|
#include "mlir/Transforms/OneToNTypeConversion.h"
|
|
|
|
#define DEBUG_TYPE "arm-sme-vector-legalization"
|
|
|
|
namespace mlir::arm_sme {
|
|
#define GEN_PASS_DEF_VECTORLEGALIZATION
|
|
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
|
|
} // namespace mlir::arm_sme
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::arm_sme;
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Decomposition of vector operations larger than an SME tile
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Common match failure reasons.
|
|
static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
|
|
"op vector size is not multiple of SME tiles");
|
|
static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
|
|
"op mask is unsupported for legalization/decomposition");
|
|
static constexpr StringLiteral
|
|
kMatchFailureNonPermutationMap("op affine map is not a permutation");
|
|
static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
|
|
"expected transpose from illegal type to legal type");
|
|
|
|
/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
|
|
/// larger vector type. The (`row`, `col`) are the position of the tile in the
|
|
/// original vector type. For example for an [8]x[8] tile with four [4]x[4]
|
|
/// sub-tiles, we would have:
|
|
///
|
|
/// 8 x vscale
|
|
/// ┌─────────────┬─────────────┐
|
|
/// │(0,0) │(0,4) │
|
|
/// │ │ │
|
|
/// ├─────────────┼─────────────┤ 8 x vscale
|
|
/// │(4,0) │(4,4) │
|
|
/// │ │ │
|
|
/// └─────────────┴─────────────┘
|
|
struct SMESubTile {
|
|
// Note: The units of (row, col) are vscale (as SME tiles are scalable).
|
|
int row{0};
|
|
int col{0};
|
|
// The SME tile type.
|
|
VectorType type;
|
|
};
|
|
|
|
/// Adds a constant elementwise scalable offset to `indices` (which are of equal
|
|
/// length). For example, in the 2D case this would return:
|
|
// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
|
|
SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
|
|
Location loc,
|
|
ValueRange indices,
|
|
ArrayRef<int> scalableOffsets) {
|
|
auto vscale = builder.create<vector::VectorScaleOp>(loc);
|
|
return llvm::map_to_vector(
|
|
llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
|
|
auto [index, base] = pair;
|
|
auto offset = builder.create<arith::MulIOp>(
|
|
loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
|
|
return builder.create<arith::AddIOp>(loc, index, offset);
|
|
});
|
|
}
|
|
|
|
/// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
|
|
/// indices for one of the SME sub-tiles it will decompose into.
|
|
///
|
|
/// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
|
|
/// indices for each tile would need to be adjusted as follows:
|
|
///
|
|
/// initial indices = [a,b], inital size = 8x8, target size = 4x4
|
|
/// ┌─────────────┬─────────────┐
|
|
/// │[a,b] │[a,b+4] │
|
|
/// │ │ │
|
|
/// ├─────────────┼─────────────┤
|
|
/// │[a+4,b] │[a+4,b+4] │
|
|
/// │ │ │
|
|
/// └─────────────┴─────────────┘
|
|
SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
|
|
ValueRange indices,
|
|
SMESubTile smeTile) {
|
|
return addConstantScalableOffset(builder, loc, indices,
|
|
{smeTile.row, smeTile.col});
|
|
}
|
|
|
|
/// Returns true if `mask` is generated by an operation that can be decomposed
|
|
/// for SME. Currently, that is just no mask, or vector.create_mask.
|
|
/// TODO: Add support for vector.constant_mask once required for SME.
|
|
bool isSupportedMaskOp(Value mask) {
|
|
return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
|
|
}
|
|
|
|
/// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
|
|
Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
|
|
SMESubTile smeTile) {
|
|
assert(isSupportedMaskOp(mask));
|
|
if (!mask)
|
|
return Value{};
|
|
auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
|
|
// The operands of `vector.create_mask` (from a 2D perspective) are the
|
|
// coordinates where the mask ends. So we subtract where this tile starts,
|
|
// from the mask operands to get the parameters for this sub-tile.
|
|
auto smeTileMaskDims = addConstantScalableOffset(
|
|
builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
|
|
auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
|
|
loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
|
|
return smeTileCreateMask.getResult();
|
|
}
|
|
|
|
/// Constructs an iterator that returns each SME tile (with coordinates)
|
|
/// contained within a VectorType. For example, if decomposing an [8]x[8] into
|
|
/// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
|
|
/// (4, 4).
|
|
auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
|
|
VectorType smeTileType,
|
|
bool transposeIndices = false) {
|
|
return llvm::map_range(
|
|
StaticTileOffsetRange(
|
|
type.getShape(),
|
|
{std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
|
|
std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
|
|
[=](auto indices) {
|
|
int row = int(indices[0]);
|
|
int col = int(indices[1]);
|
|
if (transposeIndices)
|
|
std::swap(row, col);
|
|
return SMESubTile{row, col, smeTileType};
|
|
});
|
|
}
|
|
|
|
/// Returns the number of SME tiles that fit into the (2D-scalable) vector type
|
|
/// `type`.
|
|
int getNumberOfSMETilesForVectorType(VectorType type) {
|
|
assert(isMultipleOfSMETileVectorType(type) &&
|
|
"`type` not multiple of SME tiles");
|
|
int64_t vectorRows = type.getDimSize(0);
|
|
int64_t vectorCols = type.getDimSize(1);
|
|
auto elementType = type.getElementType();
|
|
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
|
|
return (vectorRows * vectorCols) / (minNumElts * minNumElts);
|
|
}
|
|
|
|
/// Legalize `arith.constant dense<value>` splat operations to fit within SME
|
|
/// tiles by decomposing them into tile-sized operations.
|
|
struct LegalizeArithConstantOpsByDecomposition
|
|
: public OneToNOpConversionPattern<arith::ConstantOp> {
|
|
using OneToNOpConversionPattern::OneToNOpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
auto vectorType = dyn_cast<VectorType>(constantOp.getType());
|
|
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
|
|
if (!vectorType || !denseAttr || !denseAttr.isSplat())
|
|
return failure();
|
|
|
|
if (!isMultipleOfSMETileVectorType(vectorType))
|
|
return rewriter.notifyMatchFailure(constantOp,
|
|
kMatchFailureNotSMETileTypeMultiple);
|
|
|
|
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
|
|
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
|
|
auto tileSplat = rewriter.create<arith::ConstantOp>(
|
|
constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
|
|
rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
|
|
adaptor.getResultMapping());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
|
|
/// decomposing them into tile-sized operations.
|
|
struct LegalizeVectorOuterProductOpsByDecomposition
|
|
: public OneToNOpConversionPattern<vector::OuterProductOp> {
|
|
using OneToNOpConversionPattern::OneToNOpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
auto vectorType = outerProductOp.getResultVectorType();
|
|
if (!isMultipleOfSMETileVectorType(vectorType))
|
|
return rewriter.notifyMatchFailure(outerProductOp,
|
|
kMatchFailureNotSMETileTypeMultiple);
|
|
|
|
Value mask;
|
|
Operation *rootOp = outerProductOp;
|
|
auto loc = outerProductOp.getLoc();
|
|
if (outerProductOp.isMasked()) {
|
|
auto maskOp = outerProductOp.getMaskingOp();
|
|
mask = maskOp.getMask();
|
|
rootOp = maskOp;
|
|
}
|
|
|
|
if (!isSupportedMaskOp(mask))
|
|
return rewriter.notifyMatchFailure(outerProductOp,
|
|
kMatchFailureUnsupportedMaskOp);
|
|
|
|
ValueRange accSMETiles = adaptor.getAcc();
|
|
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
|
|
VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
|
|
|
|
SmallVector<Value> resultSMETiles;
|
|
for (auto [index, smeTile] : llvm::enumerate(
|
|
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
|
|
|
|
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
|
|
auto lhs = rewriter.create<vector::ScalableExtractOp>(
|
|
loc, sliceType, outerProductOp.getLhs(), smeTile.row);
|
|
auto rhs = rewriter.create<vector::ScalableExtractOp>(
|
|
loc, sliceType, outerProductOp.getRhs(), smeTile.col);
|
|
auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
|
|
loc, smeTileType, lhs, rhs,
|
|
!accSMETiles.empty() ? accSMETiles[index] : Value{},
|
|
outerProductOp.getKind());
|
|
|
|
auto maskedOuterProduct =
|
|
vector::maskOperation(rewriter, smeOuterProduct, smeMask);
|
|
resultSMETiles.push_back(maskedOuterProduct->getResult(0));
|
|
}
|
|
|
|
rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
|
|
// get the help of the type conversion), but doing so results in the type
|
|
// conversion adding target materializations in the `vector.mask` region
|
|
// (invalid). This pattern matches on `vector.mask` then calls into the
|
|
// `vector.outerproduct` pattern to work around this issue.
|
|
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
|
|
: public OneToNOpConversionPattern<vector::MaskOp> {
|
|
using OneToNOpConversionPattern::OneToNOpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
if (auto outerProductOp =
|
|
llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
|
|
LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
|
|
getContext());
|
|
return static_cast<RewritePattern &>(pattern).matchAndRewrite(
|
|
outerProductOp, rewriter);
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Legalize `vector.transfer_read` operations to fit within SME tiles by
|
|
/// decomposing them into tile-sized operations.
|
|
struct LegalizeTransferReadOpsByDecomposition
|
|
: public OneToNOpConversionPattern<vector::TransferReadOp> {
|
|
using OneToNOpConversionPattern::OneToNOpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
auto vectorType = readOp.getVectorType();
|
|
if (!isMultipleOfSMETileVectorType(vectorType))
|
|
return rewriter.notifyMatchFailure(readOp,
|
|
kMatchFailureNotSMETileTypeMultiple);
|
|
|
|
auto mask = readOp.getMask();
|
|
if (!isSupportedMaskOp(mask))
|
|
return rewriter.notifyMatchFailure(readOp,
|
|
kMatchFailureUnsupportedMaskOp);
|
|
|
|
auto permutationMap = readOp.getPermutationMap();
|
|
if (!permutationMap.isPermutation())
|
|
return rewriter.notifyMatchFailure(readOp,
|
|
kMatchFailureNonPermutationMap);
|
|
|
|
// Note: For 2D vector types the only non-identity permutation is a simple
|
|
// tranpose [1, 0].
|
|
bool transposed = !permutationMap.isIdentity();
|
|
|
|
auto loc = readOp.getLoc();
|
|
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
|
|
|
|
SmallVector<Value> resultSMETiles;
|
|
for (SMESubTile smeTile :
|
|
decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
|
|
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
|
|
auto smeRead = rewriter.create<vector::TransferReadOp>(
|
|
loc, smeTileType, readOp.getSource(),
|
|
getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
|
|
readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
|
|
readOp.getInBoundsAttr());
|
|
resultSMETiles.push_back(smeRead);
|
|
}
|
|
|
|
rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Legalize `vector.transfer_write` operations to fit within SME tiles by
|
|
/// decomposing them into tile-sized operations.
|
|
struct LegalizeTransferWriteOpsByDecomposition
|
|
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
|
|
using OneToNOpConversionPattern::OneToNOpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
auto vectorType = writeOp.getVectorType();
|
|
if (!isMultipleOfSMETileVectorType(vectorType))
|
|
return rewriter.notifyMatchFailure(writeOp,
|
|
kMatchFailureNotSMETileTypeMultiple);
|
|
|
|
auto mask = writeOp.getMask();
|
|
if (!isSupportedMaskOp(mask))
|
|
return rewriter.notifyMatchFailure(writeOp,
|
|
kMatchFailureUnsupportedMaskOp);
|
|
|
|
auto permutationMap = writeOp.getPermutationMap();
|
|
if (!permutationMap.isPermutation())
|
|
return rewriter.notifyMatchFailure(writeOp,
|
|
kMatchFailureNonPermutationMap);
|
|
|
|
// Note: For 2D vector types the only non-identity permutation is a simple
|
|
// tranpose [1, 0].
|
|
bool transposed = !permutationMap.isIdentity();
|
|
|
|
auto loc = writeOp.getLoc();
|
|
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
|
|
auto inputSMETiles = adaptor.getVector();
|
|
|
|
Value destTensorOrMemref = writeOp.getSource();
|
|
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
|
|
rewriter, vectorType, smeTileType, transposed))) {
|
|
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
|
|
auto smeWrite = rewriter.create<vector::TransferWriteOp>(
|
|
loc, inputSMETiles[index], destTensorOrMemref,
|
|
getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
|
|
writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
|
|
if (writeOp.hasPureTensorSemantics())
|
|
destTensorOrMemref = smeWrite.getResult();
|
|
}
|
|
|
|
if (writeOp.hasPureTensorSemantics())
|
|
rewriter.replaceOp(writeOp, destTensorOrMemref);
|
|
else
|
|
rewriter.eraseOp(writeOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Legalize a multi-tile transfer_write as a single store loop. This is done as
|
|
/// part of type decomposition as at this level we know each tile write is
|
|
/// disjoint, but that information is lost after decomposition (without analysis
|
|
/// to reconstruct it).
|
|
///
|
|
/// Example (pseudo-MLIR):
|
|
///
|
|
/// ```
|
|
/// vector.transfer_write %vector, %dest[%y, %x], %mask
|
|
/// : vector<[16]x[8]xi16>, memref<?x?xi16>
|
|
/// ```
|
|
/// Is rewritten to:
|
|
/// ```
|
|
/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
|
|
/// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
|
|
/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
|
|
/// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
|
|
/// : vector<[8]xi16> from vector<[8]x[8]xi16> |
|
|
/// vector.transfer_write %upper_slice, |
|
|
/// %dest[%slice_idx + %y, %x], %upper_slice_mask |
|
|
/// : vector<[8]xi16>, memref<?x?xi16> ┘
|
|
/// %lower_slice_idx = %slice_idx + %c8_vscale ─┐
|
|
/// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] |
|
|
/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
|
|
/// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower
|
|
/// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile
|
|
/// vector.transfer_write %lower_slice, |
|
|
/// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask |
|
|
/// : vector<[8]xi16>, memref<?x?xi16> ┘
|
|
/// }
|
|
/// ```
|
|
struct LegalizeMultiTileTransferWriteAsStoreLoop
|
|
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
|
|
using OneToNOpConversionPattern::OneToNOpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
|
|
OneToNPatternRewriter &rewriter) const override {
|
|
if (writeOp.hasPureTensorSemantics())
|
|
return rewriter.notifyMatchFailure(
|
|
writeOp, "TODO: tensor semantics are unsupported");
|
|
|
|
auto permutationMap = writeOp.getPermutationMap();
|
|
if (!permutationMap.isPermutation())
|
|
return rewriter.notifyMatchFailure(writeOp,
|
|
kMatchFailureNonPermutationMap);
|
|
|
|
bool transposed = !permutationMap.isIdentity();
|
|
if (transposed)
|
|
return rewriter.notifyMatchFailure(writeOp,
|
|
"TODO: transpose unsupported");
|
|
|
|
auto vectorType = writeOp.getVectorType();
|
|
if (!isMultipleOfSMETileVectorType(vectorType))
|
|
return rewriter.notifyMatchFailure(writeOp,
|
|
kMatchFailureNotSMETileTypeMultiple);
|
|
|
|
// Note: We also disallow masks where any dimension is > 16 because that
|
|
// prevents the masking from being lowered to use arm_sve.psel.
|
|
auto mask = writeOp.getMask();
|
|
if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
|
|
vectorType.getDimSize(1) > 16)))
|
|
return rewriter.notifyMatchFailure(writeOp,
|
|
kMatchFailureUnsupportedMaskOp);
|
|
|
|
auto loc = writeOp.getLoc();
|
|
auto createVscaleMultiple =
|
|
vector::makeVscaleConstantBuilder(rewriter, loc);
|
|
|
|
// Get SME tile and slice types.
|
|
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
|
|
auto minTileSlices = smeTileType.getDimSize(0);
|
|
VectorType sliceMaskType =
|
|
VectorType::get(minTileSlices, rewriter.getI1Type(), true);
|
|
|
|
// Create loop over all tile slices.
|
|
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
auto upperBound = createVscaleMultiple(minTileSlices);
|
|
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
|
auto storeLoop =
|
|
rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
|
|
rewriter.setInsertionPointToStart(storeLoop.getBody());
|
|
|
|
// For each sub-tile of the multi-tile `vectorType`.
|
|
auto inputSMETiles = adaptor.getVector();
|
|
auto tileSliceIndex = storeLoop.getInductionVar();
|
|
for (auto [index, smeTile] : llvm::enumerate(
|
|
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
|
|
// The coordinates of the tile within `vectorType`.
|
|
auto tileRow = createVscaleMultiple(smeTile.row);
|
|
auto tileCol = createVscaleMultiple(smeTile.col);
|
|
|
|
// The current slice of `vectorType` we are processing.
|
|
auto sliceIndex =
|
|
rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
|
|
|
|
// Where in the destination memref the current slice will be stored.
|
|
auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
|
|
writeOp.getIndices()[0]);
|
|
auto storeCol =
|
|
rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
|
|
|
|
// Extract the mask for the current slice.
|
|
Value sliceMask = nullptr;
|
|
if (mask) {
|
|
sliceMask = rewriter.create<vector::ExtractOp>(
|
|
loc, mask, OpFoldResult(sliceIndex));
|
|
if (sliceMaskType != sliceMask.getType())
|
|
sliceMask = rewriter.create<vector::ScalableExtractOp>(
|
|
loc, sliceMaskType, sliceMask, smeTile.col);
|
|
}
|
|
|
|
// Extract and store the current slice.
|
|
Value tile = inputSMETiles[index];
|
|
auto slice =
|
|
rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
|
|
rewriter.create<vector::TransferWriteOp>(
|
|
loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
|
|
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
|
|
sliceMask,
|
|
rewriter.getBoolArrayAttr(
|
|
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
|
|
}
|
|
|
|
rewriter.eraseOp(writeOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ArmSME-specific fixup canonicalizations/folds
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
|
|
/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
|
|
/// necessary for the mask to be lowered to ArmSME.
|
|
///
|
|
/// Example:
|
|
///
|
|
/// BEFORE:
|
|
/// ```mlir
|
|
/// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
|
|
/// %subMask = vector.extract %mask[2]
|
|
/// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
|
|
/// ```
|
|
///
|
|
/// AFTER:
|
|
/// ```mlir
|
|
/// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
|
|
/// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
|
|
/// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
|
|
/// ```
|
|
struct FoldExtractFromVectorOfSMELikeCreateMasks
|
|
: public OpRewritePattern<vector::ExtractOp> {
|
|
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto loc = extractOp.getLoc();
|
|
auto createMaskOp =
|
|
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
|
|
if (!createMaskOp)
|
|
return rewriter.notifyMatchFailure(
|
|
extractOp, "extract not from vector.create_mask op");
|
|
|
|
VectorType extractedMaskType =
|
|
llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
|
|
if (!extractedMaskType)
|
|
return rewriter.notifyMatchFailure(extractOp,
|
|
"extracted type is not a vector type");
|
|
|
|
auto numScalable = extractedMaskType.getNumScalableDims();
|
|
if (numScalable != 2)
|
|
return rewriter.notifyMatchFailure(
|
|
extractOp, "expected extracted type to be an SME-like mask");
|
|
|
|
// TODO: Support multiple extraction indices.
|
|
if (extractOp.getStaticPosition().size() != 1)
|
|
return rewriter.notifyMatchFailure(
|
|
extractOp, "only a single extraction index is supported");
|
|
|
|
auto frontMaskDim = createMaskOp.getOperand(0);
|
|
if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
|
|
return rewriter.notifyMatchFailure(
|
|
extractOp,
|
|
"constant vector.create_masks dims should be folded elsewhere");
|
|
|
|
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
auto extractionIndex = getValueOrCreateConstantIndexOp(
|
|
rewriter, loc, extractOp.getMixedPosition()[0]);
|
|
auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
|
|
loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
|
|
frontMaskDim);
|
|
auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
|
|
loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
|
|
|
|
rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
|
|
extractOp, extractedMaskType,
|
|
ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// A vector type where no fixed dimension comes after a scalable dimension.
|
|
bool isLegalVectorType(VectorType vType) {
|
|
bool seenFixedDim = false;
|
|
for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
|
|
seenFixedDim |= !scalableFlag;
|
|
if (seenFixedDim && scalableFlag)
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// Lifts an illegal vector.transpose and vector.transfer_read to a
|
|
/// memref.subview + memref.transpose, followed by a legal read.
|
|
///
|
|
/// 'Illegal' here means a leading scalable dimension and a fixed trailing
|
|
/// dimension, which has no valid lowering.
|
|
///
|
|
/// The memref.transpose is metadata-only transpose that produces a strided
|
|
/// memref, which eventually becomes a loop reading individual elements.
|
|
///
|
|
/// Example:
|
|
///
|
|
/// BEFORE:
|
|
/// ```mlir
|
|
/// %illegalRead = vector.transfer_read %memref[%a, %b]
|
|
/// : memref<?x?xf32>, vector<[8]x4xf32>
|
|
/// %legalType = vector.transpose %illegalRead, [1, 0]
|
|
/// : vector<[8]x4xf32> to vector<4x[8]xf32>
|
|
/// ```
|
|
///
|
|
/// AFTER:
|
|
/// ```mlir
|
|
/// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
|
|
/// : memref<?x?xf32> to memref<?x?xf32>
|
|
/// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
|
|
/// : memref<?x?xf32> to memref<?x?xf32>
|
|
/// %legalType = vector.transfer_read %transpose[%c0, %c0]
|
|
/// : memref<?x?xf32>, vector<4x[8]xf32>
|
|
/// ```
|
|
struct LiftIllegalVectorTransposeToMemory
|
|
: public OpRewritePattern<vector::TransposeOp> {
|
|
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
|
|
|
|
static Value getExtensionSource(Operation *op) {
|
|
if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
|
|
return op->getOperand(0);
|
|
return {};
|
|
}
|
|
|
|
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto sourceType = transposeOp.getSourceVectorType();
|
|
auto resultType = transposeOp.getResultVectorType();
|
|
if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
|
|
return rewriter.notifyMatchFailure(transposeOp,
|
|
kMatchFailureNotIllegalToLegal);
|
|
|
|
// Look through extend for transfer_read.
|
|
Value maybeRead = transposeOp.getVector();
|
|
auto *transposeSourceOp = maybeRead.getDefiningOp();
|
|
Operation *extendOp = nullptr;
|
|
if (Value extendSource = getExtensionSource(transposeSourceOp)) {
|
|
maybeRead = extendSource;
|
|
extendOp = transposeSourceOp;
|
|
}
|
|
|
|
auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
|
|
if (!illegalRead)
|
|
return rewriter.notifyMatchFailure(
|
|
transposeOp,
|
|
"expected source to be (possibly extended) transfer_read");
|
|
|
|
if (!illegalRead.getPermutationMap().isIdentity())
|
|
return rewriter.notifyMatchFailure(
|
|
illegalRead, "expected read to have identity permutation map");
|
|
|
|
auto loc = transposeOp.getLoc();
|
|
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
|
|
|
// Create a subview that matches the size of the illegal read vector type.
|
|
auto readType = illegalRead.getVectorType();
|
|
auto readSizes = llvm::map_to_vector(
|
|
llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
|
|
[&](auto dim) -> Value {
|
|
auto [size, isScalable] = dim;
|
|
auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
|
|
if (!isScalable)
|
|
return dimSize;
|
|
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
|
|
return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
|
|
});
|
|
SmallVector<Value> strides(readType.getRank(), Value(one));
|
|
auto readSubview = rewriter.create<memref::SubViewOp>(
|
|
loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
|
|
strides);
|
|
|
|
// Apply the transpose to all values/attributes of the transfer_read:
|
|
// - The mask
|
|
Value mask = illegalRead.getMask();
|
|
if (mask) {
|
|
// Note: The transpose for the mask should fold into the
|
|
// vector.create_mask/constant_mask op, which will then become legal.
|
|
mask = rewriter.create<vector::TransposeOp>(loc, mask,
|
|
transposeOp.getPermutation());
|
|
}
|
|
// - The source memref
|
|
mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
|
|
transposeOp.getPermutation(), getContext());
|
|
auto transposedSubview = rewriter.create<memref::TransposeOp>(
|
|
loc, readSubview, AffineMapAttr::get(transposeMap));
|
|
ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
|
|
// - The `in_bounds` attribute
|
|
if (inBoundsAttr) {
|
|
SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
|
|
inBoundsAttr.end());
|
|
applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
|
|
inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
|
|
}
|
|
|
|
VectorType legalReadType = resultType.clone(readType.getElementType());
|
|
// Note: The indices are all zero as the subview is already offset.
|
|
SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
|
|
auto legalRead = rewriter.create<vector::TransferReadOp>(
|
|
loc, legalReadType, transposedSubview, readIndices,
|
|
illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
|
|
inBoundsAttr);
|
|
|
|
// Replace the transpose with the new read, extending the result if
|
|
// necessary.
|
|
rewriter.replaceOp(transposeOp, [&]() -> Operation * {
|
|
if (extendOp)
|
|
return rewriter.create(loc, extendOp->getName().getIdentifier(),
|
|
Value(legalRead), resultType);
|
|
return legalRead;
|
|
}());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// A rewrite to turn unit dim transpose-like vector.shape_casts into
|
|
/// vector.transposes. The shape_cast has to be from an illegal vector type to a
|
|
/// legal one (as defined by isLegalVectorType).
|
|
///
|
|
/// The reasoning for this is if we've got to this pass and we still have
|
|
/// shape_casts of illegal types, then they likely will not cancel out. Turning
|
|
/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
|
|
/// eliminate them.
|
|
///
|
|
/// Example:
|
|
///
|
|
/// BEFORE:
|
|
/// ```mlir
|
|
/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
|
|
/// ```
|
|
///
|
|
/// AFTER:
|
|
/// ```mlir
|
|
/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
|
|
/// ```
|
|
struct ConvertIllegalShapeCastOpsToTransposes
|
|
: public OpRewritePattern<vector::ShapeCastOp> {
|
|
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto sourceType = shapeCastOp.getSourceVectorType();
|
|
auto resultType = shapeCastOp.getResultVectorType();
|
|
if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
|
|
return rewriter.notifyMatchFailure(shapeCastOp,
|
|
kMatchFailureNotIllegalToLegal);
|
|
|
|
// Note: If we know that `sourceType` is an illegal vector type (and 2D)
|
|
// then dim 0 is scalable and dim 1 is fixed.
|
|
if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
|
|
return rewriter.notifyMatchFailure(
|
|
shapeCastOp, "expected source to be a 2D scalable vector with a "
|
|
"trailing unit dim");
|
|
|
|
auto loc = shapeCastOp.getLoc();
|
|
auto transpose = rewriter.create<vector::TransposeOp>(
|
|
loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
|
|
|
|
if (resultType.getRank() == 1)
|
|
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
|
|
transpose);
|
|
else
|
|
rewriter.replaceOp(shapeCastOp, transpose);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
|
|
/// the ZA state. This workaround rewrite to support these transposes when ZA is
|
|
/// available.
|
|
///
|
|
/// Example:
|
|
///
|
|
/// BEFORE:
|
|
/// ```mlir
|
|
/// %transpose = vector.transpose %vec, [1, 0]
|
|
/// : vector<2x[4]xf32> to vector<[4]x2xf32>
|
|
/// vector.transfer_write %transpose, %dest[%y, %x]
|
|
/// : vector<[4]x2xf32>, memref<?x?xf32>
|
|
/// ```
|
|
///
|
|
/// AFTER:
|
|
/// ```mlir
|
|
/// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
|
|
/// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
|
|
/// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
|
|
/// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
|
|
/// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
|
|
/// %c4_vscale = arith.muli %vscale, %c4 : index
|
|
/// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
|
|
/// vector.transfer_write %4, %dest[%y, %x], %mask
|
|
/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
|
|
/// : vector<[4]x[4]xf32>, memref<?x?xf32>
|
|
/// ```
|
|
///
|
|
/// Values larger than a single tile are supported via decomposition.
|
|
struct LowerIllegalTransposeStoreViaZA
|
|
: public OpRewritePattern<vector::TransferWriteOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!isSupportedMaskOp(writeOp.getMask()))
|
|
return rewriter.notifyMatchFailure(writeOp,
|
|
kMatchFailureUnsupportedMaskOp);
|
|
|
|
auto permutationMap = writeOp.getPermutationMap();
|
|
if (!permutationMap.isIdentity())
|
|
return rewriter.notifyMatchFailure(writeOp,
|
|
kMatchFailureNonPermutationMap);
|
|
|
|
auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
|
|
if (!transposeOp)
|
|
return failure();
|
|
|
|
auto sourceType = transposeOp.getSourceVectorType();
|
|
auto resultType = transposeOp.getResultVectorType();
|
|
|
|
if (resultType.getRank() != 2)
|
|
return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
|
|
|
|
if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
|
|
return rewriter.notifyMatchFailure(
|
|
transposeOp, "not illegal/unsupported SVE transpose");
|
|
|
|
auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
|
|
VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
|
|
|
|
if (sourceType.getDimSize(0) <= 1 ||
|
|
sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
|
|
return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
|
|
|
|
auto loc = writeOp.getLoc();
|
|
auto createVscaleMultiple =
|
|
vector::makeVscaleConstantBuilder(rewriter, loc);
|
|
|
|
auto transposeMap = AffineMapAttr::get(
|
|
AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
|
|
|
|
// Note: We need to use `get_tile` as there's no vector-level `undef`.
|
|
Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
|
|
Value destTensorOrMemref = writeOp.getSource();
|
|
auto numSlicesPerTile =
|
|
std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
|
|
auto numSlices =
|
|
rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
|
|
for (auto [index, smeTile] : llvm::enumerate(
|
|
decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
|
|
// 1. _Deliberately_ drop a scalable dimension and insert a fixed number
|
|
// of slices from the source type into the SME tile. Without checking
|
|
// vscale (and emitting multiple implementations) we can't make use of the
|
|
// rows of the tile after 1*vscale rows.
|
|
Value tile = undefTile;
|
|
for (int d = 0; d < numSlicesPerTile; ++d) {
|
|
Value vector = rewriter.create<vector::ExtractOp>(
|
|
loc, transposeOp.getVector(),
|
|
rewriter.getIndexAttr(d + smeTile.row));
|
|
if (vector.getType() != smeSliceType) {
|
|
vector = rewriter.create<vector::ScalableExtractOp>(
|
|
loc, smeSliceType, vector, smeTile.col);
|
|
}
|
|
tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
|
|
}
|
|
|
|
// 2. Transpose the tile position.
|
|
auto transposedRow = createVscaleMultiple(smeTile.col);
|
|
auto transposedCol =
|
|
rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
|
|
|
|
// 3. Compute mask for tile store.
|
|
Value maskRows;
|
|
Value maskCols;
|
|
if (auto mask = writeOp.getMask()) {
|
|
auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
|
|
maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
|
|
transposedRow);
|
|
maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
|
|
transposedCol);
|
|
maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
|
|
} else {
|
|
maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
|
|
maskCols = numSlices;
|
|
}
|
|
auto subMask = rewriter.create<vector::CreateMaskOp>(
|
|
loc, smeTileType.clone(rewriter.getI1Type()),
|
|
ValueRange{maskRows, maskCols});
|
|
|
|
// 4. Emit a transposed tile write.
|
|
auto writeIndices = writeOp.getIndices();
|
|
Value destRow =
|
|
rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
|
|
Value destCol =
|
|
rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
|
|
auto smeWrite = rewriter.create<vector::TransferWriteOp>(
|
|
loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
|
|
transposeMap, subMask, writeOp.getInBounds());
|
|
|
|
if (writeOp.hasPureTensorSemantics())
|
|
destTensorOrMemref = smeWrite.getResult();
|
|
}
|
|
|
|
if (writeOp.hasPureTensorSemantics())
|
|
rewriter.replaceOp(writeOp, destTensorOrMemref);
|
|
else
|
|
rewriter.eraseOp(writeOp);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct VectorLegalizationPass
|
|
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
|
|
void runOnOperation() override {
|
|
auto *context = &getContext();
|
|
TypeConverter converter;
|
|
RewritePatternSet patterns(context);
|
|
converter.addConversion([](Type type) { return type; });
|
|
converter.addConversion(
|
|
[](VectorType vectorType,
|
|
SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
|
|
if (!isMultipleOfSMETileVectorType(vectorType))
|
|
return std::nullopt;
|
|
auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
|
|
auto smeTileType =
|
|
getSMETileTypeForElement(vectorType.getElementType());
|
|
types = SmallVector<Type>(smeTileCount, smeTileType);
|
|
return success();
|
|
});
|
|
|
|
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
|
|
LiftIllegalVectorTransposeToMemory,
|
|
ConvertIllegalShapeCastOpsToTransposes,
|
|
LowerIllegalTransposeStoreViaZA>(context);
|
|
// Note: These two patterns are added with a high benefit to ensure:
|
|
// - Masked outer products are handled before unmasked ones
|
|
// - Multi-tile writes are lowered as a store loop (if possible)
|
|
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
|
|
LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
|
|
/*benefit=*/1024);
|
|
patterns.add<LegalizeArithConstantOpsByDecomposition,
|
|
LegalizeVectorOuterProductOpsByDecomposition,
|
|
LegalizeTransferReadOpsByDecomposition,
|
|
LegalizeTransferWriteOpsByDecomposition>(converter, context);
|
|
populateFuncTypeConversionPatterns(converter, patterns);
|
|
scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
|
|
|
|
if (failed(applyPartialOneToNConversion(getOperation(), converter,
|
|
std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
|
|
return std::make_unique<VectorLegalizationPass>();
|
|
}
|