From 6c9be27b526fe1742755778948d0129ace92d357 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Wed, 2 Jul 2025 09:09:03 +0200 Subject: [PATCH] [mlir][tensor] Fold identity `reshape` of 0d-tensors (#146375) Just like 1d-tensors, reshapes of 0d-tensors (aka scalars) are always no-folds as they only have one possible layout. This PR adds logic to the `fold` implementation to optimize these away as is currently implemented for 1d tensors. --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 +++--- mlir/test/Dialect/Tensor/canonicalize.mlir | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 22a25fd1a5af..0430e6fc6c63 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1872,9 +1872,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { if (!sourceTy || !resultTy || sourceTy != resultTy) return {}; - // If the source and result are both 1D tensors and have the same type, the - // reshape has no effect, even if the tensor is dynamically shaped. - if (sourceTy.getRank() == 1) + // If the source and result are both 0D or 1D tensors and have the same type, + // the reshape has no effect, even if the tensor is dynamically shaped. + if (sourceTy.getRank() <= 1) return source; if (auto fromElements = getShape().getDefiningOp()) { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 3f9236095138..95c5b8c91edf 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -971,6 +971,17 @@ func.func @fold_reshape_1d(%input: tensor, %shape: tensor<1xindex>) -> te // ----- +// CHECK-LABEL: func @fold_reshape_0d +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<0xindex> +// CHECK: return %[[INPUT]] +func.func @fold_reshape_0d(%input: tensor, %shape: tensor<0xindex>) -> tensor { + %0 = tensor.reshape %input(%shape) : (tensor, tensor<0xindex>) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @fold_extract_constant_splat // CHECK-NOT: tensor.extract_slice // CHECK: arith.constant dense<42> : tensor<4x4xi32>