[mlir][Vector] Add constant folding for vector.from_elements operation (#145849)

### Summary

This PR adds a new folding pattern for **vector.from_elements** that
canonicalizes it to **arith.constant** when all input operands are
constants.

### Implementation Details

**Leverages FoldAdaptor capabilities**: Uses adaptor.getElements() to
access **pre-computed** constant attributes, avoiding redundant pattern
matching on operands.

### Example Transformation
```
Before:
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>

After:
%v = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
```

---------

Co-authored-by: Yang Bai <yangb@nvidia.com>
This commit is contained in:
Yang Bai
2025-07-01 11:39:53 +08:00
committed by GitHub
parent 0a69c83421
commit 393a75ebb7
2 changed files with 69 additions and 12 deletions

View File

@@ -398,6 +398,18 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
return {};
}
/// Converts an IntegerAttr to have the specified type if needed.
/// This handles cases where constant attributes have a different type than the
/// target element type. If the input attribute is not an IntegerAttr or already
/// has the correct type, returns it unchanged.
static Attribute convertIntegerAttr(Attribute attr, Type expectedType) {
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
if (intAttr.getType() != expectedType)
return IntegerAttr::get(expectedType, intAttr.getInt());
}
return attr;
}
//===----------------------------------------------------------------------===//
// CombiningKindAttr
//===----------------------------------------------------------------------===//
@@ -2464,8 +2476,37 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
return {};
}
/// Fold vector.from_elements to a constant when all operands are constants.
/// Example:
/// %c1 = arith.constant 1 : i32
/// %c2 = arith.constant 2 : i32
/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
/// =>
/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
///
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
ArrayRef<Attribute> elements) {
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
return {};
auto destVecType = fromElementsOp.getDest().getType();
auto destEltType = destVecType.getElementType();
// Constant attributes might have a different type than the return type.
// Convert them before creating the dense elements attribute.
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
return convertIntegerAttr(attr, destEltType);
});
return DenseElementsAttr::get(destVecType, convertedElements);
}
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
return foldFromElementsToElements(*this);
if (auto res = foldFromElementsToElements(*this))
return res;
if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
return res;
return {};
}
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
@@ -3332,17 +3373,6 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
/// Converts the expected type to an IntegerAttr if there's
/// a mismatch.
auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
if (intAttr.getType() != expectedType)
return IntegerAttr::get(expectedType, intAttr.getInt());
}
return attr;
};
// The `convertIntegerAttr` method specifically handles the case
// for `llvm.mlir.constant` which can hold an attribute with a
// different type than the return type.
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
for (auto value : denseSource.getValues<Attribute>())
insertedValues.push_back(convertIntegerAttr(value, destEltType));

View File

@@ -3075,6 +3075,33 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
// -----
// CHECK-LABEL: func @from_elements_all_elements_constant(
func.func @from_elements_all_elements_constant() -> vector<2x2xi32> {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
// CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32>
%res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
// CHECK: return %[[RES]]
return %res : vector<2x2xi32>
}
// -----
// CHECK-LABEL: func @from_elements_partial_elements_constant(
// CHECK-SAME: %[[A:.*]]: f32
func.func @from_elements_partial_elements_constant(%arg0: f32) -> vector<2xf32> {
// CHECK: %[[C:.*]] = arith.constant 1.000000e+00 : f32
%c = arith.constant 1.0 : f32
// CHECK: %[[RES:.*]] = vector.from_elements %[[A]], %[[C]] : vector<2xf32>
%res = vector.from_elements %arg0, %c : vector<2xf32>
// CHECK: return %[[RES]]
return %res : vector<2xf32>
}
// -----
// CHECK-LABEL: func @vector_insert_const_regression(
// CHECK: llvm.mlir.undef
// CHECK: vector.insert