[mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass (#137389)
In ConvertVectorToLLVM, the only option for setting alignment of `vector.gather`, `vector.scatter`, and the `vector.load/store` ops was to extract it from the datatype of the memref type. However, this is insufficient for hardware backends requiring alignment of vector types. This PR introduces the `use-vector-alignment` option to the `ConvertVectorToLLVMPass`, which makes the pass use the alignment of the vector type of these operations instead of the alignment of the memref type. --------- Co-authored-by: Lily Orth-Smith <lorthsmith@microsoft.com>
This commit is contained in:
@@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
|
||||
return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
|
||||
}
|
||||
|
||||
// Helper that returns data layout alignment of a vector.
|
||||
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
|
||||
VectorType vectorType, unsigned &align) {
|
||||
Type convertedVectorTy = typeConverter.convertType(vectorType);
|
||||
if (!convertedVectorTy)
|
||||
return failure();
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
align = LLVM::TypeToLLVMIRTranslator(llvmContext)
|
||||
.getPreferredAlignment(convertedVectorTy,
|
||||
typeConverter.getDataLayout());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns data layout alignment of a memref.
|
||||
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
|
||||
MemRefType memrefType, unsigned &align) {
|
||||
@@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper to resolve the alignment for vector load/store, gather and scatter
|
||||
// ops. If useVectorAlignment is true, get the preferred alignment for the
|
||||
// vector type in the operation. This option is used for hardware backends with
|
||||
// vectorization. Otherwise, use the preferred alignment of the element type of
|
||||
// the memref. Note that if you choose to use vector alignment, the shape of the
|
||||
// vector type must be resolved before the ConvertVectorToLLVM pass is run.
|
||||
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
|
||||
VectorType vectorType,
|
||||
MemRefType memrefType, unsigned &align,
|
||||
bool useVectorAlignment) {
|
||||
if (useVectorAlignment) {
|
||||
if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Check if the last stride is non-unit and has a valid memory space.
|
||||
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
|
||||
const LLVMTypeConverter &converter) {
|
||||
@@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
|
||||
template <class LoadOrStoreOp>
|
||||
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
|
||||
public:
|
||||
explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
|
||||
bool useVectorAlign)
|
||||
: ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
|
||||
useVectorAlignment(useVectorAlign) {}
|
||||
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
@@ -240,8 +281,10 @@ public:
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
|
||||
return failure();
|
||||
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
|
||||
memRefTy, align, useVectorAlignment)))
|
||||
return rewriter.notifyMatchFailure(loadOrStoreOp,
|
||||
"could not resolve alignment");
|
||||
|
||||
// Resolve address.
|
||||
auto vtype = cast<VectorType>(
|
||||
@@ -252,12 +295,23 @@ public:
|
||||
rewriter);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// If true, use the preferred alignment of the vector type.
|
||||
// If false, use the preferred alignment of the element type
|
||||
// of the memref. This flag is intended for use with hardware
|
||||
// backends that require alignment of vector operations.
|
||||
const bool useVectorAlignment;
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.gather.
|
||||
class VectorGatherOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::GatherOp> {
|
||||
public:
|
||||
explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
|
||||
bool useVectorAlign)
|
||||
: ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
|
||||
useVectorAlignment(useVectorAlign) {}
|
||||
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
@@ -278,10 +332,9 @@ public:
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
|
||||
return rewriter.notifyMatchFailure(gather,
|
||||
"could not resolve memref alignment");
|
||||
}
|
||||
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
|
||||
memRefType, align, useVectorAlignment)))
|
||||
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
|
||||
|
||||
// Resolve address.
|
||||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
|
||||
@@ -297,12 +350,24 @@ public:
|
||||
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// If true, use the preferred alignment of the vector type.
|
||||
// If false, use the preferred alignment of the element type
|
||||
// of the memref. This flag is intended for use with hardware
|
||||
// backends that require alignment of vector operations.
|
||||
const bool useVectorAlignment;
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.scatter.
|
||||
class VectorScatterOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
|
||||
public:
|
||||
explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
|
||||
bool useVectorAlign)
|
||||
: ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
|
||||
useVectorAlignment(useVectorAlign) {}
|
||||
|
||||
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
@@ -322,10 +387,10 @@ public:
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
|
||||
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
|
||||
memRefType, align, useVectorAlignment)))
|
||||
return rewriter.notifyMatchFailure(scatter,
|
||||
"could not resolve memref alignment");
|
||||
}
|
||||
"could not resolve alignment");
|
||||
|
||||
// Resolve address.
|
||||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
|
||||
@@ -340,6 +405,13 @@ public:
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// If true, use the preferred alignment of the vector type.
|
||||
// If false, use the preferred alignment of the element type
|
||||
// of the memref. This flag is intended for use with hardware
|
||||
// backends that require alignment of vector operations.
|
||||
const bool useVectorAlignment;
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.expandload.
|
||||
@@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions, bool force32BitVectorIndices) {
|
||||
bool reassociateFPReductions, bool force32BitVectorIndices,
|
||||
bool useVectorAlignment) {
|
||||
// This function populates only ConversionPatterns, not RewritePatterns.
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
|
||||
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
|
||||
patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
|
||||
VectorLoadStoreConversion<vector::MaskedLoadOp>,
|
||||
VectorLoadStoreConversion<vector::StoreOp>,
|
||||
VectorLoadStoreConversion<vector::MaskedStoreOp>,
|
||||
VectorGatherOpConversion, VectorScatterOpConversion>(
|
||||
converter, useVectorAlignment);
|
||||
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
|
||||
VectorInsertOpConversion, VectorPrintOpConversion,
|
||||
VectorTypeCastOpConversion, VectorScaleOpConversion,
|
||||
VectorLoadStoreConversion<vector::LoadOp>,
|
||||
VectorLoadStoreConversion<vector::MaskedLoadOp>,
|
||||
VectorLoadStoreConversion<vector::StoreOp>,
|
||||
VectorLoadStoreConversion<vector::MaskedStoreOp>,
|
||||
VectorGatherOpConversion, VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
|
||||
VectorSplatOpLowering, VectorSplatNdOpLowering,
|
||||
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
|
||||
|
||||
Reference in New Issue
Block a user