[mlir][vector] Add 1D vector.deinterleave lowering (#93042)
This patch implements the lowering of vector.deinterleave
for 1D vectors.
For fixed vector types, the operation is lowered to two
llvm shufflevector operations. One for even indexed
elements and the other for odd indexed elements. A poison
operation is used to satisfy the parameters of the
shufflevector parameters.
For scalable vectors, the llvm vector.deinterleave2
intrinsic is used for lowering. As such the results
found by extraction and used to form the result
struct for the intrinsic.
This commit is contained in:
@@ -1761,6 +1761,70 @@ struct VectorInterleaveOpLowering
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a `vector.deinterleave`.
|
||||
/// This supports fixed-sized vectors and scalable vectors.
|
||||
struct VectorDeinterleaveOpLowering
|
||||
: public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
|
||||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
VectorType resultType = deinterleaveOp.getResultVectorType();
|
||||
VectorType sourceType = deinterleaveOp.getSourceVectorType();
|
||||
auto loc = deinterleaveOp.getLoc();
|
||||
|
||||
// Note: n-D deinterleave operations should be lowered to the 1-D before
|
||||
// converting to LLVM.
|
||||
if (resultType.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(deinterleaveOp,
|
||||
"DeinterleaveOp not rank 1");
|
||||
|
||||
if (resultType.isScalable()) {
|
||||
auto llvmTypeConverter = this->getTypeConverter();
|
||||
auto deinterleaveResults = deinterleaveOp.getResultTypes();
|
||||
auto packedOpResults =
|
||||
llvmTypeConverter->packOperationResults(deinterleaveResults);
|
||||
auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
|
||||
loc, packedOpResults, adaptor.getSource());
|
||||
|
||||
auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, intrinsic->getResult(0), 0);
|
||||
auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, intrinsic->getResult(0), 1);
|
||||
|
||||
rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
|
||||
return success();
|
||||
}
|
||||
// Lower fixed-size deinterleave to two shufflevectors. While the
|
||||
// vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
|
||||
// langref still recommends fixed-vectors use shufflevector, see:
|
||||
// https://llvm.org/docs/LangRef.html#id889.
|
||||
int64_t resultVectorSize = resultType.getNumElements();
|
||||
SmallVector<int32_t> evenShuffleMask;
|
||||
SmallVector<int32_t> oddShuffleMask;
|
||||
|
||||
evenShuffleMask.reserve(resultVectorSize);
|
||||
oddShuffleMask.reserve(resultVectorSize);
|
||||
|
||||
for (int i = 0; i < sourceType.getNumElements(); ++i) {
|
||||
if (i % 2 == 0)
|
||||
evenShuffleMask.push_back(i);
|
||||
else
|
||||
oddShuffleMask.push_back(i);
|
||||
}
|
||||
|
||||
auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
|
||||
auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
|
||||
loc, adaptor.getSource(), poison, evenShuffleMask);
|
||||
auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
|
||||
loc, adaptor.getSource(), poison, oddShuffleMask);
|
||||
|
||||
rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
@@ -1785,8 +1849,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
|
||||
VectorSplatOpLowering, VectorSplatNdOpLowering,
|
||||
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
|
||||
MaskedReductionOpConversion, VectorInterleaveOpLowering>(
|
||||
converter);
|
||||
MaskedReductionOpConversion, VectorInterleaveOpLowering,
|
||||
VectorDeinterleaveOpLowering>(converter);
|
||||
// Transfer ops with rank > 1 are handled by VectorToSCF.
|
||||
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user