[MLIR][VectorToLLVM] Handle scalable dim in createVectorLengthValue() (#93361)

LLVM's Vector Predication Intrinsics require an explicit vector length
parameter:
https://llvm.org/docs/LangRef.html#vector-predication-intrinsics.

For a scalable vector type, this should be caculated as VectorScaleOp
multiplied by base vector length, e.g.: for <[4]xf32> we should return:
vscale * 4.
This commit is contained in:
Zhaoshi Zheng
2024-06-13 09:06:05 -07:00
committed by GitHub
parent 19b43e1757
commit abcbbe7114
2 changed files with 123 additions and 2 deletions

View File

@@ -523,7 +523,7 @@ static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
llvmType);
}
/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
/// Creates a value with the 1-D vector shape provided in `llvmType`.
/// This is used as effective vector length by some intrinsics supporting
/// dynamic vector lengths at runtime.
static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
@@ -532,9 +532,20 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
auto vShape = vType.getShape();
assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
return rewriter.create<LLVM::ConstantOp>(
Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
if (!vType.getScalableDims()[0])
return baseVecLength;
// For a scalable vector type, create and return `vScale * baseVecLength`.
Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
vScale =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
Value scalableVecLength =
rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
return scalableVecLength;
}
/// Helper method to lower a `vector.reduction` op that performs an arithmetic