[mlir][vector] Propagate scalability to gather/scatter ptrs vector (#97584)
In convert-vector-to-llvm the first operand (vector of pointers holding
all memory addresses to read) to the masked.gather (and scatter)
intrinsic has a fixed vector type.
This may result in intrinsics where the scalable flag has been dropped:
```
%0 = llvm.intr.masked.gather %1, %2, %3 {alignment = 4 : i32}
: (!llvm.vec<4 x ptr>, vector<[4]xi1>, vector<[4]xi32>) -> vector<[4]xi32>
```
Fortunately the operand is overloaded on the result type so we end up
with the correct IR when lowering to LLVM, but this is still incorrect.
This patch fixes it by propagating scalability.
This commit is contained in:
@@ -102,11 +102,14 @@ static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
|
||||
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
MemRefType memRefType, Value llvmMemref, Value base,
|
||||
Value index, uint64_t vLen) {
|
||||
Value index, VectorType vectorType) {
|
||||
assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
|
||||
"unsupported memref type");
|
||||
assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
|
||||
auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
|
||||
auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
|
||||
auto ptrsType =
|
||||
LLVM::getVectorType(pType, vectorType.getDimSize(0),
|
||||
/*isScalable=*/vectorType.getScalableDims()[0]);
|
||||
return rewriter.create<LLVM::GEPOp>(
|
||||
loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
|
||||
base, index);
|
||||
@@ -288,9 +291,9 @@ public:
|
||||
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
|
||||
auto vType = gather.getVectorType();
|
||||
// Resolve address.
|
||||
Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
|
||||
memRefType, base, ptr, adaptor.getIndexVec(),
|
||||
/*vLen=*/vType.getDimSize(0));
|
||||
Value ptrs =
|
||||
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
|
||||
base, ptr, adaptor.getIndexVec(), vType);
|
||||
// Replace with the gather intrinsic.
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
|
||||
gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
|
||||
@@ -305,8 +308,7 @@ public:
|
||||
// Resolve address.
|
||||
Value ptrs = getIndexedPtrs(
|
||||
rewriter, loc, typeConverter, memRefType, base, ptr,
|
||||
/*index=*/vectorOperands[0],
|
||||
LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
|
||||
/*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
|
||||
// Create the gather intrinsic.
|
||||
return rewriter.create<LLVM::masked_gather>(
|
||||
loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
|
||||
@@ -343,9 +345,9 @@ public:
|
||||
VectorType vType = scatter.getVectorType();
|
||||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
Value ptrs = getIndexedPtrs(
|
||||
rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
|
||||
ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
|
||||
Value ptrs =
|
||||
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
|
||||
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
|
||||
|
||||
// Replace with the scatter intrinsic.
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
|
||||
|
||||
Reference in New Issue
Block a user