[mlir][vector] add higher dimensional support to gather/scatter

Similar to mask-load/store and compress/expand, the gather and
scatter operation now allow for higher dimension uses. Note that
to support the mixed-type index, the new syntax is:
   vector.gather %base [%i,%j] [%kvector] ....
The first client of this generalization is the sparse compiler,
which needs to define scatter and gathers on dense operands
of higher dimensions too.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D97422
This commit is contained in:
Aart Bik
2021-02-25 18:04:39 -08:00
parent c62dabc3f5
commit df5ccf5a94
14 changed files with 244 additions and 115 deletions

View File

@@ -178,34 +178,21 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
return success();
}
// Helper that returns the base address of a memref.
static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
Value memref, MemRefType memRefType, Value &base) {
// Inspect stride and offset structure.
//
// TODO: flat memory only for now, generalize
//
// Add an index vector component to a base pointer. This almost always succeeds
// unless the last stride is non-unit or the memory space is not zero.
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
Location loc, Value memref, Value base,
Value index, MemRefType memRefType,
VectorType vType, Value &ptrs) {
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
offset != 0 || memRefType.getMemorySpace() != 0)
return failure();
base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
return success();
}
// Helper that returns vector of pointers given a memref base with index vector.
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
Location loc, Value memref, Value indices,
MemRefType memRefType, VectorType vType,
Type iType, Value &ptrs) {
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
if (failed(successStrides) || strides.back() != 1 ||
memRefType.getMemorySpace() != 0)
return failure();
auto pType = MemRefDescriptor(memref).getElementPtrType();
auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
return success();
}
@@ -435,19 +422,20 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto loc = gather->getLoc();
auto adaptor = vector::GatherOpAdaptor(operands);
MemRefType memRefType = gather.getMemRefType();
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
align)))
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Get index ptrs.
VectorType vType = gather.getVectorType();
Type iType = gather.getIndicesVectorType().getElementType();
// Resolve address.
Value ptrs;
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
gather.getMemRefType(), vType, iType, ptrs)))
VectorType vType = gather.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
adaptor.index_vec(), memRefType, vType, ptrs)))
return failure();
// Replace with the gather intrinsic.
@@ -469,19 +457,20 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
auto adaptor = vector::ScatterOpAdaptor(operands);
MemRefType memRefType = scatter.getMemRefType();
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
align)))
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Get index ptrs.
VectorType vType = scatter.getVectorType();
Type iType = scatter.getIndicesVectorType().getElementType();
// Resolve address.
Value ptrs;
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
scatter.getMemRefType(), vType, iType, ptrs)))
VectorType vType = scatter.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
adaptor.index_vec(), memRefType, vType, ptrs)))
return failure();
// Replace with the scatter intrinsic.
@@ -507,8 +496,8 @@ public:
// Resolve address.
auto vtype = typeConverter->convertType(expand.getVectorType());
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
@@ -530,8 +519,8 @@ public:
MemRefType memRefType = compress.getMemRefType();
// Resolve address.
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
compress, adaptor.valueToStore(), ptr, adaptor.mask());