[mlir][memref] Add a new ReifyResultShapes pass (#145927)
This pass reifies the shapes of a subset of
`ReifyRankedShapedTypeOpInterface` ops with `tensor` results.
The pass currently only supports result shape type reification for:
- tensor::PadOp
- tensor::ConcatOp
It addresses a representation gap where implicit op semantics are needed
to infer static result types from dynamic
operands. But it does so by using `ReifyRankedShapedTypeOpInterface` as
the source of truth rather than the op itself.
As a consequence, this cannot generalize today.
TODO: in the future, we should consider coupling this information with
op "transfer functions" (e.g.
`IndexingMapOpInterface`) to provide a source of truth that can work
across result shape inference, canonicalization and
op verifiers.
The pass replaces the operations with their reified versions, when more
static information can be derived, and inserts
casts when results shapes are updated.
Example:
```mlir
#map = affine_map<(d0) -> (-d0 + 256)>
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
%0 = affine.apply #map(%arg1)
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
^bb0(%arg3: index, %arg4: index, %arg5: index):
tensor.yield %arg0 : f32
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
return %padded : tensor<1x?x64xf32>
}
// mlir-opt --reify-result-shapes
#map = affine_map<()[s0] -> (-s0 + 256)>
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
%0 = affine.apply #map()[%arg1]
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
^bb0(%arg3: index, %arg4: index, %arg5: index):
tensor.yield %arg0 : f32
} : tensor<1x?x64xf32> to tensor<1x256x64xf32>
%cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
return %cast : tensor<1x?x64xf32>
}
```
---------
Co-authored-by: Fabian Mora <fabian.mora-cordero@amd.com>
This commit is contained in:
committed by
GitHub
parent
3355cca938
commit
08cf6ae537
31
mlir/test/Dialect/Tensor/reify-shapes.mlir
Normal file
31
mlir/test/Dialect/Tensor/reify-shapes.mlir
Normal file
@@ -0,0 +1,31 @@
|
||||
// RUN: mlir-opt -reify-result-shapes %s | FileCheck %s
|
||||
|
||||
// The test below checks concat op reification. In the first case, no cast is inserted while on the second a cast gets inserted.
|
||||
// CHECK-LABEL: func.func @concat_reification
|
||||
func.func @concat_reification(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>)
|
||||
-> (tensor<4x11x3xf32>, tensor<?x?x?xf32>) {
|
||||
// CHECK: %[[RES0:.*]] = tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
|
||||
%1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
|
||||
// CHECK: %[[V0:.*]] = tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<4x7x?xf32>
|
||||
// CHECK: %[[RES1:.*]] = tensor.cast %[[V0]] : tensor<4x7x?xf32> to tensor<?x?x?xf32>
|
||||
%2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: return %[[RES0]], %[[RES1]] : tensor<4x11x3xf32>, tensor<?x?x?xf32>
|
||||
return %1, %2 : tensor<4x11x3xf32>, tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @pad_reification
|
||||
func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
|
||||
%pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
|
||||
%es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1]
|
||||
: tensor<64x?x64xf32> to tensor<1x?x64xf32>
|
||||
|
||||
// CHECK: tensor.pad
|
||||
// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
|
||||
// CHECK: tensor.cast %{{.*}} : tensor<1x256x64xf32> to tensor<1x?x64xf32>
|
||||
%padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
|
||||
^bb0(%a: index, %b: index, %c: index):
|
||||
tensor.yield %cst : f32
|
||||
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
|
||||
|
||||
return %padded : tensor<1x?x64xf32>
|
||||
}
|
||||
Reference in New Issue
Block a user