diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index dbc1ac60e097..74c4c0a8835f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -222,9 +222,59 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> { ]; } +def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> { + let description = [{ + Interface for operations that connect an iteration domain to operands via + affine maps. Provides methods to access indexing maps between iteration + domain and operand index spaces. + }]; + let cppNamespace = "::mlir::linalg"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the indexing maps attribute within the current operation. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"getIndexingMaps" + >, + InterfaceMethod< + /*desc=*/[{ + Return the indexing maps within the current operation. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getIndexingMapsArray", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = $_op.getIndexingMaps() + .template getAsValueRange(); + return {range.begin(), range.end()}; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the input or output indexing map for `opOperand`. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getMatchingIndexingMap", + /*args=*/(ins "OpOperand*":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + auto indexingMaps = + $_op.getIndexingMaps().template getAsValueRange(); + return *(indexingMaps.begin() + opOperand->getOperandNumber()); + }] + >, + ]; +} + // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. def LinalgStructuredInterface - : OpInterface<"LinalgOp", [DestinationStyleOpInterface]> { + : OpInterface<"LinalgOp", [ + DestinationStyleOpInterface, + IndexingMapOpInterface + ]> { let cppNamespace = "::mlir::linalg"; let methods = [ //===------------------------------------------------------------------===// @@ -465,21 +515,6 @@ def LinalgStructuredInterface blockArgument.getArgNumber()); }] >, - InterfaceMethod< - /*desc=*/[{ - Return the input or output indexing map for `opOperand`. - }], - /*retTy=*/"AffineMap", - /*methodName=*/"getMatchingIndexingMap", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - auto indexingMaps = - $_op.getIndexingMaps().template getAsValueRange(); - return *(indexingMaps.begin() + opOperand->getOperandNumber()); - }] - >, InterfaceMethod< /*desc=*/[{ Return the indexing map for a `result`. @@ -576,27 +611,6 @@ def LinalgStructuredInterface /*methodBody=*/"", /*defaultImplementation=*/[{ return success(); }] >, - InterfaceMethod< - /*desc=*/[{ - Return the indexing maps attribute within the current operation. - }], - /*retTy=*/"ArrayAttr", - /*methodName=*/"getIndexingMaps" - >, - InterfaceMethod< - /*desc=*/[{ - Return the indexing maps within the current operation. - }], - /*retTy=*/"SmallVector", - /*methodName=*/"getIndexingMapsArray", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto range = $_op.getIndexingMaps() - .template getAsValueRange(); - return {range.begin(), range.end()}; - }] - >, InterfaceMethod< /*desc=*/[{ Return true if any of the operands has a dynamic shape. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 59b7fdeef10b..a6dab03d6473 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -612,10 +612,9 @@ using PadSizeComputationFunction = const PadTilingInterfaceOptions &)>; /// Specific helper for Linalg ops. -FailureOr> -computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad, - ArrayRef iterationDomain, - const PadTilingInterfaceOptions &options); +FailureOr> computeIndexingMapOpInterfacePaddedShape( + RewriterBase &rewriter, OpOperand &operandToPad, + ArrayRef iterationDomain, const PadTilingInterfaceOptions &options); /// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`. /// @@ -632,7 +631,7 @@ rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, const PadTilingInterfaceOptions &constOptions, SmallVector &padOps, PadSizeComputationFunction computePaddingSizeFun = - &computeLinalgPaddedShape); + &computeIndexingMapOpInterfacePaddedShape); namespace detail { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index e627fc83f2ba..5d55adbf46f3 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2229,10 +2229,12 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, return diag; } - // Only Linalg ops for now, until TilingInterface exposes a loopsToOperand - // map / C++ APIs to compute the effect of padding on operands. - if (!isa(targetOp.getOperation())) { - auto diag = emitSilenceableError() << "only LinalgOp supported atm"; + // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a + // loopsToOperand map / C++ APIs to compute the effect of padding on + // operands. + if (!isa(targetOp.getOperation())) { + auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops " + "supported atm"; diag.attachNote(target->getLoc()) << "target op"; return diag; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index a9d7bc64f2a6..5383ae48aeb3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -155,11 +155,13 @@ SmallVector linalg::computePaddedShape( return paddedShape; } -FailureOr> linalg::computeLinalgPaddedShape( +FailureOr> +linalg::computeIndexingMapOpInterfacePaddedShape( RewriterBase &rewriter, OpOperand &operandToPad, ArrayRef iterationDomain, const PadTilingInterfaceOptions &options) { - auto linalgOp = llvm::dyn_cast(operandToPad.getOwner()); - if (!linalgOp) + auto transferOp = + llvm::dyn_cast(operandToPad.getOwner()); + if (!transferOp) return failure(); // clang-format off @@ -173,7 +175,7 @@ FailureOr> linalg::computeLinalgPaddedShape( for (const Range &range : iterationDomain) loopUpperBounds.push_back(range.size); - AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&operandToPad); + AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad); return computePaddedShape( rewriter, cast>(operandToPad.get()), indexingMap, loopUpperBounds, options); @@ -255,7 +257,18 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, SmallVector newOperands; newOperands.reserve(opToPad->getNumOperands()); for (OpOperand &opOperand : opToPad->getOpOperands()) { - LLVM_DEBUG(DBGS() << "--start padding oprd: " << opOperand.get() << "\n"); + Value operand = opOperand.get(); + LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n"); + + // 2.a. Skip scalar-like operands. + Type operandType = operand.getType(); + if (!isa(operandType)) { + assert(!isa(operandType) || + isa(operandType) && + "Unexpected non-vector ShapedType"); + newOperands.push_back(operand); + continue; + } // 2.a. Compute padded shape. FailureOr> maybePaddedShape = computePaddingSizeFun(rewriter, opOperand, iterationDomain, options); @@ -266,14 +279,16 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, // 2.b. Expect proper `paddingValues`. // TODO: we may want to allow garbage padding in the future, in which case // we would just not assert. - assert(opOperand.getOperandNumber() < options.paddingValues.size() && - "--no padding value specified"); + if (opOperand.getOperandNumber() >= options.paddingValues.size()) { + return rewriter.notifyMatchFailure(opToPad, + "--no padding value specified"); + } Attribute paddingValueAttr = options.paddingValues[opOperand.getOperandNumber()]; // 2.c. Perform actual padding. Value paddedOperand = padOperand( - rewriter, opToPad, cast>(opOperand.get()), + rewriter, opToPad, cast>(operand), *maybePaddedShape, paddingValueAttr); LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n"); diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir index c361885693cb..f0a410fa4015 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir @@ -1,5 +1,33 @@ // RUN: mlir-opt --transform-interpreter -canonicalize -split-input-file --verify-diagnostics %s | FileCheck %s +// CHECK-LABEL: pad_fill +// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32> +func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32> +{ + %0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + + // Tile to 5 then pad to 8 + %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] { + padding_values=[0.0 : f32, 0.0 : f32], + padding_dimensions=[0] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + transform.yield + } +} + +// ----- + // CHECK-LABEL: pad_lhs func.func @pad_lhs( %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)