From d31ba5256327d30f264c2f671bf197877b242cde Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 24 Jun 2025 07:56:32 +0200 Subject: [PATCH] [mlir][Interface] Factor out common IndexingMapOpInterface behavior in a new generic interface (#145313) Refactor the verifiers to make use of the common bits and make `vector.contract` also use this interface. In the process, the confusingly named getStaticShape has disappeared. Note: the verifier for IndexingMapOpInterface is currently called manually from other verifiers as it was unclear how to avoid it taking precedence over more meaningful error messages --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 1 + .../Dialect/Linalg/IR/LinalgInterfaces.td | 180 +++--------------- .../mlir/Dialect/Vector/IR/VectorOps.h | 1 + .../mlir/Dialect/Vector/IR/VectorOps.td | 12 ++ mlir/include/mlir/Interfaces/CMakeLists.txt | 1 + .../mlir/Interfaces/IndexingMapOpInterface.h | 27 +++ .../mlir/Interfaces/IndexingMapOpInterface.td | 153 +++++++++++++++ mlir/lib/Dialect/Linalg/IR/CMakeLists.txt | 1 + .../Dialect/Linalg/IR/LinalgInterfaces.cpp | 85 +-------- .../Linalg/IR/ValueBoundsOpInterfaceImpl.cpp | 1 + .../Linalg/Transforms/DropUnitDims.cpp | 7 +- .../Linalg/Transforms/Vectorization.cpp | 5 +- mlir/lib/Dialect/Vector/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 +- mlir/lib/Interfaces/CMakeLists.txt | 2 + .../lib/Interfaces/IndexingMapOpInterface.cpp | 125 ++++++++++++ mlir/test/Dialect/Linalg/invalid.mlir | 2 +- 17 files changed, 369 insertions(+), 238 deletions(-) create mode 100644 mlir/include/mlir/Interfaces/IndexingMapOpInterface.h create mode 100644 mlir/include/mlir/Interfaces/IndexingMapOpInterface.td create mode 100644 mlir/lib/Interfaces/IndexingMapOpInterface.cpp diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index 0f960fb5ad79..0ebbeea93755 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -20,6 +20,7 @@ #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/IndexingMapOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/RawOstreamExtras.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 74c4c0a8835f..ca1cba8747bd 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -14,6 +14,7 @@ #define LINALG_IR_LINALGINTERFACES include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/IndexingMapOpInterface.td" include "mlir/IR/OpBase.td" // The 'LinalgContractionOpInterface' provides access to the @@ -222,59 +223,11 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> { ]; } -def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> { - let description = [{ - Interface for operations that connect an iteration domain to operands via - affine maps. Provides methods to access indexing maps between iteration - domain and operand index spaces. - }]; - let cppNamespace = "::mlir::linalg"; - let methods = [ - InterfaceMethod< - /*desc=*/[{ - Return the indexing maps attribute within the current operation. - }], - /*retTy=*/"ArrayAttr", - /*methodName=*/"getIndexingMaps" - >, - InterfaceMethod< - /*desc=*/[{ - Return the indexing maps within the current operation. - }], - /*retTy=*/"SmallVector", - /*methodName=*/"getIndexingMapsArray", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto range = $_op.getIndexingMaps() - .template getAsValueRange(); - return {range.begin(), range.end()}; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the input or output indexing map for `opOperand`. - }], - /*retTy=*/"AffineMap", - /*methodName=*/"getMatchingIndexingMap", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - auto indexingMaps = - $_op.getIndexingMaps().template getAsValueRange(); - return *(indexingMaps.begin() + opOperand->getOperandNumber()); - }] - >, - ]; -} - // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. def LinalgStructuredInterface - : OpInterface<"LinalgOp", [ - DestinationStyleOpInterface, - IndexingMapOpInterface - ]> { + : OpInterface<"LinalgOp", + [DestinationStyleOpInterface, IndexingMapOpInterface] + > { let cppNamespace = "::mlir::linalg"; let methods = [ //===------------------------------------------------------------------===// @@ -464,30 +417,6 @@ def LinalgStructuredInterface return getBlock()->getArguments().take_back($_op.getNumDpsInits()); }] >, - InterfaceMethod< - /*desc=*/[{ - Return the `opOperand` shape or an empty vector for scalars or vectors - not wrapped within a tensor or a memref. - }], - /*retTy=*/"ArrayRef", - /*methodName=*/"getShape", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - Type t = opOperand->get().getType(); - // A VectorType is an elemental type, do not consider its rank for the operand. - if (isa(t)) - return {}; - if (auto shapedType = ::llvm::dyn_cast(t)) { - // Failsafe. - assert((isa(t) || isa(t)) && - "expected a ranked tensor or memref in LinalgInterface::getRank"); - return shapedType.getShape(); - } - return {}; - }] - >, InterfaceMethod< /*desc=*/[{ Return the block argument for an `opOperand`. @@ -620,7 +549,12 @@ def LinalgStructuredInterface /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::any_of(getStaticShape(), ShapedType::isDynamic); + for (OpOperand &opOperand : this->getOperation()->getOpOperands()) { + if (auto shapedType = dyn_cast(opOperand.get().getType())) { + if (ShapedType::isDynamicShape(shapedType.getShape())) return true; + } + } + return false; }] >, InterfaceMethod< @@ -738,53 +672,6 @@ def LinalgStructuredInterface //===------------------------------------------------------------------===// // Linalg generalization hooks. //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/[{ - Hook to provide a custom AffineMap used to compute all the operand - subshapes given loop bounds. This is used to answer the question: "given - an iteration space over the codomain, what are the subshapes of the - operands involved in the computation". - The default behavior is to just concatenate all the indexing maps. - A custom AffineMap allows providing a map that can be used to - compute subshapes even in cases where the concatenation of indexing maps - (i.e. the data traversal order) is not a simple permutation of the loop - traversal order. It is then possible to define ops with skewed data - traversal order for which we can still easily compute hyperrectangular - loop bounds and subviews. - }], - /*retTy=*/"AffineMap", - /*methodName=*/"getLoopsToShapesMap", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto maps = $_op.getIndexingMapsArray(); - return concatAffineMaps(maps, $_op.getContext()); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Hook to provide a custom AffineMap used to construct the - hyperrectangular loop iteration space given all the operand subshapes. - This is used to answer the question: - "Given a list of operand ranges, what is the subportion of the iteration - space involved in the computation". - This is the inverse problem of `getLoopsToShapesMap`. - Return the empty AffineMap when such an AffineMap cannot be constructed. - The default behavior is based on a very simple inference procedure that - only works with permutation affine maps. - A more advanced Tensor-Comprehension like inference is possible but has - proven to be ambiguous in unfavorable case. - A safer and more robust alternative is to allow each op to define - its own AffineMap. - }], - /*retTy=*/"AffineMap", - /*methodName=*/"getShapesToLoopsMap", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return inversePermutation(getLoopsToShapesMap()); - }] - >, InterfaceMethod< /*desc=*/[{ Checks if the given operands can be dropped, and the remaining @@ -798,39 +685,30 @@ def LinalgStructuredInterface return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands); }] >, + //===------------------------------------------------------------------===// + // IndexingMapOpInterface interface methods implementation. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Like `getShape`, but only returns statically-known information, without - generating any new IR. For each shape dimension, returns >=0 if that - dimension is statically known, or ShapedType::kDynamic otherwise. + Return the `opOperand` shape or an empty vector for scalars or vectors + not wrapped within a tensor or a memref. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getStaticShape", - /*args=*/(ins), + /*retTy=*/"ArrayRef", + /*methodName=*/"getShape", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector res; - for (OpOperand &opOperand : this->getOperation()->getOpOperands()) - llvm::append_range(res, getShape(&opOperand)); - return res; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Returns the statically-known loop ranges. Composes - `getShapesToLoopsMap()` with the result of `getStaticShape`. - Returns ShapedType::kDynamic for non-statically-known loop ranges. - This is expected to be called by a valid Linalg op - }], - /*retTy=*/"SmallVector", - /*methodName=*/"getStaticLoopRanges", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - SmallVector viewSizes = getStaticShape(); - AffineMap invertedMap = getShapesToLoopsMap(); - assert(invertedMap && "expected a valid Linalg op to call the method"); - return invertedMap.compose(viewSizes); + Type t = opOperand->get().getType(); + // A VectorType is an elemental type, do not consider its rank for the operand. + if (isa(t)) + return {}; + if (auto shapedType = ::llvm::dyn_cast(t)) { + // Failsafe. + assert((isa(t) || isa(t)) && + "expected a ranked tensor or memref in LinalgInterface::getRank"); + return shapedType.getShape(); + } + return {}; }] >, //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 98fb6075cbf3..364c1728715e 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -25,6 +25,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/IndexingMapOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 926a92eff2eb..02e62930a742 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -21,6 +21,7 @@ include "mlir/Dialect/Vector/IR/Vector.td" include "mlir/Dialect/Vector/IR/VectorAttributes.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/IndexingMapOpInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -33,6 +34,7 @@ include "mlir/IR/EnumAttr.td" // than the current set: {*, +}. def Vector_ContractionOp : Vector_Op<"contract", [ + IndexingMapOpInterface, Pure, PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>, PredOpTrait<"third operand acc and result have same element type", @@ -207,6 +209,16 @@ def Vector_ContractionOp : .template getAsValueRange(); return {range.begin(), range.end()}; } + + //===------------------------------------------------------------------===// + // IndexingMapOpInterface interface methods implementation. + //===------------------------------------------------------------------===// + ArrayRef getShape(OpOperand * opOperand) { + Type t = opOperand->get().getType(); + if (auto vt = dyn_cast(t)) + return vt.getShape(); + return {}; + } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index d81298bb4daf..067e0511e4e7 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_interface(CopyOpInterface) add_mlir_interface(DerivedAttributeOpInterface) add_mlir_interface(DestinationStyleOpInterface) add_mlir_interface(FunctionInterfaces) +add_mlir_interface(IndexingMapOpInterface) add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) diff --git a/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h new file mode 100644 index 000000000000..40252613a21f --- /dev/null +++ b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h @@ -0,0 +1,27 @@ +//===- IndexingMapOpInterface.h ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_ +#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_ + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace detail { +/// Verify that `op` conforms to the invariants of StructuredOpInterface +LogicalResult verifyIndexingMapOpInterface(Operation *op); +} // namespace detail +} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Interfaces/IndexingMapOpInterface.h.inc" + +#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td new file mode 100644 index 000000000000..fdcc183d9921 --- /dev/null +++ b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td @@ -0,0 +1,153 @@ +//===- IndexingMapOpInterface.td - Interface Declaration -*- tablegen -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for the IndexingMapOpInterface. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE +#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE + +include "mlir/IR/OpBase.td" + +def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> { + let description = [{ + Interface for operations that connect an iteration domain to operands via + affine maps. Provides methods to access indexing maps between iteration + domain and operand index spaces. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the indexing maps attribute within the current operation. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"getIndexingMaps" + >, + InterfaceMethod< + /*desc=*/[{ + Return the indexing maps within the current operation. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getIndexingMapsArray", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = $_op.getIndexingMaps() + .template getAsValueRange(); + return {range.begin(), range.end()}; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the input or output indexing map for `opOperand`. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getMatchingIndexingMap", + /*args=*/(ins "OpOperand*":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + auto indexingMaps = + $_op.getIndexingMaps().template getAsValueRange(); + return *(indexingMaps.begin() + opOperand->getOperandNumber()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Hook to provide a custom AffineMap used to compute all the operand + subshapes given loop bounds. This is used to answer the question: "given + an iteration space over the codomain, what are the subshapes of the + operands involved in the computation". + The default behavior is to just concatenate all the indexing maps. + A custom AffineMap allows providing a map that can be used to + compute subshapes even in cases where the concatenation of indexing maps + (i.e. the data traversal order) is not a simple permutation of the loop + traversal order. It is then possible to define ops with skewed data + traversal order for which we can still easily compute hyperrectangular + loop bounds and subviews. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getLoopsToShapesMap", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto maps = $_op.getIndexingMapsArray(); + return concatAffineMaps(maps, $_op.getContext()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Hook to provide a custom AffineMap used to construct the + hyperrectangular loop iteration space given all the operand subshapes. + This is used to answer the question: + "Given a list of operand ranges, what is the subportion of the iteration + space involved in the computation". + This is the inverse problem of `getLoopsToShapesMap`. + Return the empty AffineMap when such an AffineMap cannot be constructed. + The default behavior is based on a very simple inference procedure that + only works with permutation affine maps. + A more advanced Tensor-Comprehension like inference is possible but has + proven to be ambiguous in unfavorable case. + A safer and more robust alternative is to allow each op to define + its own AffineMap. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getShapesToLoopsMap", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return inversePermutation($_op.getLoopsToShapesMap()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the static shape of the underlying operand (note this is + op-specific behavior). + Returns ShapedType::kDynamic for non-statically-known loop ranges. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getStaticOperandShape", + /*args=*/(ins "OpOperand*":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector res; + llvm::append_range(res, $_op.getShape(opOperand)); + return res; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns loop ranges by composing `getShapesToLoopsMap()` with the + flattened list of operand shapes. + Returns ShapedType::kDynamic for non-statically-known loop ranges. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getStaticLoopRanges", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector allShapesSizes; + for (OpOperand &opOperand : this->getOperation()->getOpOperands()) + llvm::append_range(allShapesSizes, $_op.getShape(&opOperand)); + AffineMap invertedMap = $_op.getShapesToLoopsMap(); + assert(invertedMap && "expected a valid op"); + return invertedMap.compose(allShapesSizes); + }] + > + ]; + let extraClassDeclaration = [{ + // Verifier implementation for IndexingMapOpInterface. + // This must be called manually as part of other verifiers so that the + // verification order, and meaningful error messages, are not preempted. + LogicalResult verifyImpl(); + }]; +} + +#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt index b4aeb44ac8fa..ec433284e17a 100644 --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgDialect MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRFunctionInterfaces + MLIRIndexingMapOpInterface MLIRInferTypeOpInterface MLIRIR MLIRParser diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 139e9901b0a2..ca7f31dd6b51 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -1251,38 +1251,20 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { if (failed(linalgOp.verifyIndexingMapRequiredAttributes())) return failure(); - // All input/output operands must be indexed. - if (static_cast(linalgOp.getIndexingMapsArray().size()) != - linalgOp->getNumOperands()) - return op->emitOpError("expected the number of indexing_map (") - << linalgOp.getIndexingMapsArray().size() - << ") to be equal to the number of input/output operands (" - << linalgOp->getNumOperands() << ")"; + // Delayed calling of IndexingMapOpInterface::verifyImpl. + if (failed(cast(op).verifyImpl())) + return failure(); // Set this flag if this op has user defined maps. This is required to guard // the below error condition which assume default indexing maps. for (OpOperand &opOperand : linalgOp->getOpOperands()) { AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); - - // Symbols disallowed. - if (indexingMap.getNumSymbols() != 0) - return op->emitOpError("unexpected symbols in indexing_map #") - << opOperand.getOperandNumber(); - // Domain must be consistent. unsigned numLoops = linalgOp.getNumLoops(); if (indexingMap.getNumDims() != numLoops) return op->emitOpError("expected indexing_map #") << opOperand.getOperandNumber() << " to have " << numLoops << " dim(s) to match the number of loops"; - - int64_t rank = linalgOp.getRank(&opOperand); - - if (indexingMap.getNumResults() != rank) - return op->emitOpError("expected operand rank (") - << rank << ") to match the result rank of indexing_map #" - << opOperand.getOperandNumber() << " (" - << indexingMap.getNumResults() << ")"; } SmallVector redDims; linalgOp.getReductionDims(redDims); @@ -1290,67 +1272,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { if (!linalgOp.getShapesToLoopsMap()) return op->emitOpError("expected the shape-to-loops map to be non-null"); - // Check if given shapes match to inferred shapes. - SmallVector endLoopRangeValues = linalgOp.getStaticLoopRanges(); - SmallVector startLoopRangeValues(endLoopRangeValues.size(), 0); - // Verify only static cases since we can't get exact dimension sizes and - // loop ranges for dynamic cases in this stage. - if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { - for (int64_t &range : endLoopRangeValues) - range -= 1; - for (OpOperand &opOperand : linalgOp->getOpOperands()) { - AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); - SmallVector startIndices = - indexingMap.compose(startLoopRangeValues); - SmallVector endIndices = - indexingMap.compose(endLoopRangeValues); - ArrayRef shape = linalgOp.getShape(&opOperand); - for (auto dim : llvm::seq(0, shape.size())) { - // Ignore dynamic dimension or the case that the dimension size is 0 - if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) - continue; - - // The first index or last index should be the maximum or the minimum in - // the inferred index ranges since the range is increasing or - // decreasing. The size of dimensions of input/output operands and the - // maximum value + 1 in the inferred range should be the same. But, for - // now we check if the inferred ranges are in boundary of input/output - // operands' size or not in case that Affine Expressions are complicated - // such as d0 * 3 - // + d1 since it is not easy to handle the issues. - // Found the case that this solution can't check, for example, (d0, d1) - // -> (d1 - d0) - int64_t inferredDimSize = - std::max(startIndices[dim], endIndices[dim]) + 1; - if (std::min(startIndices[dim], endIndices[dim]) < 0) { - std::string mapStr; - { - llvm::raw_string_ostream os(mapStr); - os << indexingMap; - } - return op->emitOpError( - "unexpected result less than 0 at expression #") - << dim << " in " << mapStr; - } - if (isa(indexingMap.getResult(dim))) { - if (inferredDimSize != shape[dim]) { - return op->emitOpError("inferred input/output operand #") - << opOperand.getOperandNumber() << " has shape's dimension #" - << dim << " to be " << inferredDimSize << ", but found " - << shape[dim]; - } - } else { - if (inferredDimSize > shape[dim]) { - return op->emitOpError("inferred input/output operand #") - << opOperand.getOperandNumber() << " has shape's dimension #" - << dim << " to be greater than or equal to " - << inferredDimSize << ", but found " << shape[dim]; - } - } - } - } - } - // Check the region has exactly one block. if (linalgOp->getNumRegions() != 1 || !llvm::hasSingleElement(linalgOp->getRegion(0))) diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp index f56ef485069f..8d6d9dc690b5 100644 --- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Interfaces/IndexingMapOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 5e6dde36d7f9..c5d9a729a413 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -398,7 +398,10 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, return rewriter.notifyMatchFailure(genericOp, "invalid indexing maps for operation"); } - SmallVector dims = genericOp.getStaticShape(); + + SmallVector allShapesSizes; + for (OpOperand &opOperand : genericOp->getOpOperands()) + llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand)); // 1a. Get the allowed list of dimensions to drop from the `options`. SmallVector allowedUnitDims = options.controlFn(genericOp); @@ -411,7 +414,7 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, llvm::SmallDenseSet unitDims; for (const auto &expr : enumerate(invertedMap.getResults())) { if (AffineDimExpr dimExpr = dyn_cast(expr.value())) { - if (dims[dimExpr.getPosition()] == 1 && + if (allShapesSizes[dimExpr.getPosition()] == 1 && unitDimsFilter.count(expr.index())) unitDims.insert(expr.index()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index ff28bd7c4834..ff8e0b8977ae 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -31,6 +31,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" @@ -2217,7 +2218,9 @@ static LogicalResult vectorizeLinalgOpPrecondition( LinalgOp linalgOp, ArrayRef inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { // tensor with dimension of 0 cannot be vectorized. - if (llvm::is_contained(linalgOp.getStaticShape(), 0)) + if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) { + return llvm::is_contained(linalgOp.getShape(&operand), 0); + })) return failure(); // Check API contract for input vector sizes. if (!inputVectorSizes.empty() && diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt index 204462ffd047..d464230c8772 100644 --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRVectorDialect MLIRDataLayoutInterfaces MLIRDestinationStyleOpInterface MLIRDialectUtils + MLIRIndexingMapOpInterface MLIRIR MLIRMaskableOpInterface MLIRMaskingOpInterface diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ee9ab61b670c..5e0f36064be3 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1063,7 +1063,8 @@ LogicalResult ContractionOp::verify() { if (!isSupportedCombiningKind(getKind(), elementType)) return emitOpError("unsupported contraction type"); - return success(); + // Delayed calling of IndexingMapOpInterface::verifyImpl. + return cast(this->getOperation()).verifyImpl(); } // MaskableOpInterface methods. diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index a25694cfff5f..af923d98c76f 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES DestinationStyleOpInterface.cpp FunctionImplementation.cpp FunctionInterfaces.cpp + IndexingMapOpInterface.cpp InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp @@ -62,6 +63,7 @@ add_mlir_library(MLIRFunctionInterfaces MLIRIR ) +add_mlir_interface_library(IndexingMapOpInterface) add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp new file mode 100644 index 000000000000..f3c12aed8df8 --- /dev/null +++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp @@ -0,0 +1,125 @@ +//===- IndexingMapOpInterface.cpp -- IndexingMapOpInterface impl ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/IndexingMapOpInterface.h" + +using namespace mlir; + +namespace mlir { +#include "mlir/Interfaces/IndexingMapOpInterface.cpp.inc" +} // namespace mlir + +LogicalResult mlir::IndexingMapOpInterface::verifyImpl() { + // All input/output operands must be indexed. + if (static_cast(getIndexingMapsArray().size()) != + getOperation()->getNumOperands()) + return this->emitOpError("expected the number of indexing_map (") + << getIndexingMapsArray().size() + << ") to be equal to the number of input/output operands (" + << getOperation()->getNumOperands() << ")"; + + AffineMap invertedMap = getShapesToLoopsMap(); + if (!invertedMap) { + std::string str; + llvm::raw_string_ostream os(str); + getLoopsToShapesMap().print(os); + return this->emitOpError("invalid indexing maps are non-invertible: ") + << "(" << str << ")"; + } + + SmallVector endLoopRangeValues = getStaticLoopRanges(); + + // Set this flag if this op has user defined maps. This is required to guard + // the below error condition which assume default indexing maps. + for (OpOperand &opOperand : getOperation()->getOpOperands()) { + AffineMap indexingMap = getMatchingIndexingMap(&opOperand); + + // Symbols disallowed. + if (indexingMap.getNumSymbols() != 0) + return getOperation()->emitOpError("unexpected symbols in indexing_map #") + << opOperand.getOperandNumber(); + + // Domain must be consistent. + if (indexingMap.getNumDims() != endLoopRangeValues.size()) + return getOperation()->emitOpError("expected indexing_map #") + << opOperand.getOperandNumber() << " to have " + << endLoopRangeValues.size() + << " dim(s) to match the number of loops"; + + SmallVector shape = getStaticOperandShape(&opOperand); + int64_t rank = shape.size(); + + if (indexingMap.getNumResults() != rank) + return getOperation()->emitOpError("expected operand rank (") + << rank << ") to match the result rank of indexing_map #" + << opOperand.getOperandNumber() << " (" + << indexingMap.getNumResults() << ")"; + } + + // Check if given shapes match to inferred shapes. + SmallVector startLoopRangeValues(endLoopRangeValues.size(), 0); + // Verify only static cases since we can't get exact dimension sizes and + // loop ranges for dynamic cases in this stage. + if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { + // Exclusive end range. + for (int64_t &range : endLoopRangeValues) + range -= 1; + for (OpOperand &opOperand : getOperation()->getOpOperands()) { + AffineMap indexingMap = getMatchingIndexingMap(&opOperand); + SmallVector startIndices = + indexingMap.compose(startLoopRangeValues); + SmallVector endIndices = indexingMap.compose(endLoopRangeValues); + SmallVector shape = getStaticOperandShape(&opOperand); + for (auto dim : llvm::seq(0, shape.size())) { + // Ignore dynamic dimension or the case that the dimension size is 0 + if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) + continue; + + // The first index or last index should be the maximum or the minimum in + // the inferred index ranges since the range is increasing or + // decreasing. The size of dimensions of input/output operands and the + // maximum value + 1 in the inferred range should be the same. But, for + // now we check if the inferred ranges are in boundary of input/output + // operands' size or not in case that Affine Expressions are complicated + // such as d0 * 3 + // + d1 since it is not easy to handle the issues. + // Found the case that this solution can't check, for example, (d0, d1) + // -> (d1 - d0) + int64_t inferredDimSize = + std::max(startIndices[dim], endIndices[dim]) + 1; + if (std::min(startIndices[dim], endIndices[dim]) < 0) { + std::string mapStr; + { + llvm::raw_string_ostream os(mapStr); + os << indexingMap; + } + return this->emitOpError( + "unexpected result less than 0 at expression #") + << dim << " in " << mapStr; + } + if (isa(indexingMap.getResult(dim))) { + if (inferredDimSize != shape[dim]) { + return this->emitOpError("inferred input/output operand #") + << opOperand.getOperandNumber() << " has shape's dimension #" + << dim << " to be " << inferredDimSize << ", but found " + << shape[dim]; + } + } else { + if (inferredDimSize > shape[dim]) { + return this->emitOpError("inferred input/output operand #") + << opOperand.getOperandNumber() << " has shape's dimension #" + << dim << " to be greater than or equal to " + << inferredDimSize << ", but found " << shape[dim]; + } + } + } + } + } + + return success(); +} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index c0c5f785e856..ca40301f04fa 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -151,7 +151,7 @@ func.func @generic_result_0_element_type(%arg0: memref(off + i)>>, %arg1: memref(off + i)>>) { - // expected-error @+1 {{expected the shape-to-loops map to be non-null}} + // expected-error @+1 {{invalid indexing maps are non-invertible: ((d0, d1) -> (d0 + d1, d0 + d1))}} linalg.generic { indexing_maps = [ affine_map<(i, j) -> (i + j)>,