[mlir][vector] add higher dimensional support to gather/scatter
Similar to mask-load/store and compress/expand, the gather and scatter operation now allow for higher dimension uses. Note that to support the mixed-type index, the new syntax is: vector.gather %base [%i,%j] [%kvector] .... The first client of this generalization is the sparse compiler, which needs to define scatter and gathers on dense operands of higher dimensions too. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D97422
This commit is contained in:
@@ -178,34 +178,21 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns the base address of a memref.
|
||||
static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value memref, MemRefType memRefType, Value &base) {
|
||||
// Inspect stride and offset structure.
|
||||
//
|
||||
// TODO: flat memory only for now, generalize
|
||||
//
|
||||
// Add an index vector component to a base pointer. This almost always succeeds
|
||||
// unless the last stride is non-unit or the memory space is not zero.
|
||||
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value memref, Value base,
|
||||
Value index, MemRefType memRefType,
|
||||
VectorType vType, Value &ptrs) {
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
|
||||
if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
|
||||
offset != 0 || memRefType.getMemorySpace() != 0)
|
||||
return failure();
|
||||
base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns vector of pointers given a memref base with index vector.
|
||||
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value memref, Value indices,
|
||||
MemRefType memRefType, VectorType vType,
|
||||
Type iType, Value &ptrs) {
|
||||
Value base;
|
||||
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
|
||||
if (failed(successStrides) || strides.back() != 1 ||
|
||||
memRefType.getMemorySpace() != 0)
|
||||
return failure();
|
||||
auto pType = MemRefDescriptor(memref).getElementPtrType();
|
||||
auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
|
||||
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
|
||||
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -435,19 +422,20 @@ public:
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = gather->getLoc();
|
||||
auto adaptor = vector::GatherOpAdaptor(operands);
|
||||
MemRefType memRefType = gather.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
|
||||
align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
return failure();
|
||||
|
||||
// Get index ptrs.
|
||||
VectorType vType = gather.getVectorType();
|
||||
Type iType = gather.getIndicesVectorType().getElementType();
|
||||
// Resolve address.
|
||||
Value ptrs;
|
||||
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
|
||||
gather.getMemRefType(), vType, iType, ptrs)))
|
||||
VectorType vType = gather.getVectorType();
|
||||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
|
||||
adaptor.index_vec(), memRefType, vType, ptrs)))
|
||||
return failure();
|
||||
|
||||
// Replace with the gather intrinsic.
|
||||
@@ -469,19 +457,20 @@ public:
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = scatter->getLoc();
|
||||
auto adaptor = vector::ScatterOpAdaptor(operands);
|
||||
MemRefType memRefType = scatter.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
|
||||
align)))
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
return failure();
|
||||
|
||||
// Get index ptrs.
|
||||
VectorType vType = scatter.getVectorType();
|
||||
Type iType = scatter.getIndicesVectorType().getElementType();
|
||||
// Resolve address.
|
||||
Value ptrs;
|
||||
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
|
||||
scatter.getMemRefType(), vType, iType, ptrs)))
|
||||
VectorType vType = scatter.getVectorType();
|
||||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
|
||||
adaptor.index_vec(), memRefType, vType, ptrs)))
|
||||
return failure();
|
||||
|
||||
// Replace with the scatter intrinsic.
|
||||
@@ -507,8 +496,8 @@ public:
|
||||
|
||||
// Resolve address.
|
||||
auto vtype = typeConverter->convertType(expand.getVectorType());
|
||||
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
|
||||
expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
|
||||
@@ -530,8 +519,8 @@ public:
|
||||
MemRefType memRefType = compress.getMemRefType();
|
||||
|
||||
// Resolve address.
|
||||
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
|
||||
compress, adaptor.valueToStore(), ptr, adaptor.mask());
|
||||
|
||||
Reference in New Issue
Block a user