[mlir][Vector] Clean up populateVectorToLLVMConversionPatterns (#119975)
Clean up `populateVectorToLLVMConversionPatterns` so that it populates only conversion patterns. All rewrite patterns that do not lower to LLVM should be populated into a separate greedy pattern rewrite. The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions. Depends on #119973.
This commit is contained in:
committed by
GitHub
parent
59890c1334
commit
0693b9e9cc
@@ -1475,16 +1475,17 @@ public:
|
||||
|
||||
/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
|
||||
/// Non-scalable versions of this operation are handled in Vector Transforms.
|
||||
class VectorCreateMaskOpRewritePattern
|
||||
: public OpRewritePattern<vector::CreateMaskOp> {
|
||||
class VectorCreateMaskOpConversion
|
||||
: public OpConversionPattern<vector::CreateMaskOp> {
|
||||
public:
|
||||
explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
|
||||
bool enableIndexOpt)
|
||||
: OpRewritePattern<vector::CreateMaskOp>(context),
|
||||
explicit VectorCreateMaskOpConversion(MLIRContext *context,
|
||||
bool enableIndexOpt)
|
||||
: OpConversionPattern<vector::CreateMaskOp>(context),
|
||||
force32BitVectorIndices(enableIndexOpt) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getType();
|
||||
if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
|
||||
return failure();
|
||||
@@ -1495,7 +1496,7 @@ public:
|
||||
loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
|
||||
/*isScalable=*/true));
|
||||
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
|
||||
op.getOperand(0));
|
||||
adaptor.getOperands()[0]);
|
||||
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
|
||||
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
indices, bounds);
|
||||
@@ -1896,16 +1897,19 @@ struct VectorScalableStepOpLowering
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorRankReducingFMAPattern(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions, bool force32BitVectorIndices) {
|
||||
// This function populates only ConversionPatterns, not RewritePatterns.
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
|
||||
populateVectorInsertExtractStridedSliceTransforms(patterns);
|
||||
populateVectorStepLoweringPatterns(patterns);
|
||||
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
|
||||
patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
|
||||
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
|
||||
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
|
||||
@@ -1922,8 +1926,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
MaskedReductionOpConversion, VectorInterleaveOpLowering,
|
||||
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
|
||||
VectorScalableStepOpLowering>(converter);
|
||||
// Transfer ops with rank > 1 are handled by VectorToSCF.
|
||||
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
|
||||
}
|
||||
|
||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||
|
||||
Reference in New Issue
Block a user