[mlir][vector] Support complete folding in single pass for vector.insert/vector.extract (#142124)

### Description

This patch improves the folding efficiency of `vector.insert` and
`vector.extract` operations by not returning early after successfully
converting dynamic indices to static indices.

This PR also renames the test pass `TestConstantFold` to
`TestSingleFold` and adds comprehensive documentation explaining the
single-pass folding behavior.

### Motivation

Since the `OpBuilder::createOrFold` function only calls `fold` **once**,
the current `fold` methods of `vector.insert` and `vector.extract` may
leave the op in a state that can be folded further. For example,
consider the following un-folded IR:
```
%v1 = vector.insert %e1, %v0 [0] : f32 into vector<128xf32>
%c0 = arith.constant 0 : index
%e2 = vector.extract %v1[%c0] : f32 from vector<128xf32>
```
If we use `createOrFold` to create the `vector.extract` op, then the
result will be:
```
%v1 = vector.insert %e1, %v0 [127] : f32 into vector<128xf32>
%e2 = vector.extract %v1[0] : f32 from vector<128xf32>
```
But this is not the optimal result. `createOrFold` should have returned
`%e1`.
The reason is that the execution of fold returns immediately after
`extractInsertFoldConstantOp`, causing subsequent folding logics to be
skipped.

---------

Co-authored-by: Yang Bai <yangb@nvidia.com>
This commit is contained in:
Yang Bai
2025-06-19 00:26:04 +08:00
committed by GitHub
parent 0018921148
commit fe3933da15
13 changed files with 86 additions and 32 deletions

View File

@@ -2063,6 +2063,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
if (opChange) {
op.setStaticPosition(staticPosition);
op.getOperation()->setOperands(operands);
// Return the original result to indicate an in-place folding happened.
return op.getResult();
}
return {};
@@ -2146,11 +2147,12 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return getVector();
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
return res;
// Fold `arith.constant` indices into the `vector.extract` operation. Make
// sure that patterns requiring constant indices are added after this fold.
// Fold `arith.constant` indices into the `vector.extract` operation.
// Do not stop here as this fold may enable subsequent folds that require
// constant indices.
SmallVector<Value> operands = {getVector()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
@@ -2172,7 +2174,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
return OpFoldResult();
return inplaceFolded;
}
namespace {
@@ -3272,11 +3275,12 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getValueToStoreType() == getType())
return getValueToStore();
// Fold `arith.constant` indices into the `vector.insert` operation. Make
// sure that patterns requiring constant indices are added after this fold.
// Fold `arith.constant` indices into the `vector.insert` operation.
// Do not stop here as this fold may enable subsequent folds that require
// constant indices.
SmallVector<Value> operands = {getValueToStore(), getDest()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
@@ -3286,7 +3290,7 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
return res;
}
return {};
return inplaceFolded;
}
//===----------------------------------------------------------------------===//

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-constant-fold -split-input-file %s | FileCheck %s
// RUN: mlir-opt -test-single-fold -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @affine_apply
func.func @affine_apply(%variable : index) -> (index, index, index) {

View File

@@ -1,5 +1,5 @@
// RUN: mlir-opt \
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
// RUN: --split-input-file \
// RUN: %s | FileCheck %s

View File

@@ -1,5 +1,5 @@
// RUN: mlir-opt \
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
// RUN: %s | FileCheck %s
mesh.mesh @mesh_1d(shape = 2)

View File

@@ -1,5 +1,5 @@
// RUN: mlir-opt \
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
// RUN: %s | FileCheck %s
mesh.mesh @mesh_1d_4(shape = 4)

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt --test-constant-fold %s | FileCheck %s
// RUN: mlir-opt --test-single-fold %s | FileCheck %s
// CHECK-LABEL: func @test_const
func.func @test_const(%arg0 : index) -> tensor<4xi32> {

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -split-input-file -test-constant-fold | FileCheck %s
// RUN: mlir-opt %s -split-input-file -test-single-fold | FileCheck %s
// CHECK-LABEL: fold_extract_transpose_negative
func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4xf16> {
@@ -11,3 +11,5 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
%2 = vector.extract %1[0] : vector<4x4xf16> from vector<1x4x4xf16>
return %2 : vector<4x4xf16>
}

View File

@@ -0,0 +1,38 @@
// RUN: mlir-opt %s -split-input-file -test-single-fold | FileCheck %s
// The tests in this file verify that fold() methods can handle complex
// optimization scenarios without requiring multiple folding iterations.
// This is important because:
//
// 1. OpBuilder::createOrFold() only calls fold() once, so operations must
// be fully optimized in that single call
// 2. Multiple rounds of folding would incur higher performance costs,
// so it's more efficient to complete all optimizations in one pass
//
// These tests ensure that folding implementations are robust and complete,
// avoiding situations where operations are left in intermediate states
// that could be further optimized.
// CHECK-LABEL: fold_extract_in_single_pass
// CHECK-SAME: (%{{.*}}: vector<4xf16>, %[[ARG1:.+]]: f16)
func.func @fold_extract_in_single_pass(%arg0: vector<4xf16>, %arg1: f16) -> f16 {
%0 = vector.insert %arg1, %arg0 [1] : f16 into vector<4xf16>
%c1 = arith.constant 1 : index
// Verify that the fold is finished in a single pass even if the index is dynamic.
%1 = vector.extract %0[%c1] : f16 from vector<4xf16>
// CHECK: return %[[ARG1]] : f16
return %1 : f16
}
// -----
// CHECK-LABEL: fold_insert_in_single_pass
func.func @fold_insert_in_single_pass() -> vector<2xf16> {
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
%c1 = arith.constant 1 : index
%c2 = arith.constant 2.5 : f16
// Verify that the fold is finished in a single pass even if the index is dynamic.
// CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16>
%0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16>
return %0 : vector<2xf16>
}

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -split-input-file -test-constant-fold -mlir-print-debuginfo | FileCheck %s
// RUN: mlir-opt %s -split-input-file -test-single-fold -mlir-print-debuginfo | FileCheck %s
// CHECK-LABEL: func @fold_and_merge
func.func @fold_and_merge() -> (i32, i32) {

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -test-constant-fold | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -test-single-fold | FileCheck %s
// -----

View File

@@ -26,11 +26,11 @@ endif()
add_mlir_library(MLIRTestTransforms
TestCommutativityUtils.cpp
TestCompositePass.cpp
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
TestInliningCallback.cpp
TestMakeIsolatedFromAbove.cpp
TestSingleFold.cpp
TestTransformsOps.cpp
${MLIRTestTransformsPDLSrc}

View File

@@ -1,4 +1,4 @@
//===- TestConstantFold.cpp - Pass to test constant folding ---------------===//
//===- TestSingleFold.cpp - Pass to test single-pass folding --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -12,14 +12,23 @@
using namespace mlir;
namespace {
/// Simple constant folding pass.
struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
/// Test pass for single-pass constant folding.
///
/// This pass tests the behavior of operations when folded exactly once. Unlike
/// canonicalization passes that may apply multiple rounds of folding, this pass
/// ensures that each operation is folded at most once, which is useful for
/// testing scenarios where the fold implementation should handle complex cases
/// without requiring multiple iterations.
///
/// The pass also removes dead constants after folding to clean up unused
/// intermediate results.
struct TestSingleFold : public PassWrapper<TestSingleFold, OperationPass<>>,
public RewriterBase::Listener {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConstantFold)
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSingleFold)
StringRef getArgument() const final { return "test-constant-fold"; }
StringRef getArgument() const final { return "test-single-fold"; }
StringRef getDescription() const final {
return "Test operation constant folding";
return "Test single-pass operation folding and dead constant elimination";
}
// All constants in the operation post folding.
SmallVector<Operation *> existingConstants;
@@ -39,18 +48,19 @@ struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
};
} // namespace
void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
void TestSingleFold::foldOperation(Operation *op, OperationFolder &helper) {
// Attempt to fold the specified operation, including handling unused or
// duplicated constants.
(void)helper.tryToFold(op);
}
void TestConstantFold::runOnOperation() {
void TestSingleFold::runOnOperation() {
existingConstants.clear();
// Collect and fold the operations within the operation.
SmallVector<Operation *, 8> ops;
getOperation()->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) { ops.push_back(op); });
getOperation()->walk<mlir::WalkOrder::PreOrder>(
[&](Operation *op) { ops.push_back(op); });
// Fold the constants in reverse so that the last generated constants from
// folding are at the beginning. This creates somewhat of a linear ordering to
@@ -70,6 +80,6 @@ void TestConstantFold::runOnOperation() {
namespace mlir {
namespace test {
void registerTestConstantFold() { PassRegistration<TestConstantFold>(); }
void registerTestSingleFold() { PassRegistration<TestSingleFold>(); }
} // namespace test
} // namespace mlir

View File

@@ -87,7 +87,6 @@ void registerTestCfAssertPass();
void registerTestCFGLoopInfoPass();
void registerTestComposeSubView();
void registerTestCompositePass();
void registerTestConstantFold();
void registerTestControlFlowSink();
void registerTestConvertToSPIRVPass();
void registerTestDataLayoutPropagation();
@@ -145,6 +144,7 @@ void registerTestSCFUtilsPass();
void registerTestSCFWhileOpBuilderPass();
void registerTestSCFWrapInZeroTripCheckPasses();
void registerTestShapeMappingPass();
void registerTestSingleFold();
void registerTestSliceAnalysisPass();
void registerTestSPIRVCPURunnerPipeline();
void registerTestSPIRVFuncSignatureConversion();
@@ -233,7 +233,6 @@ void registerTestPasses() {
mlir::test::registerTestCFGLoopInfoPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestCompositePass();
mlir::test::registerTestConstantFold();
mlir::test::registerTestControlFlowSink();
mlir::test::registerTestConvertToSPIRVPass();
mlir::test::registerTestDataLayoutPropagation();
@@ -291,6 +290,7 @@ void registerTestPasses() {
mlir::test::registerTestSCFWhileOpBuilderPass();
mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
mlir::test::registerTestShapeMappingPass();
mlir::test::registerTestSingleFold();
mlir::test::registerTestSliceAnalysisPass();
mlir::test::registerTestSPIRVCPURunnerPipeline();
mlir::test::registerTestSPIRVFuncSignatureConversion();