[mlir][Vector] Clean up populateVectorToLLVMConversionPatterns (#119975)

Clean up `populateVectorToLLVMConversionPatterns` so that it populates
only conversion patterns. All rewrite patterns that do not lower to LLVM
should be populated into a separate greedy pattern rewrite.

The current combination of rewrite patterns and conversion patterns
triggered an edge case when merging the 1:1 and 1:N dialect conversions.

Depends on #119973.
This commit is contained in:
Matthias Springer
2024-12-17 11:37:17 +01:00
committed by GitHub
parent 59890c1334
commit 0693b9e9cc
6 changed files with 41 additions and 22 deletions

View File

@@ -1475,16 +1475,17 @@ public:
/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
/// Non-scalable versions of this operation are handled in Vector Transforms.
class VectorCreateMaskOpRewritePattern
: public OpRewritePattern<vector::CreateMaskOp> {
class VectorCreateMaskOpConversion
: public OpConversionPattern<vector::CreateMaskOp> {
public:
explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
bool enableIndexOpt)
: OpRewritePattern<vector::CreateMaskOp>(context),
explicit VectorCreateMaskOpConversion(MLIRContext *context,
bool enableIndexOpt)
: OpConversionPattern<vector::CreateMaskOp>(context),
force32BitVectorIndices(enableIndexOpt) {}
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
LogicalResult
matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstType = op.getType();
if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
return failure();
@@ -1495,7 +1496,7 @@ public:
loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
/*isScalable=*/true));
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
op.getOperand(0));
adaptor.getOperands()[0]);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
indices, bounds);
@@ -1896,16 +1897,19 @@ struct VectorScalableStepOpLowering
} // namespace
void mlir::vector::populateVectorRankReducingFMAPattern(
RewritePatternSet &patterns) {
patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
}
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions, bool force32BitVectorIndices) {
// This function populates only ConversionPatterns, not RewritePatterns.
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
@@ -1922,8 +1926,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
VectorScalableStepOpLowering>(converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(