[mlir][tensor][memref] Enhance collapse(expand(src)) canonicalization pattern. (#145995)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
|
||||
#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
@@ -305,8 +306,43 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
|
||||
rewriter.replaceOpWithNewOp<CollapseOpTy>(
|
||||
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
|
||||
} else if (srcRank < resultRank) {
|
||||
// Compute the dynamic output shape for the new expand_shape op.
|
||||
Location loc = collapseOp.getLoc();
|
||||
SmallVector<OpFoldResult> origOutputShape =
|
||||
expandOp.getMixedOutputShape();
|
||||
SmallVector<OpFoldResult> newOutputShape;
|
||||
for (const ReassociationIndices &indices :
|
||||
collapseOp.getReassociationIndices()) {
|
||||
int64_t numStaticElems = 1;
|
||||
SmallVector<Value> dynamicSizes;
|
||||
for (int64_t idx : indices) {
|
||||
OpFoldResult size = origOutputShape[idx];
|
||||
if (std::optional<int64_t> maybeCst = getConstantIntValue(size)) {
|
||||
numStaticElems *= maybeCst.value();
|
||||
continue;
|
||||
}
|
||||
dynamicSizes.push_back(cast<Value>(size));
|
||||
}
|
||||
if (dynamicSizes.empty()) {
|
||||
newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems));
|
||||
continue;
|
||||
}
|
||||
|
||||
// There is at least one dynamic size, so we can initialize `result` to
|
||||
// the first dynamic size.
|
||||
Value result = dynamicSizes[0];
|
||||
for (Value v : llvm::drop_begin(dynamicSizes))
|
||||
result = rewriter.create<arith::MulIOp>(loc, result, v);
|
||||
if (numStaticElems != 1) {
|
||||
result = rewriter.create<arith::MulIOp>(
|
||||
loc, result,
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, numStaticElems));
|
||||
}
|
||||
newOutputShape.push_back(result);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<ExpandOpTy>(
|
||||
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
|
||||
collapseOp, resultType, expandOp.getSrc(), composedReassociation,
|
||||
newOutputShape);
|
||||
} else {
|
||||
// Collapses/expansions that do not change the rank are not allowed. Use
|
||||
// a cast instead.
|
||||
|
||||
@@ -466,6 +466,24 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref<?xf16>, %arg1: index, %arg2: index) -> memref<8x?x?xf16> {
|
||||
%expanded = memref.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : memref<?xf16> into memref<4x2x?x?x32xf16>
|
||||
%collapsed = memref.collapse_shape %expanded [[0, 1], [2], [3, 4]] : memref<4x2x?x?x32xf16> into memref<8x?x?xf16>
|
||||
return %collapsed : memref<8x?x?xf16>
|
||||
}
|
||||
// CHECK: func @compose_collapse_of_expand_partially_dynamic
|
||||
// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C32:.+]] = arith.constant 32
|
||||
// CHECK: %[[COLLAPSED_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
|
||||
// CHECK: %[[RESULT:.+]] = memref.expand_shape %[[SRC]]
|
||||
// CHECK-SAME: [0, 1, 2]
|
||||
// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[COLLAPSED_D2]]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @do_not_compose_collapse_of_expand_non_identity_layout(
|
||||
%arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
|
||||
-> memref<?xf32, strided<[?], offset: 0>> {
|
||||
|
||||
@@ -1243,6 +1243,24 @@ func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor<?xf16>, %arg1: index, %arg2: index) -> tensor<8x?x?xf16> {
|
||||
%expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : tensor<?xf16> into tensor<4x2x?x?x32xf16>
|
||||
%collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4]] : tensor<4x2x?x?x32xf16> into tensor<8x?x?xf16>
|
||||
return %collapsed : tensor<8x?x?xf16>
|
||||
}
|
||||
// CHECK: func @compose_collapse_of_expand_partially_dynamic
|
||||
// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C32:.+]] = arith.constant 32
|
||||
// CHECK: %[[COLLAPSED_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
|
||||
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SRC]]
|
||||
// CHECK-SAME: [0, 1, 2]
|
||||
// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[COLLAPSED_D2]]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
|
||||
-> tensor<1x1x1x1xf32> {
|
||||
%0 = tensor.collapse_shape %arg0 []
|
||||
|
||||
Reference in New Issue
Block a user