[mlir][Vector] Introduce 'vector.load' and 'vector.store' ops
This patch adds the 'vector.load' and 'vector.store' ops to the Vector dialect [1]. These operations model *contiguous* vector loads and stores from/to memory. Their semantics are similar to the 'affine.vector_load' and 'affine.vector_store' counterparts but without the affine constraints. The most relevant feature is that these new vector operations may perform a vector load/store on memrefs with a non-vector element type, unlike 'std.load' and 'std.store' ops. This opens the representation to model more generic vector load/store scenarios: unaligned vector loads/stores, perform scalar and vector memory access on the same memref, decouple memory allocation constraints from memory accesses, etc [1]. These operations will also facilitate the progressive lowering of both Affine vector loads/stores and Vector transfer reads/writes for those that read/write contiguous slices from/to memory. In particular, this patch adds the 'vector.load' and 'vector.store' ops to the Vector dialect, implements their lowering to the LLVM dialect, and changes the lowering of 'affine.vector_load' and 'affine.vector_store' ops to the new vector ops. The lowering of Vector transfer reads/writes will be implemented in the future, probably as an independent pass. The API of 'vector.maskedload' and 'vector.maskedstore' has also been changed slightly to align it with the transfer read/write ops and the vector new ops. This will improve reusability among all these operations. For example, the lowering of 'vector.load', 'vector.store', 'vector.maskedload' and 'vector.maskedstore' to the LLVM dialect is implemented with a single template conversion pattern. [1] https://llvm.discourse.group/t/memref-type-and-data-layout/ Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D96185
This commit is contained in:
@@ -357,64 +357,72 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.maskedload.
|
||||
class VectorMaskedLoadOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
|
||||
/// Overloaded utility that replaces a vector.load, vector.store,
|
||||
/// vector.maskedload and vector.maskedstore with their respective LLVM
|
||||
/// couterparts.
|
||||
static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
|
||||
vector::LoadOpAdaptor adaptor,
|
||||
VectorType vectorTy, Value ptr, unsigned align,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
|
||||
}
|
||||
|
||||
static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
|
||||
vector::MaskedLoadOpAdaptor adaptor,
|
||||
VectorType vectorTy, Value ptr, unsigned align,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
|
||||
}
|
||||
|
||||
static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
|
||||
vector::StoreOpAdaptor adaptor,
|
||||
VectorType vectorTy, Value ptr, unsigned align,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
|
||||
ptr, align);
|
||||
}
|
||||
|
||||
static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
|
||||
vector::MaskedStoreOpAdaptor adaptor,
|
||||
VectorType vectorTy, Value ptr, unsigned align,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
|
||||
}
|
||||
|
||||
/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
|
||||
/// vector.maskedstore.
|
||||
template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
|
||||
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
|
||||
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
|
||||
matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = load->getLoc();
|
||||
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
|
||||
MemRefType memRefType = load.getMemRefType();
|
||||
// Only 1-D vectors can be lowered to LLVM.
|
||||
VectorType vectorTy = loadOrStoreOp.getVectorType();
|
||||
if (vectorTy.getRank() > 1)
|
||||
return failure();
|
||||
|
||||
auto loc = loadOrStoreOp->getLoc();
|
||||
auto adaptor = LoadOrStoreOpAdaptor(operands);
|
||||
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
auto vtype = typeConverter->convertType(load.getResultVectorType());
|
||||
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
|
||||
.template cast<VectorType>();
|
||||
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
|
||||
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.maskedstore.
|
||||
class VectorMaskedStoreOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = store->getLoc();
|
||||
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
|
||||
MemRefType memRefType = store.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
auto vtype = typeConverter->convertType(store.getValueVectorType());
|
||||
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
store, adaptor.value(), ptr, adaptor.mask(),
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1511,8 +1519,14 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorInsertOpConversion,
|
||||
VectorPrintOpConversion,
|
||||
VectorTypeCastOpConversion,
|
||||
VectorMaskedLoadOpConversion,
|
||||
VectorMaskedStoreOpConversion,
|
||||
VectorLoadStoreConversion<vector::LoadOp,
|
||||
vector::LoadOpAdaptor>,
|
||||
VectorLoadStoreConversion<vector::MaskedLoadOp,
|
||||
vector::MaskedLoadOpAdaptor>,
|
||||
VectorLoadStoreConversion<vector::StoreOp,
|
||||
vector::StoreOpAdaptor>,
|
||||
VectorLoadStoreConversion<vector::MaskedStoreOp,
|
||||
vector::MaskedStoreOpAdaptor>,
|
||||
VectorGatherOpConversion,
|
||||
VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion,
|
||||
|
||||
Reference in New Issue
Block a user