From 91bbebc7e118cceae1fc0e349de08094a3cd2fe7 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 27 Dec 2024 16:52:34 +0000 Subject: [PATCH] [mlir][scf] Add getPartialResultTilePosition to PartialReductionOpInterface (#120465) This PR adds a new interface method to PartialReductionOpInterface which allows it to query the result tile position for the partial result. Previously, tiling the reduction dimension with SplitReductionOuterReduction when the result has transposed parallel dimensions would produce wrong results. Other fixes that were needed to make this PR work: - Instead of ad-hoc logic to decide where to place the new reduction dimensions in the partial result based on the iteration space, the reduction dimensions are always appended to the partial result tensor. - Remove usage of PartialReductionOpInterface in Mesh dialect. The implementation was trying to just get a neutral element, but ended up trying to use PartialReductionOpInterface for it, which is not right. It was also passing the wrong sizes to it. --- .../mlir/Interfaces/TilingInterface.td | 22 +++ .../Transforms/MeshShardingInterfaceImpl.cpp | 34 ++-- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 165 ++++++++++++------ .../SCF/Transforms/TileUsingInterface.cpp | 28 +-- .../Linalg/transform-tile-reduction.mlir | 67 +++++-- 5 files changed, 225 insertions(+), 91 deletions(-) diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index b75fc5e806af..50b69b8f8d83 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -427,6 +427,28 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> { /*defaultImplementation=*/[{ return failure(); }] + >, + InterfaceMethod< + /*desc=*/[{ + Method to return the position of the partial result tile computed by + the tiled operation. This is same as + TilingInterface:::getResultTilePosition, but determines the result + tile position for partial reduction. + }], + /*retType=*/"::llvm::LogicalResult", + /*methodName=*/"getPartialResultTilePosition", + /*args=*/(ins + "::mlir::OpBuilder &":$b, + "unsigned":$resultNumber, + "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets, + "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes, + "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets, + "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes, + "::mlir::ArrayRef":$reductionDims), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] > ]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp index 5bf2f91c2c7b..92cfba2549a3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp @@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { static MeshOp getMesh(Operation *op, ArrayRef operandShardings, ArrayRef resultShardings, SymbolTableCollection &symbolTable) { - for (const MeshSharding& sharding : operandShardings) { + for (const MeshSharding &sharding : operandShardings) { if (sharding) { return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); } } - for (const MeshSharding& sharding : resultShardings) { + for (const MeshSharding &sharding : resultShardings) { if (sharding) { return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); } @@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef operandShardings, // the original operand. // The other processes would use the reduction operation neutral tensor. static Value createDestinationPassingStyleInitOperand( - LinalgOp op, Value spmdizedOperand, ArrayRef reductionMeshAxes, - MeshOp meshOp, ImplicitLocOpBuilder &builder) { + LinalgOp op, int operandNumber, Value spmdizedOperand, + ArrayRef reductionMeshAxes, MeshOp meshOp, + ImplicitLocOpBuilder &builder) { Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex( meshOp.getSymName(), reductionMeshAxes, builder); Value zero = builder.create(0); @@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand( builder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); SmallVector shape = tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand); - PartialReductionOpInterface partialReductionIface = - llvm::cast(op.getOperation()); - assert(op->getNumResults() == 1 && "Multiple results not supported."); - FailureOr> reductionNeutralTensor = - partialReductionIface.generateInitialTensorForPartialReduction( - builder, builder.getLoc(), shape, {}); - assert(succeeded(reductionNeutralTensor)); - builder.create(reductionNeutralTensor.value()); + + SmallVector combinerOps; + matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps); + assert(combinerOps.size() == 1); + std::optional neutralEl = + arith::getNeutralElement(combinerOps[0]); + + Value init = builder.create(op.getLoc(), shape, + neutralEl.value().getType()); + Value constant = + builder.create(op.getLoc(), neutralEl.value()); + Value fill = builder.create(op.getLoc(), constant, init) + .getResult(0); + + builder.create(fill); } return ifOp.getResult(0); } @@ -178,7 +186,7 @@ static SmallVector createDestinationPassingStyleInitOperands( Value spmdizedInitOperand = spmdizationMap.lookup(op->getOperands()[operandIdx]); newOperands[operandIdx] = createDestinationPassingStyleInitOperand( - op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); + op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); return newOperands; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index f86715a94b26..b7764da26a7f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -324,7 +324,27 @@ struct LinalgOpTilingInterface // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. //===----------------------------------------------------------------------===// -/// External model implementation of PartialReductionInterface for LinalgOps. +/// Return an AffineMap for a partial result for the given result number, +/// assuming the partial tiling strategy is outer-reduction loop + +/// inner-parallel tile. The returned AffineMap can be used as the replacement +/// AffineMap for the inner-parallel tile linalg op for the given result number. +/// +/// The new AffineMap is the old AffineMap with reduction dimensions appended +/// at end. +static AffineMap getPartialResultAffineMap(LinalgOp linalgOp, + ArrayRef reductionDims, + unsigned resultNumber) { + AffineMap map = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber)); + for (int redPos : reductionDims) { + map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()), + map.getNumResults()); + } + return map; +} + +/// External model implementation of PartialReductionInterface for +/// LinalgOps. template struct LinalgOpPartialReductionInterface : public PartialReductionOpInterface::ExternalModel< @@ -338,11 +358,24 @@ struct LinalgOpPartialReductionInterface if (linalgOp.hasPureBufferSemantics()) return op->emitOpError("expected operation to have tensor semantics"); + // LinalgOp implements TilingInterface. + auto tilingInterfaceOp = cast(linalgOp.getOperation()); + SmallVector shape = + llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b), + [](Range x) { return x.size; }); + + SmallVector tiledShape; + for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { + if (isZeroIndex(tileSize)) { + tiledShape.push_back(dimSize); + } else { + tiledShape.push_back(tileSize); + } + } + SmallVector inits; for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e; ++initIdx) { - // Insert the new parallel dimension based on the index of the reduction - // loops. This could be controlled by user for more flexibility. SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx, combinerOps) || @@ -355,33 +388,19 @@ struct LinalgOpPartialReductionInterface return op->emitOpError( "Failed to get an identity value for the reduction operation."); - ArrayRef oldShape = - linalgOp.getShape(linalgOp.getDpsInitOperand(initIdx)); - - // Calculate the new shape, we insert the new dimensions based on the - // index of the reduction dimensions. - SmallVector newOutputShape; - SmallVector dynamicDims; - int64_t currReductionDims = 0; - DenseSet reductionDimsSet(reductionDims.begin(), - reductionDims.end()); - for (int64_t idx : - llvm::seq(0, oldShape.size() + reductionDims.size())) { - if (reductionDimsSet.contains(idx)) { - dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape); - currReductionDims++; - continue; - } - int64_t oldIdx = idx - currReductionDims; - int64_t dim = oldShape[oldIdx]; - newOutputShape.push_back(dim); - if (ShapedType::isDynamic(dim)) - dynamicDims.push_back(b.create( - loc, linalgOp.getDpsInitOperand(initIdx)->get(), oldIdx)); + // Append the new partial result dimensions. + AffineMap partialMap = + getPartialResultAffineMap(linalgOp, reductionDims, initIdx); + SmallVector partialResultShape; + for (AffineExpr dimExpr : partialMap.getResults()) { + auto dim = cast(dimExpr); + partialResultShape.push_back(tiledShape[dim.getPosition()]); } - Value emptyTensor = b.create( - loc, newOutputShape, - linalgOp.getRegionOutputArgs()[initIdx].getType(), dynamicDims); + + Type elType = + getElementTypeOrSelf(linalgOp->getResult(initIdx).getType()); + Value emptyTensor = + b.create(loc, partialResultShape, elType); Value constantOp = b.create(loc, *identity); auto identityTensor = b.create(loc, constantOp, emptyTensor); @@ -407,11 +426,7 @@ struct LinalgOpPartialReductionInterface // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace // this with a for range loop when we have it. AffineMap newMap = - linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx)); - for (int redPos : reductionDims) { - newMap = newMap.insertResult(b.getAffineDimExpr(redPos), - newMap.getNumResults()); - } + getPartialResultAffineMap(linalgOp, reductionDims, idx); newInitMaps.push_back(newMap); } @@ -476,29 +491,75 @@ struct LinalgOpPartialReductionInterface Location loc, ValueRange partialReduce, ArrayRef reductionDims) const { auto linalgOp = cast(op); - SmallVector reductionDimsInt64(reductionDims); - auto reduction = b.create( - loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64, - [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) { - int64_t numInits = linalgOp.getNumDpsInits(); - SmallVector yieldedValues; - for (int idx : llvm::seq(0, numInits)) { + + // Permute the reduction dims as permuted by the partial result map. + + int64_t numInits = linalgOp.getNumDpsInits(); + SmallVector mergeOperations; + SmallVector replacements; + for (int idx : llvm::seq(numInits)) { + // linalg.reduce's iteration space is the tiled result's iteration space + // (and not the tiled operation's iteration space). To account for this, + // permute the reduction dimensions based on the partial result map of the + // tiled result. + AffineMap partialMap = + getPartialResultAffineMap(linalgOp, reductionDims, idx); + SmallVector partialReductionDims; + for (auto [resultNum, dimExpr] : + llvm::enumerate(partialMap.getResults())) { + unsigned dim = cast(dimExpr).getPosition(); + if (llvm::find(reductionDims, dim) != reductionDims.end()) { + partialReductionDims.push_back(resultNum); + } + } + + Value partialResult = partialReduce[idx]; + Value init = linalgOp.getDpsInits()[idx]; + + auto reduction = b.create( + loc, partialResult, init, partialReductionDims, + [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) { // Get the combiner op. SmallVector combinerOps; matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps); Operation *clonedReductionOp = b.clone(*combinerOps[0]); // Combine the input at idx and output at numInits + idx. - clonedReductionOp->setOperand(0, inputs[idx]); - clonedReductionOp->setOperand(1, inputs[numInits + idx]); - // Yield. - yieldedValues.push_back(clonedReductionOp->getResult(0)); - } - b.create(loc, yieldedValues); - }); - return MergeResult{ - {reduction.getOperation()}, - llvm::map_to_vector(reduction->getResults(), - [](OpResult r) -> Value { return r; })}; + clonedReductionOp->setOperand(0, inputs[0]); + clonedReductionOp->setOperand(1, inputs[1]); + b.create(loc, clonedReductionOp->getResult(0)); + }); + + mergeOperations.push_back(reduction); + replacements.push_back(reduction->getResult(0)); + } + + return MergeResult{mergeOperations, replacements}; + } + + LogicalResult getPartialResultTilePosition( + Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes, + ArrayRef reductionDims) const { + auto linalgOp = cast(op); + + AffineMap partialMap = + getPartialResultAffineMap(linalgOp, reductionDims, resultNumber); + for (AffineExpr dimExpr : partialMap.getResults()) { + unsigned dim = cast(dimExpr).getPosition(); + resultSizes.push_back(sizes[dim]); + + if (llvm::find(reductionDims, dim) != reductionDims.end()) { + // Reduction dims are reduced, and are always outputed in the same + // place. So use offset 0 for them. + resultOffsets.push_back(b.getIndexAttr(0)); + } else { + resultOffsets.push_back(offsets[dim]); + } + } + + return success(); } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 2277989bf841..b548f8ce8b56 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -657,21 +657,29 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, resultOffset, resultSize); case scf::SCFTilingOptions::ReductionTilingStrategy:: PartialReductionOuterReduction: { - // TODO: This does not work for non identity accesses to the result tile. - // The proper fix is to add a getPartialResultTilePosition method to - // PartialReductionOpInterface. - resultOffset = - SmallVector(offsets.size(), rewriter.getIndexAttr(0)); - for (size_t i = 0; i < offsets.size(); i++) { - resultSize.push_back( - tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i)); + auto redOp = dyn_cast(op.getOperation()); + if (!redOp) { + return rewriter.notifyMatchFailure( + op, "PartialReductionOuterReduction tiling strategy is only supported" + "for operations implementing PartialReductionOpInterface"); } - return success(); + // Get reduction dimensions. + // TODO: PartialReductionOpInterface should really query TilingInterface + // itself and find reduction dimensions. + SmallVector reductionDims; + for (auto [idx, iteratorType] : + llvm::enumerate(op.getLoopIteratorTypes())) { + if (iteratorType == utils::IteratorType::reduction) + reductionDims.push_back(idx); + } + return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes, + resultOffset, resultSize, + reductionDims); + } default: return rewriter.notifyMatchFailure(op, "unhandled reduction tiling strategy"); } - } } static FailureOr diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir index cce4b4efa61c..9d34c80822d0 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -32,8 +32,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor // CHECK: %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor) { // CHECK: %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]] @@ -81,13 +80,13 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK: func @reduction_tile_transpose -// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32> -// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32> +// CHECK: tensor.empty(%{{.*}}) : tensor +// CHECK: linalg.fill {{.*}} : tensor) -> tensor // CHECK: scf.for -// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor +// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor to tensor // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor) outs(%[[EXT]] : tensor) -// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor into tensor<5x?xf32> -// CHECK: scf.yield {{.*}} : tensor<5x?xf32> +// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor into tensor +// CHECK: scf.yield {{.*}} : tensor // CHECK: } // CHECK: linalg.reduce // CHECK: return @@ -129,8 +128,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor // CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] @@ -183,9 +181,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor // CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK-DAG: %[[D3:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor -// CHECK-DAG: %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor -// CHECK: %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D0]], %[[D2]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor // CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] @@ -243,8 +239,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor // CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { // CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor @@ -422,8 +417,8 @@ func.func @reduction_tile_multiple_results(%arg0: tensor, %out: tensor< module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %12, %2, %3, %loop = transform.structured.tile_reduction_using_for %0 - by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + %1, %12, %2, %3, %4, %loop = transform.structured.tile_reduction_using_for %0 + by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } } @@ -444,4 +439,44 @@ module attributes {transform.with_named_sequence} { // CHECK: scf.yield %[[INSERT1]], %[[INSERT1]] // CHECK: linalg.reduce // CHECK: arith.addf +// CHECK: linalg.reduce // CHECK: arith.maximumf + +// ----- + +func.func @reduction_tile_multi_dim_transpose(%arg0: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d2, d0)>], + iterator_types = ["parallel", "reduction", "parallel"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %42 = arith.addf %arg7, %arg9 : f32 + linalg.yield %42 : f32 + } -> tensor + return %red : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0 + by tile_sizes = [0, 5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +// CHECK: func @reduction_tile_multi_dim_transpose +// CHECK: tensor.empty(%{{.*}}) : tensor +// CHECK: linalg.fill {{.*}} : tensor) -> tensor +// CHECK: scf.for +// CHECK: %[[K:.*]] = affine.min +// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0, 0] [%[[D2:.*]], %[[D0:.*]], %[[K]]] [1, 1, 1] : tensor to tensor +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[L:.*]] : tensor) outs(%[[EXT]] : tensor) +// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0, 0] [%[[D2]], %[[D0]], %[[K]]] [1, 1, 1] : tensor into tensor +// CHECK: scf.yield {{.*}} : tensor +// CHECK: } +// CHECK: linalg.reduce +// CHECK: return