This patch generalizes tensor.expand_shape and memref.expand_shape to consume the output shape as a list of SSA values. This enables us to implement generic reshape operations with dynamic shapes using collapse_shape/expand_shape pairs. The output_shape input to expand_shape follows the static/dynamic representation that's also used in `tensor.extract_slice`. Differential Revision: https://reviews.llvm.org/D140821 --------- Signed-off-by: Gaurav Shukla<gaurav.shukla@amd.com> Signed-off-by: Gaurav Shukla <gaurav.shukla@amd.com> Co-authored-by: Ramiro Leal-Cavazos <ramiroleal050@gmail.com>
16 lines
778 B
MLIR
16 lines
778 B
MLIR
// RUN: mlir-opt %s -generate-runtime-verification -cse | FileCheck %s
|
|
|
|
// CHECK-LABEL: func @expand_shape(
|
|
// CHECK-SAME: %[[m:.*]]: memref<?xf32>
|
|
// CHECK-SAME: %[[sz0:.*]]: index
|
|
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
|
|
// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
|
|
// CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]]
|
|
// CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]]
|
|
// CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]]
|
|
// CHECK: cf.assert %[[cmpi]], "ERROR: Runtime op verification failed
|
|
func.func @expand_shape(%m: memref<?xf32>, %sz0: index) -> memref<?x5xf32> {
|
|
%0 = memref.expand_shape %m [[0, 1]] output_shape [%sz0, 5] : memref<?xf32> into memref<?x5xf32>
|
|
return %0 : memref<?x5xf32>
|
|
}
|