[mlir][Vector] Introduce 'vector.load' and 'vector.store' ops

This patch adds the 'vector.load' and 'vector.store' ops to the Vector
dialect [1]. These operations model *contiguous* vector loads and stores
from/to memory. Their semantics are similar to the 'affine.vector_load' and
'affine.vector_store' counterparts but without the affine constraints. The
most relevant feature is that these new vector operations may perform a vector
load/store on memrefs with a non-vector element type, unlike 'std.load' and
'std.store' ops. This opens the representation to model more generic vector
load/store scenarios: unaligned vector loads/stores, perform scalar and vector
memory access on the same memref, decouple memory allocation constraints from
memory accesses, etc [1]. These operations will also facilitate the progressive
lowering of both Affine vector loads/stores and Vector transfer reads/writes
for those that read/write contiguous slices from/to memory.

In particular, this patch adds the 'vector.load' and 'vector.store' ops to the
Vector dialect, implements their lowering to the LLVM dialect, and changes the
lowering of 'affine.vector_load' and 'affine.vector_store' ops to the new vector
ops. The lowering of Vector transfer reads/writes will be implemented in the
future, probably as an independent pass. The API of 'vector.maskedload' and
'vector.maskedstore' has also been changed slightly to align it with the
transfer read/write ops and the vector new ops. This will improve reusability
among all these operations. For example, the lowering of 'vector.load',
'vector.store', 'vector.maskedload' and 'vector.maskedstore' to the LLVM dialect
is implemented with a single template conversion pattern.

[1] https://llvm.discourse.group/t/memref-type-and-data-layout/

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96185
This commit is contained in:
Diego Caballero
2021-02-12 19:41:46 +02:00
parent 98754e2909
commit ee66e43a96
8 changed files with 414 additions and 116 deletions

View File

@@ -357,64 +357,72 @@ public:
}
};
/// Conversion pattern for a vector.maskedload.
class VectorMaskedLoadOpConversion
: public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
/// Overloaded utility that replaces a vector.load, vector.store,
/// vector.maskedload and vector.maskedstore with their respective LLVM
/// couterparts.
static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
vector::LoadOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
}
static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
vector::MaskedLoadOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
}
static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
vector::StoreOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
ptr, align);
}
static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
vector::MaskedStoreOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
}
/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
/// vector.maskedstore.
template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
public:
using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = load->getLoc();
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
MemRefType memRefType = load.getMemRefType();
// Only 1-D vectors can be lowered to LLVM.
VectorType vectorTy = loadOrStoreOp.getVectorType();
if (vectorTy.getRank() > 1)
return failure();
auto loc = loadOrStoreOp->getLoc();
auto adaptor = LoadOrStoreOpAdaptor(operands);
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
return failure();
// Resolve address.
auto vtype = typeConverter->convertType(load.getResultVectorType());
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
.template cast<VectorType>();
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
adaptor.indices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
rewriter.getI32IntegerAttr(align));
return success();
}
};
/// Conversion pattern for a vector.maskedstore.
class VectorMaskedStoreOpConversion
: public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
public:
using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = store->getLoc();
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
MemRefType memRefType = store.getMemRefType();
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
auto vtype = typeConverter->convertType(store.getValueVectorType());
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
store, adaptor.value(), ptr, adaptor.mask(),
rewriter.getI32IntegerAttr(align));
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
return success();
}
};
@@ -1511,8 +1519,14 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion,
VectorPrintOpConversion,
VectorTypeCastOpConversion,
VectorMaskedLoadOpConversion,
VectorMaskedStoreOpConversion,
VectorLoadStoreConversion<vector::LoadOp,
vector::LoadOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedLoadOp,
vector::MaskedLoadOpAdaptor>,
VectorLoadStoreConversion<vector::StoreOp,
vector::StoreOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedStoreOp,
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion,
VectorScatterOpConversion,
VectorExpandLoadOpConversion,