[[mlir][Vector] Add simple folders for vector.from_element/vector.to_elements (#144444)
This PR adds simple folders to remove no-op sequences of `vector.from_elements` and `vector.to_elements`.
This commit is contained in:
@@ -836,6 +836,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
|
||||
let arguments = (ins AnyVectorOfAnyRank:$source);
|
||||
let results = (outs Variadic<AnyType>:$elements);
|
||||
let assemblyFormat = "$source attr-dict `:` type($source)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Vector_FromElementsOp : Vector_Op<"from_elements", [
|
||||
@@ -873,6 +874,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
|
||||
let arguments = (ins Variadic<AnyType>:$elements);
|
||||
let results = (outs AnyFixedVectorOfAnyRank:$dest);
|
||||
let assemblyFormat = "$elements attr-dict `:` type($dest)";
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -2373,10 +2373,95 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
|
||||
return llvm::to_vector<4>(getVectorType().getShape());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToElementsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true if all the `operands` are defined by `defOp`.
|
||||
/// Otherwise, returns false.
|
||||
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
|
||||
if (operands.empty())
|
||||
return false;
|
||||
|
||||
return llvm::all_of(operands, [&](Value operand) {
|
||||
Operation *currentDef = operand.getDefiningOp();
|
||||
return currentDef == defOp;
|
||||
});
|
||||
}
|
||||
|
||||
/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
|
||||
/// (%e0, %e1, ...). For example:
|
||||
///
|
||||
/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
|
||||
/// %1:3 = vector.to_elements %0 : vector<3xf32>
|
||||
/// user_op %1#0, %1#1, %1#2
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// user_op %a, %b, %c
|
||||
///
|
||||
static LogicalResult
|
||||
foldToElementsFromElements(ToElementsOp toElementsOp,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
auto fromElementsOp =
|
||||
toElementsOp.getSource().getDefiningOp<FromElementsOp>();
|
||||
if (!fromElementsOp)
|
||||
return failure();
|
||||
|
||||
llvm::append_range(results, fromElementsOp.getElements());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
return foldToElementsFromElements(*this, results);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FromElementsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
|
||||
///
|
||||
/// Case #1: Input and output vectors are the same.
|
||||
///
|
||||
/// %0:3 = vector.to_elements %a : vector<3xf32>
|
||||
/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
|
||||
/// user_op %1
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// user_op %a
|
||||
///
|
||||
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
|
||||
OperandRange fromElemsOperands = fromElementsOp.getElements();
|
||||
if (fromElemsOperands.empty())
|
||||
return {};
|
||||
|
||||
auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
|
||||
if (!toElementsOp)
|
||||
return {};
|
||||
|
||||
if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
|
||||
return {};
|
||||
|
||||
// Case #1: Input and output vectors are the same. Forward the input vector.
|
||||
Value toElementsInput = toElementsOp.getSource();
|
||||
if (fromElementsOp.getType() == toElementsInput.getType() &&
|
||||
llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
|
||||
return toElementsInput;
|
||||
}
|
||||
|
||||
// TODO: Support cases with different input and output shapes and different
|
||||
// number of elements.
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
|
||||
return foldFromElementsToElements(*this);
|
||||
}
|
||||
|
||||
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
|
||||
/// same SSA value. E.g.:
|
||||
///
|
||||
|
||||
@@ -3023,6 +3023,58 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @to_elements_from_elements_no_op(
|
||||
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32
|
||||
func.func @to_elements_from_elements_no_op(%a: f32, %b: f32) -> (f32, f32) {
|
||||
// CHECK-NOT: vector.from_elements
|
||||
// CHECK-NOT: vector.to_elements
|
||||
%0 = vector.from_elements %b, %a : vector<2xf32>
|
||||
%1:2 = vector.to_elements %0 : vector<2xf32>
|
||||
// CHECK: return %[[B]], %[[A]]
|
||||
return %1#0, %1#1 : f32, f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @from_elements_to_elements_no_op(
|
||||
// CHECK-SAME: %[[A:.*]]: vector<4x2xf32>
|
||||
func.func @from_elements_to_elements_no_op(%a: vector<4x2xf32>) -> vector<4x2xf32> {
|
||||
// CHECK-NOT: vector.from_elements
|
||||
// CHECK-NOT: vector.to_elements
|
||||
%0:8 = vector.to_elements %a : vector<4x2xf32>
|
||||
%1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : vector<4x2xf32>
|
||||
// CHECK: return %[[A]]
|
||||
return %1 : vector<4x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @from_elements_to_elements_dup_elems(
|
||||
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
|
||||
func.func @from_elements_to_elements_dup_elems(%a: vector<4xf32>) -> vector<4x2xf32> {
|
||||
// CHECK: %[[TO_EL:.*]]:4 = vector.to_elements %[[A]]
|
||||
// CHECK: %[[FROM_EL:.*]] = vector.from_elements %[[TO_EL]]#0, %[[TO_EL]]#1, %[[TO_EL]]#2
|
||||
%0:4 = vector.to_elements %a : vector<4xf32> // 4 elements
|
||||
%1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#0, %0#1, %0#2, %0#3 : vector<4x2xf32>
|
||||
// CHECK: return %[[FROM_EL]]
|
||||
return %1 : vector<4x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @from_elements_to_elements_shuffle(
|
||||
// CHECK-SAME: %[[A:.*]]: vector<4x2xf32>
|
||||
func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2xf32> {
|
||||
// CHECK: %[[TO_EL:.*]]:8 = vector.to_elements %[[A]]
|
||||
// CHECK: %[[FROM_EL:.*]] = vector.from_elements %[[TO_EL]]#7, %[[TO_EL]]#0, %[[TO_EL]]#6
|
||||
%0:8 = vector.to_elements %a : vector<4x2xf32>
|
||||
%1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<4x2xf32>
|
||||
// CHECK: return %[[FROM_EL]]
|
||||
return %1 : vector<4x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @vector_insert_const_regression(
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: vector.insert
|
||||
|
||||
Reference in New Issue
Block a user