[mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass (#137389)

In ConvertVectorToLLVM, the only option for setting alignment of
`vector.gather`, `vector.scatter`, and the `vector.load/store` ops was
to extract it from the datatype of the memref type. However, this is
insufficient for hardware backends requiring alignment of vector types.
This PR introduces the `use-vector-alignment` option to the
`ConvertVectorToLLVMPass`, which makes the pass use the alignment of the
vector type of these operations instead of the alignment of the memref
type.

---------

Co-authored-by: Lily Orth-Smith <lorthsmith@microsoft.com>
This commit is contained in:
Lily Orth-Smith
2025-05-02 12:54:26 -07:00
committed by GitHub
parent 1101b76732
commit 3715de976e
7 changed files with 206 additions and 19 deletions

View File

@@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
}
// Helper that returns data layout alignment of a vector.
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
VectorType vectorType, unsigned &align) {
Type convertedVectorTy = typeConverter.convertType(vectorType);
if (!convertedVectorTy)
return failure();
llvm::LLVMContext llvmContext;
align = LLVM::TypeToLLVMIRTranslator(llvmContext)
.getPreferredAlignment(convertedVectorTy,
typeConverter.getDataLayout());
return success();
}
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
@@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
return success();
}
// Helper to resolve the alignment for vector load/store, gather and scatter
// ops. If useVectorAlignment is true, get the preferred alignment for the
// vector type in the operation. This option is used for hardware backends with
// vectorization. Otherwise, use the preferred alignment of the element type of
// the memref. Note that if you choose to use vector alignment, the shape of the
// vector type must be resolved before the ConvertVectorToLLVM pass is run.
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
VectorType vectorType,
MemRefType memrefType, unsigned &align,
bool useVectorAlignment) {
if (useVectorAlignment) {
if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
return failure();
}
} else {
if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
return failure();
}
}
return success();
}
// Check if the last stride is non-unit and has a valid memory space.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
const LLVMTypeConverter &converter) {
@@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
template <class LoadOrStoreOp>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
public:
explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
bool useVectorAlign)
: ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -240,8 +281,10 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
return failure();
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
memRefTy, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(loadOrStoreOp,
"could not resolve alignment");
// Resolve address.
auto vtype = cast<VectorType>(
@@ -252,12 +295,23 @@ public:
rewriter);
return success();
}
private:
// If true, use the preferred alignment of the vector type.
// If false, use the preferred alignment of the element type
// of the memref. This flag is intended for use with hardware
// backends that require alignment of vector operations.
const bool useVectorAlignment;
};
/// Conversion pattern for a vector.gather.
class VectorGatherOpConversion
: public ConvertOpToLLVMPattern<vector::GatherOp> {
public:
explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
bool useVectorAlign)
: ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -278,10 +332,9 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
return rewriter.notifyMatchFailure(gather,
"could not resolve memref alignment");
}
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
memRefType, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -297,12 +350,24 @@ public:
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
return success();
}
private:
// If true, use the preferred alignment of the vector type.
// If false, use the preferred alignment of the element type
// of the memref. This flag is intended for use with hardware
// backends that require alignment of vector operations.
const bool useVectorAlignment;
};
/// Conversion pattern for a vector.scatter.
class VectorScatterOpConversion
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
public:
explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
bool useVectorAlign)
: ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -322,10 +387,10 @@ public:
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
memRefType, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(scatter,
"could not resolve memref alignment");
}
"could not resolve alignment");
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -340,6 +405,13 @@ public:
rewriter.getI32IntegerAttr(align));
return success();
}
private:
// If true, use the preferred alignment of the vector type.
// If false, use the preferred alignment of the element type
// of the memref. This flag is intended for use with hardware
// backends that require alignment of vector operations.
const bool useVectorAlignment;
};
/// Conversion pattern for a vector.expandload.
@@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions, bool force32BitVectorIndices) {
bool reassociateFPReductions, bool force32BitVectorIndices,
bool useVectorAlignment) {
// This function populates only ConversionPatterns, not RewritePatterns.
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
VectorLoadStoreConversion<vector::MaskedLoadOp>,
VectorLoadStoreConversion<vector::StoreOp>,
VectorLoadStoreConversion<vector::MaskedStoreOp>,
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorLoadStoreConversion<vector::LoadOp>,
VectorLoadStoreConversion<vector::MaskedLoadOp>,
VectorLoadStoreConversion<vector::StoreOp>,
VectorLoadStoreConversion<vector::MaskedStoreOp>,
VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,