[mlir][tensor] Fold identity reshape of 0d-tensors (#146375)
Just like 1d-tensors, reshapes of 0d-tensors (aka scalars) are always no-folds as they only have one possible layout. This PR adds logic to the `fold` implementation to optimize these away as is currently implemented for 1d tensors.
This commit is contained in:
@@ -1872,9 +1872,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
|
||||
if (!sourceTy || !resultTy || sourceTy != resultTy)
|
||||
return {};
|
||||
|
||||
// If the source and result are both 1D tensors and have the same type, the
|
||||
// reshape has no effect, even if the tensor is dynamically shaped.
|
||||
if (sourceTy.getRank() == 1)
|
||||
// If the source and result are both 0D or 1D tensors and have the same type,
|
||||
// the reshape has no effect, even if the tensor is dynamically shaped.
|
||||
if (sourceTy.getRank() <= 1)
|
||||
return source;
|
||||
|
||||
if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
|
||||
|
||||
@@ -971,6 +971,17 @@ func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> te
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_reshape_0d
|
||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<f32>
|
||||
// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<0xindex>
|
||||
// CHECK: return %[[INPUT]]
|
||||
func.func @fold_reshape_0d(%input: tensor<f32>, %shape: tensor<0xindex>) -> tensor<f32> {
|
||||
%0 = tensor.reshape %input(%shape) : (tensor<f32>, tensor<0xindex>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_extract_constant_splat
|
||||
// CHECK-NOT: tensor.extract_slice
|
||||
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
|
||||
|
||||
Reference in New Issue
Block a user