[mlir][Vector] Add vector.to_elements op (#141457)
This PR introduces the `vector.to_elements` op, which decomposes a
vector into its scalar elements. This operation is symmetrical to the
existing `vector.from_elements`.
Examples:
```
// Decompose a 0-D vector.
%0 = vector.to_elements %v0 : vector<f32>
// %0 = %v0[0]
// Decompose a 1-D vector.
%0:2 = vector.to_elements %v1 : vector<2xf32>
// %0#0 = %v1[0]
// %0#1 = %v1[1]
// Decompose a 2-D.
%0:6 = vector.to_elements %v2 : vector<2x3xf32>
// %0#0 = %v2[0, 0]
// %0#1 = %v2[0, 1]
// %0#2 = %v2[0, 2]
// %0#3 = %v2[1, 0]
// %0#4 = %v2[1, 1]
// %0#5 = %v2[1, 2]
```
This op is aimed at reducing code size when modeling "structured" vector
extractions and simplifying canonicalizations of large sequences of
`vector.extract` and `vector.insert` ops into `vector.shuffle` and other
sophisticated ops that can re-arrange vector elements.
This commit is contained in:
@@ -2787,6 +2787,11 @@ private:
|
||||
void handleTypesMatchConstraint(
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
|
||||
|
||||
/// Check for inferable type resolution based on
|
||||
/// `ShapedTypeMatchesElementCountAndTypes` constraint.
|
||||
void handleShapedTypeMatchesElementCountAndTypesConstraint(
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
|
||||
|
||||
/// Returns an argument or attribute with the given name that has been seen
|
||||
/// within the format.
|
||||
ConstArgument findSeenArg(StringRef name);
|
||||
@@ -2850,6 +2855,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
|
||||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
|
||||
} else if (def.isSubClassOf("TypesMatchWith")) {
|
||||
handleTypesMatchConstraint(variableTyResolver, def);
|
||||
} else if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
|
||||
handleShapedTypeMatchesElementCountAndTypesConstraint(variableTyResolver,
|
||||
def);
|
||||
} else if (!op.allResultTypesKnown()) {
|
||||
// This doesn't check the name directly to handle
|
||||
// DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||
@@ -3289,6 +3297,24 @@ void OpFormatParser::handleTypesMatchConstraint(
|
||||
variableTyResolver[rhsName] = {arg, transformer};
|
||||
}
|
||||
|
||||
void OpFormatParser::handleShapedTypeMatchesElementCountAndTypesConstraint(
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
|
||||
StringRef shapedArg = def.getValueAsString("shaped");
|
||||
StringRef elementsArg = def.getValueAsString("elements");
|
||||
|
||||
// Check if the 'shaped' argument is seen, then we can infer the 'elements'
|
||||
// types.
|
||||
if (ConstArgument arg = findSeenArg(shapedArg)) {
|
||||
variableTyResolver[elementsArg] = {
|
||||
arg, "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
|
||||
"ShapedType>($_self).getNumElements(), "
|
||||
"::llvm::cast<::mlir::ShapedType>($_self).getElementType())"};
|
||||
}
|
||||
|
||||
// Type inference in the opposite direction is not possible as the actual
|
||||
// shaped type can't be inferred from the variadic elements.
|
||||
}
|
||||
|
||||
ConstArgument OpFormatParser::findSeenArg(StringRef name) {
|
||||
if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
|
||||
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
|
||||
|
||||
Reference in New Issue
Block a user