[MLIR] Add allow Insert/extract slice option to pack/unpack op (#117340)
This PR adds default option below. The new options will come as default to true and not change the original lowering behavior of pack and unpack op. - lowerPadLikeWithInsertSlice to packOp (with default = true) - lowerUnpadLikeWithExtractSlice to unPackOp (with default = true) The motivation of the PR is finer granular control of the lowering of pack and unpack Ops. This is useful in particular when we want to guarantee that there's no additional insertslice and extractslice that interfere with tiling. With the original lowering pipeline, packOp and unPackOp may be lowered to insertslice and extractslice when the high dimensions are unit dimensions and no transpose is invovled. Under such circumstances, such insert and extract slice ops will block producer/consumer fusion tile + fuse transforms. With this PR, we will be able to disable such lowering path and allow consumer fusion to go through as expected.
This commit is contained in:
@@ -559,7 +559,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
|
||||
Return handles to the newly produced pad, expand_shape and transpose ops.
|
||||
}];
|
||||
|
||||
let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
|
||||
let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$lowerPadLikeWithInsertSlice);
|
||||
let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
|
||||
Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
|
||||
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
|
||||
@@ -599,7 +600,8 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
|
||||
Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops.
|
||||
}];
|
||||
|
||||
let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target);
|
||||
let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$lowerUnpadLikeWithExtractSlice);
|
||||
let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
|
||||
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
|
||||
Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,
|
||||
|
||||
@@ -1121,7 +1121,8 @@ struct LowerPackResult {
|
||||
|
||||
/// Rewrite pack as pad + reshape + transpose.
|
||||
FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
|
||||
tensor::PackOp packOp);
|
||||
tensor::PackOp packOp,
|
||||
bool lowerPadLikeWithInsertSlice = true);
|
||||
|
||||
struct LowerUnPackOpResult {
|
||||
tensor::EmptyOp emptyOp;
|
||||
@@ -1131,8 +1132,9 @@ struct LowerUnPackOpResult {
|
||||
};
|
||||
|
||||
/// Rewrite pack as empty + transpose + reshape + extract_slice.
|
||||
FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
|
||||
tensor::UnPackOp unPackOp);
|
||||
FailureOr<LowerUnPackOpResult>
|
||||
lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
|
||||
bool lowerUnpadLikeWithExtractSlice = true);
|
||||
|
||||
/// Struct to hold the result of a `pack` call.
|
||||
struct PackResult {
|
||||
|
||||
@@ -1176,7 +1176,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
|
||||
transform::ApplyToEachResultList &transformResults,
|
||||
transform::TransformState &state) {
|
||||
rewriter.setInsertionPoint(target);
|
||||
FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
|
||||
bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
|
||||
FailureOr<LowerPackResult> res =
|
||||
lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
|
||||
if (failed(res)) {
|
||||
return mlir::emitSilenceableFailure(target->getLoc())
|
||||
<< "cannot lower to pad + expand + transpose";
|
||||
@@ -1196,7 +1198,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
|
||||
transform::ApplyToEachResultList &transformResults,
|
||||
transform::TransformState &state) {
|
||||
rewriter.setInsertionPoint(target);
|
||||
FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
|
||||
bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
|
||||
FailureOr<LowerUnPackOpResult> res =
|
||||
lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
|
||||
if (failed(res)) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
emitSilenceableError()
|
||||
|
||||
@@ -217,7 +217,8 @@ private:
|
||||
} // namespace
|
||||
|
||||
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
|
||||
tensor::PackOp packOp) {
|
||||
tensor::PackOp packOp,
|
||||
bool lowerPadLikeWithInsertSlice) {
|
||||
// 1. Filter out NYI cases.
|
||||
auto packedTensorType =
|
||||
cast<RankedTensorType>(packOp->getResultTypes().front());
|
||||
@@ -295,7 +296,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
|
||||
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
|
||||
DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
|
||||
|
||||
if (packOp.isLikePad()) {
|
||||
if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
|
||||
// Pack ops which operate as simple pads may not produce legal
|
||||
// tensor.insert_slice operations when the packed type does not rank reduce
|
||||
// to the padded type.
|
||||
@@ -351,8 +352,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
|
||||
return LowerPackResult{padOp, reshapeOp, transposeOp};
|
||||
}
|
||||
|
||||
FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
|
||||
tensor::UnPackOp unPackOp) {
|
||||
FailureOr<LowerUnPackOpResult>
|
||||
linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
|
||||
bool lowerUnpadLikeWithExtractSlice) {
|
||||
Location loc = unPackOp->getLoc();
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(unPackOp);
|
||||
@@ -362,7 +364,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
|
||||
|
||||
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
|
||||
auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
|
||||
if (unPackOp.isLikeUnPad()) {
|
||||
if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
|
||||
// This unpack is just a plain unpad.
|
||||
// Just extract the slice from the higher ranked tensor.
|
||||
ArrayRef<int64_t> destShape = destTensorType.getShape();
|
||||
|
||||
@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
|
||||
// be lowered to insert_slice.
|
||||
// CHECK-LABEL: func.func @pack_as_pad_disabled_insert_slice(
|
||||
func.func @pack_as_pad_disabled_insert_slice(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
|
||||
%cst_0 = arith.constant 0.0 : f32
|
||||
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
|
||||
// CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
|
||||
// CHECK-DAG: %[[PAD:.*]] = tensor.pad %[[ARG0]]
|
||||
// CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
|
||||
// CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
|
||||
// CHECK-DAG: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
|
||||
%pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
|
||||
: tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
|
||||
return %pack : tensor<1x1x1x1x136x64x16x16xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
|
||||
%pack = transform.structured.match ops{["tensor.pack"]} in %module_op
|
||||
: (!transform.any_op) -> !transform.op<"tensor.pack">
|
||||
transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}: (!transform.op<"tensor.pack">)
|
||||
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we don't lower the following pack as a pad.
|
||||
// Although all the outer most dimensions in the resulting shape are 1s,
|
||||
// some of the original dimensions are not part of the inner_dims_pos, hence
|
||||
@@ -233,6 +261,38 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
|
||||
// be lowered to extract_slice.
|
||||
// CHECK-LABEL: func.func @unpack_as_pad_disabled_extract_slice(
|
||||
func.func @unpack_as_pad_disabled_extract_slice(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
|
||||
%cst_0 = arith.constant 0.0 : f32
|
||||
|
||||
// tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
|
||||
// CHECK-DAG: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
|
||||
// CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
|
||||
// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
|
||||
// CHECK-DAG: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
|
||||
%pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
|
||||
: tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
|
||||
return %pack : tensor<129x47x16x16xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
|
||||
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
|
||||
: (!transform.any_op) -> !transform.op<"tensor.unpack">
|
||||
transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}: (!transform.op<"tensor.unpack">)
|
||||
-> (!transform.op<"tensor.empty">,
|
||||
!transform.op<"linalg.transpose">,
|
||||
!transform.op<"tensor.collapse_shape">,
|
||||
!transform.op<"tensor.extract_slice">)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
|
||||
func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
|
||||
%dest: tensor<200x4x16x100x16x32xi32>)
|
||||
@@ -572,7 +632,7 @@ func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?x
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
|
||||
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
|
||||
: (!transform.any_op) -> !transform.op<"tensor.unpack">
|
||||
: (!transform.any_op) -> !transform.op<"tensor.unpack">
|
||||
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
|
||||
-> (!transform.op<"tensor.empty">,
|
||||
!transform.op<"linalg.transpose">,
|
||||
@@ -627,9 +687,9 @@ module attributes {transform.with_named_sequence} {
|
||||
// CHECK-LABEL: @unpack_with_outer_dims_perm
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
|
||||
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
|
||||
// CHECK: %[[TRAN:.*]] = linalg.transpose
|
||||
// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
|
||||
// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
|
||||
// CHECK: %[[TRAN:.*]] = linalg.transpose
|
||||
// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
|
||||
// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
|
||||
// CHECK-SAME: permutation = [1, 3, 0, 2]
|
||||
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
|
||||
// CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32>
|
||||
@@ -638,7 +698,7 @@ module attributes {transform.with_named_sequence} {
|
||||
// CHECK: linalg.copy ins(%[[SLICE]]
|
||||
// CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
|
||||
func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
|
||||
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
|
||||
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
|
||||
inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
|
||||
return %unpack : tensor<32x64xf32>
|
||||
}
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s
|
||||
|
||||
// For pack op, we use lowerPadLikeWithInsertSlice = false to ensure no insert_slice is generated.
|
||||
// This allows linalg.transpose to be fused as a producer operation. In below testcase, linalg.transpose
|
||||
// as a producer operation is fused into the scf.forall loop.
|
||||
|
||||
module {
|
||||
// CHECK-label: func @fuse_pack_as_producer
|
||||
// CHECK: scf.forall {{.*}} {
|
||||
// CHECK: %[[PRODUCER:.*]] = linalg.transpose
|
||||
// CHECK: linalg.generic {{.*}} ins(%[[PRODUCER]]
|
||||
// CHECK: scf.forall.in_parallel
|
||||
// CHECK: }
|
||||
func.func @fuse_pack_as_producer(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
|
||||
-> tensor<4x4x128x256xf32> {
|
||||
%dest = tensor.empty() : tensor<1x1x128x256xf32>
|
||||
%pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256]
|
||||
into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32>
|
||||
|
||||
%out = tensor.empty() : tensor<4x4x128x256xf32>
|
||||
%res = linalg.generic
|
||||
{indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>,
|
||||
affine_map<(i, j, k, l) -> (i, j, k, l)>,
|
||||
affine_map<(i, j, k, l) -> (i, j, k, l)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>)
|
||||
outs(%out: tensor<4x4x128x256xf32>) {
|
||||
^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32):
|
||||
%r = arith.addf %pack_elem, %other_elem : f32
|
||||
linalg.yield %r : f32
|
||||
} -> tensor<4x4x128x256xf32>
|
||||
|
||||
return %res : tensor<4x4x128x256xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
// Find and lower pack operation.
|
||||
%pack = transform.structured.match ops{["tensor.pack"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.op<"tensor.pack">
|
||||
%paded, %expanded, %transpose = transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}
|
||||
: (!transform.op<"tensor.pack">)
|
||||
-> (!transform.op<"tensor.pad">,
|
||||
!transform.op<"tensor.expand_shape">,
|
||||
!transform.op<"linalg.transpose">)
|
||||
|
||||
%root = transform.structured.match ops{["linalg.generic"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
// Tile the lialg operation with parallel forall loop tiling [4, 4].
|
||||
%tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
|
||||
// Fuse the transpose operation into the tiled loop.
|
||||
transform.structured.fuse_into_containing_op %transpose into %forall_op
|
||||
: (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
// For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion.
|
||||
// In below testcase, tensor.insert_slice as a producer operation cannot be fused into the scf.forall loop.
|
||||
|
||||
module {
|
||||
// CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice
|
||||
// CHECK: %[[PRODUCER:.*]] = tensor.insert_slice
|
||||
// CHECK: scf.forall {{.*}} {
|
||||
// CHECK: linalg.generic {{.*}} ins(%[[PRODUCER]]
|
||||
// CHECK: scf.forall.in_parallel
|
||||
// CHECK: }
|
||||
func.func @fuse_pack_as_producer_blocked_by_insert_slice(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
|
||||
-> tensor<4x4x128x256xf32> {
|
||||
%dest = tensor.empty() : tensor<1x1x128x256xf32>
|
||||
%pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256]
|
||||
into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32>
|
||||
|
||||
%out = tensor.empty() : tensor<4x4x128x256xf32>
|
||||
%res = linalg.generic
|
||||
{indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>,
|
||||
affine_map<(i, j, k, l) -> (i, j, k, l)>,
|
||||
affine_map<(i, j, k, l) -> (i, j, k, l)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>)
|
||||
outs(%out: tensor<4x4x128x256xf32>) {
|
||||
^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32):
|
||||
%r = arith.addf %pack_elem, %other_elem : f32
|
||||
linalg.yield %r : f32
|
||||
} -> tensor<4x4x128x256xf32>
|
||||
|
||||
return %res : tensor<4x4x128x256xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
// Find and lower pack operation.
|
||||
%pack = transform.structured.match ops{["tensor.pack"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.op<"tensor.pack">
|
||||
%paded, %expanded, %transpose = transform.structured.lower_pack %pack
|
||||
: (!transform.op<"tensor.pack">)
|
||||
-> (!transform.op<"tensor.pad">,
|
||||
!transform.op<"tensor.expand_shape">,
|
||||
!transform.op<"linalg.transpose">)
|
||||
|
||||
%root = transform.structured.match ops{["linalg.generic"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
// Tile the lialg operation with parallel forall loop tiling [4, 4].
|
||||
%tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
|
||||
// Fuse the transpose operation into the tiled loop.
|
||||
transform.structured.fuse_into_containing_op %transpose into %forall_op
|
||||
: (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
|
||||
// This allows linalg.transpose to be fused as a consumer operation. In below testcase, linalg.transpose
|
||||
// as a consumer operation is fused into the scf.forall loop.
|
||||
module {
|
||||
// CHECK-label: func @fuse_unpack_as_consumer
|
||||
// CHECK: scf.forall {{.*}} {
|
||||
// CHECK: %[[CONSUMER:.*]] = linalg.generic
|
||||
// CHECK: linalg.transpose ins(%[[CONSUMER]]
|
||||
// CHECK: scf.forall.in_parallel
|
||||
// CHECK: }
|
||||
func.func @fuse_unpack_as_consumer(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
|
||||
-> tensor<128x256xf32> {
|
||||
%out = tensor.empty() : tensor<1x1x128x256xf32>
|
||||
%res = linalg.generic
|
||||
{indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>,
|
||||
affine_map<(i, j, k, l) -> (i, j, k, l)>,
|
||||
affine_map<(i, j, k, l) -> (0, 0, k, l)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>)
|
||||
outs(%out: tensor<1x1x128x256xf32>) {
|
||||
^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32):
|
||||
%r = arith.addf %unpack_elem, %other_elem : f32
|
||||
linalg.yield %r : f32
|
||||
} -> tensor<1x1x128x256xf32>
|
||||
|
||||
%dest = tensor.empty() : tensor<128x256xf32>
|
||||
%unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256]
|
||||
into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32>
|
||||
|
||||
return %unpack : tensor<128x256xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
// Find and lower unpack operation.
|
||||
%unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.op<"tensor.unpack">
|
||||
transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}
|
||||
: (!transform.op<"tensor.unpack">)
|
||||
-> (!transform.op<"tensor.empty">,
|
||||
!transform.op<"linalg.transpose">,
|
||||
!transform.op<"tensor.collapse_shape">,
|
||||
!transform.op<"tensor.extract_slice">)
|
||||
|
||||
%root = transform.structured.match ops{["linalg.generic"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
// Tile the lialg operation with parallel forall loop tiling [4, 4].
|
||||
%tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
|
||||
// Fuse the consumer operation into the tiled loop.
|
||||
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
|
||||
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
|
||||
transform.test.fuse_consumer %slice_op
|
||||
: (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
// For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion.
|
||||
// In below testcase, tensor.extract_slice as a consumer operation cannot be fused into the scf.forall loop.
|
||||
module {
|
||||
// CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice
|
||||
// CHECK: %[[CONSUMER:.*]] = scf.forall {{.*}} {
|
||||
// CHECK: %[[ADDF:.*]] = linalg.generic
|
||||
// CHECK: scf.forall.in_parallel
|
||||
// CHECK: tensor.parallel_insert_slice %[[ADDF]]
|
||||
// CHECK: }
|
||||
// CHECK: tensor.extract_slice %[[CONSUMER]]
|
||||
func.func @fuse_unpack_as_consumer_blocked_by_extract_slice(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
|
||||
-> tensor<128x256xf32> {
|
||||
%out = tensor.empty() : tensor<1x1x128x256xf32>
|
||||
%res = linalg.generic
|
||||
{indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>,
|
||||
affine_map<(i, j, k, l) -> (i, j, k, l)>,
|
||||
affine_map<(i, j, k, l) -> (0, 0, k, l)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>)
|
||||
outs(%out: tensor<1x1x128x256xf32>) {
|
||||
^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32):
|
||||
%r = arith.addf %unpack_elem, %other_elem : f32
|
||||
linalg.yield %r : f32
|
||||
} -> tensor<1x1x128x256xf32>
|
||||
|
||||
%dest = tensor.empty() : tensor<128x256xf32>
|
||||
%unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256]
|
||||
into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32>
|
||||
|
||||
return %unpack : tensor<128x256xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
// Find and lower unpack operation.
|
||||
%unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.op<"tensor.unpack">
|
||||
transform.structured.lower_unpack %unpack
|
||||
: (!transform.op<"tensor.unpack">)
|
||||
-> (!transform.op<"tensor.empty">,
|
||||
!transform.op<"linalg.transpose">,
|
||||
!transform.op<"tensor.collapse_shape">,
|
||||
!transform.op<"tensor.extract_slice">)
|
||||
|
||||
%root = transform.structured.match ops{["linalg.generic"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
// Tile the lialg operation with parallel forall loop tiling [4, 4].
|
||||
%tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
|
||||
// Fuse the consumer operation into the tiled loop.
|
||||
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
|
||||
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
|
||||
// Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
|
||||
// is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
|
||||
// to fuse" error.
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user