[mlir][Vector] Add vector.to_elements op (#141457)

This PR introduces the `vector.to_elements` op, which decomposes a
vector into its scalar elements. This operation is symmetrical to the
existing `vector.from_elements`.

Examples:

```
    // Decompose a 0-D vector.
    %0 = vector.to_elements %v0 : vector<f32>
    // %0 = %v0[0]

    // Decompose a 1-D vector.
    %0:2 = vector.to_elements %v1 : vector<2xf32>
    // %0#0 = %v1[0]
    // %0#1 = %v1[1]

    // Decompose a 2-D.
    %0:6 = vector.to_elements %v2 : vector<2x3xf32>
    // %0#0 = %v2[0, 0]
    // %0#1 = %v2[0, 1]
    // %0#2 = %v2[0, 2]
    // %0#3 = %v2[1, 0]
    // %0#4 = %v2[1, 1]
    // %0#5 = %v2[1, 2]
```

This op is aimed at reducing code size when modeling "structured" vector
extractions and simplifying canonicalizations of large sequences of
`vector.extract` and `vector.insert` ops into `vector.shuffle` and other
sophisticated ops that can re-arrange vector elements.
This commit is contained in:
Diego Caballero
2025-06-18 13:45:43 -07:00
committed by GitHub
parent b85e92990f
commit 7aecd7ecac
6 changed files with 184 additions and 25 deletions

View File

@@ -2787,6 +2787,11 @@ private:
void handleTypesMatchConstraint(
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
/// Check for inferable type resolution based on
/// `ShapedTypeMatchesElementCountAndTypes` constraint.
void handleShapedTypeMatchesElementCountAndTypesConstraint(
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
/// Returns an argument or attribute with the given name that has been seen
/// within the format.
ConstArgument findSeenArg(StringRef name);
@@ -2850,6 +2855,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
handleTypesMatchConstraint(variableTyResolver, def);
} else if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
handleShapedTypeMatchesElementCountAndTypesConstraint(variableTyResolver,
def);
} else if (!op.allResultTypesKnown()) {
// This doesn't check the name directly to handle
// DeclareOpInterfaceMethods<InferTypeOpInterface>
@@ -3289,6 +3297,24 @@ void OpFormatParser::handleTypesMatchConstraint(
variableTyResolver[rhsName] = {arg, transformer};
}
void OpFormatParser::handleShapedTypeMatchesElementCountAndTypesConstraint(
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
StringRef shapedArg = def.getValueAsString("shaped");
StringRef elementsArg = def.getValueAsString("elements");
// Check if the 'shaped' argument is seen, then we can infer the 'elements'
// types.
if (ConstArgument arg = findSeenArg(shapedArg)) {
variableTyResolver[elementsArg] = {
arg, "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
"ShapedType>($_self).getNumElements(), "
"::llvm::cast<::mlir::ShapedType>($_self).getElementType())"};
}
// Type inference in the opposite direction is not possible as the actual
// shaped type can't be inferred from the variadic elements.
}
ConstArgument OpFormatParser::findSeenArg(StringRef name) {
if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;