[MLIR][VectorToLLVM] Handle scalable dim in createVectorLengthValue() (#93361)
LLVM's Vector Predication Intrinsics require an explicit vector length parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics. For a scalable vector type, this should be caculated as VectorScaleOp multiplied by base vector length, e.g.: for <[4]xf32> we should return: vscale * 4.
This commit is contained in:
@@ -523,7 +523,7 @@ static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
|
||||
llvmType);
|
||||
}
|
||||
|
||||
/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
|
||||
/// Creates a value with the 1-D vector shape provided in `llvmType`.
|
||||
/// This is used as effective vector length by some intrinsics supporting
|
||||
/// dynamic vector lengths at runtime.
|
||||
static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
|
||||
@@ -532,9 +532,20 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
|
||||
auto vShape = vType.getShape();
|
||||
assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
|
||||
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(),
|
||||
rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
|
||||
|
||||
if (!vType.getScalableDims()[0])
|
||||
return baseVecLength;
|
||||
|
||||
// For a scalable vector type, create and return `vScale * baseVecLength`.
|
||||
Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
|
||||
vScale =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
|
||||
Value scalableVecLength =
|
||||
rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
|
||||
return scalableVecLength;
|
||||
}
|
||||
|
||||
/// Helper method to lower a `vector.reduction` op that performs an arithmetic
|
||||
|
||||
Reference in New Issue
Block a user