[mlir][tensor] Fold identity reshape of 0d-tensors (#146375)

Just like 1d-tensors, reshapes of 0d-tensors (aka scalars) are always
no-folds as they only have one possible layout. This PR adds logic to
the `fold` implementation to optimize these away as is currently
implemented for 1d tensors.
This commit is contained in:
Markus Böck
2025-07-02 09:09:03 +02:00
committed by GitHub
parent 9262ac3ee4
commit 6c9be27b52
2 changed files with 14 additions and 3 deletions

View File

@@ -1872,9 +1872,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!sourceTy || !resultTy || sourceTy != resultTy)
return {};
// If the source and result are both 1D tensors and have the same type, the
// reshape has no effect, even if the tensor is dynamically shaped.
if (sourceTy.getRank() == 1)
// If the source and result are both 0D or 1D tensors and have the same type,
// the reshape has no effect, even if the tensor is dynamically shaped.
if (sourceTy.getRank() <= 1)
return source;
if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {

View File

@@ -971,6 +971,17 @@ func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> te
// -----
// CHECK-LABEL: func @fold_reshape_0d
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<f32>
// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<0xindex>
// CHECK: return %[[INPUT]]
func.func @fold_reshape_0d(%input: tensor<f32>, %shape: tensor<0xindex>) -> tensor<f32> {
%0 = tensor.reshape %input(%shape) : (tensor<f32>, tensor<0xindex>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @fold_extract_constant_splat
// CHECK-NOT: tensor.extract_slice
// CHECK: arith.constant dense<42> : tensor<4x4xi32>