[mlir][Vector] Add constant folding for vector.from_elements operation (#145849)
### Summary This PR adds a new folding pattern for **vector.from_elements** that canonicalizes it to **arith.constant** when all input operands are constants. ### Implementation Details **Leverages FoldAdaptor capabilities**: Uses adaptor.getElements() to access **pre-computed** constant attributes, avoiding redundant pattern matching on operands. ### Example Transformation ``` Before: %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c2_i32 = arith.constant 2 : i32 %c3_i32 = arith.constant 3 : i32 %v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32> After: %v = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32> ``` --------- Co-authored-by: Yang Bai <yangb@nvidia.com>
This commit is contained in:
@@ -398,6 +398,18 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Converts an IntegerAttr to have the specified type if needed.
|
||||
/// This handles cases where constant attributes have a different type than the
|
||||
/// target element type. If the input attribute is not an IntegerAttr or already
|
||||
/// has the correct type, returns it unchanged.
|
||||
static Attribute convertIntegerAttr(Attribute attr, Type expectedType) {
|
||||
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
|
||||
if (intAttr.getType() != expectedType)
|
||||
return IntegerAttr::get(expectedType, intAttr.getInt());
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CombiningKindAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -2464,8 +2476,37 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Fold vector.from_elements to a constant when all operands are constants.
|
||||
/// Example:
|
||||
/// %c1 = arith.constant 1 : i32
|
||||
/// %c2 = arith.constant 2 : i32
|
||||
/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
|
||||
/// =>
|
||||
/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
|
||||
///
|
||||
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
|
||||
ArrayRef<Attribute> elements) {
|
||||
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
|
||||
return {};
|
||||
|
||||
auto destVecType = fromElementsOp.getDest().getType();
|
||||
auto destEltType = destVecType.getElementType();
|
||||
// Constant attributes might have a different type than the return type.
|
||||
// Convert them before creating the dense elements attribute.
|
||||
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
|
||||
return convertIntegerAttr(attr, destEltType);
|
||||
});
|
||||
|
||||
return DenseElementsAttr::get(destVecType, convertedElements);
|
||||
}
|
||||
|
||||
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
|
||||
return foldFromElementsToElements(*this);
|
||||
if (auto res = foldFromElementsToElements(*this))
|
||||
return res;
|
||||
if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
|
||||
return res;
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
|
||||
@@ -3332,17 +3373,6 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
|
||||
|
||||
/// Converts the expected type to an IntegerAttr if there's
|
||||
/// a mismatch.
|
||||
auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
|
||||
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
|
||||
if (intAttr.getType() != expectedType)
|
||||
return IntegerAttr::get(expectedType, intAttr.getInt());
|
||||
}
|
||||
return attr;
|
||||
};
|
||||
|
||||
// The `convertIntegerAttr` method specifically handles the case
|
||||
// for `llvm.mlir.constant` which can hold an attribute with a
|
||||
// different type than the return type.
|
||||
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
|
||||
for (auto value : denseSource.getValues<Attribute>())
|
||||
insertedValues.push_back(convertIntegerAttr(value, destEltType));
|
||||
|
||||
@@ -3075,6 +3075,33 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @from_elements_all_elements_constant(
|
||||
func.func @from_elements_all_elements_constant() -> vector<2x2xi32> {
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%c2_i32 = arith.constant 2 : i32
|
||||
%c3_i32 = arith.constant 3 : i32
|
||||
// CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32>
|
||||
%res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
return %res : vector<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @from_elements_partial_elements_constant(
|
||||
// CHECK-SAME: %[[A:.*]]: f32
|
||||
func.func @from_elements_partial_elements_constant(%arg0: f32) -> vector<2xf32> {
|
||||
// CHECK: %[[C:.*]] = arith.constant 1.000000e+00 : f32
|
||||
%c = arith.constant 1.0 : f32
|
||||
// CHECK: %[[RES:.*]] = vector.from_elements %[[A]], %[[C]] : vector<2xf32>
|
||||
%res = vector.from_elements %arg0, %c : vector<2xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
return %res : vector<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @vector_insert_const_regression(
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: vector.insert
|
||||
|
||||
Reference in New Issue
Block a user