[MLIR]Extend vector.gather to support n-D result
Currently vector.gather only supports reading memory into a 1-D result vector. This patch extends it to support an n-D result vector with the indices, masks, and passthroughs in n-D vectors. As we are trying to vectorize tensor.extract with vector.gather (https://github.com/iree-org/iree/issues/9198), it will need to gather the elements into an n-D vector. Having vector.gather with n-D results allows us to avoid flatten and reshape at the vectorization stage. The backends can then decide the optimal ways to lower the vector.gather op. Note that this is different from n-D gathering, which is about reading n-D memory with the n-D indices. The indices here are still only 1-D offsets on the base. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D131905
This commit is contained in:
committed by
Diego Caballero
parent
6ca17b58f5
commit
0cbfd6fd16
@@ -91,24 +91,28 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
||||
return success();
|
||||
}
|
||||
|
||||
// Add an index vector component to a base pointer. This almost always succeeds
|
||||
// unless the last stride is non-unit or the memory space is not zero.
|
||||
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value memref, Value base,
|
||||
Value index, MemRefType memRefType,
|
||||
VectorType vType, Value &ptrs) {
|
||||
// Check if the last stride is non-unit or the memory space is not zero.
|
||||
static LogicalResult isMemRefTypeSupported(MemRefType memRefType) {
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
|
||||
if (failed(successStrides) || strides.back() != 1 ||
|
||||
memRefType.getMemorySpaceAsInt() != 0)
|
||||
return failure();
|
||||
auto pType = MemRefDescriptor(memref).getElementPtrType();
|
||||
auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
|
||||
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Add an index vector component to a base pointer.
|
||||
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
|
||||
MemRefType memRefType, Value llvmMemref, Value base,
|
||||
Value index, uint64_t vLen) {
|
||||
assert(succeeded(isMemRefTypeSupported(memRefType)) &&
|
||||
"unsupported memref type");
|
||||
auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
|
||||
auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
|
||||
return rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
|
||||
}
|
||||
|
||||
// Casts a strided element pointer to a vector pointer. The vector pointer
|
||||
// will be in the same address space as the incoming memref type.
|
||||
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
|
||||
@@ -257,29 +261,53 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = gather->getLoc();
|
||||
MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
|
||||
assert(memRefType && "The base should be bufferized");
|
||||
|
||||
if (failed(isMemRefTypeSupported(memRefType)))
|
||||
return failure();
|
||||
|
||||
auto loc = gather->getLoc();
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
Value ptrs;
|
||||
VectorType vType = gather.getVectorType();
|
||||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
|
||||
adaptor.getIndexVec(), memRefType, vType, ptrs)))
|
||||
return failure();
|
||||
Value base = adaptor.getBase();
|
||||
|
||||
// Replace with the gather intrinsic.
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
|
||||
gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
|
||||
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
auto llvmNDVectorTy = adaptor.getIndexVec().getType();
|
||||
// Handle the simple case of 1-D vector.
|
||||
if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
|
||||
auto vType = gather.getVectorType();
|
||||
// Resolve address.
|
||||
Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr,
|
||||
adaptor.getIndexVec(),
|
||||
/*vLen=*/vType.getDimSize(0));
|
||||
// Replace with the gather intrinsic.
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
|
||||
gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
|
||||
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
|
||||
auto callback = [align, memRefType, base, ptr, loc, &rewriter](
|
||||
Type llvm1DVectorTy, ValueRange vectorOperands) {
|
||||
// Resolve address.
|
||||
Value ptrs = getIndexedPtrs(
|
||||
rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0],
|
||||
LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
|
||||
// Create the gather intrinsic.
|
||||
return rewriter.create<LLVM::masked_gather>(
|
||||
loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
|
||||
/*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
|
||||
};
|
||||
ValueRange vectorOperands = {adaptor.getIndexVec(), adaptor.getMask(),
|
||||
adaptor.getPassThru()};
|
||||
return LLVM::detail::handleMultidimensionalVectors(
|
||||
gather, vectorOperands, *getTypeConverter(), callback, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -295,19 +323,21 @@ public:
|
||||
auto loc = scatter->getLoc();
|
||||
MemRefType memRefType = scatter.getMemRefType();
|
||||
|
||||
if (failed(isMemRefTypeSupported(memRefType)))
|
||||
return failure();
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
Value ptrs;
|
||||
VectorType vType = scatter.getVectorType();
|
||||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
|
||||
adaptor.getIndexVec(), memRefType, vType, ptrs)))
|
||||
return failure();
|
||||
Value ptrs =
|
||||
getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr,
|
||||
adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
|
||||
|
||||
// Replace with the scatter intrinsic.
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
|
||||
|
||||
Reference in New Issue
Block a user