//===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include using namespace mlir; Optional> mlir::getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType) { // Make the sourceType greater rank than the targetType. If they are same // rank, then its an unsupported reshape op. if (sourceType.getRank() == targetType.getRank()) return llvm::None; if (sourceType.getRank() < targetType.getRank()) std::swap(sourceType, targetType); ArrayRef sourceShape = sourceType.getShape(); ArrayRef targetShape = targetType.getShape(); unsigned sourceDim = 0; SmallVector reassociationMap; reassociationMap.reserve(targetType.getRank()); ReassociationIndices currIndices; int64_t prodOfCollapsedDims = 1; while (sourceDim < sourceShape.size()) { unsigned targetDim = reassociationMap.size(); // If all the dimensions of the targetShape are exhausted, then the // remaining dims in the source shape must be all 1s. So for such cases, set // 1 as the target shape. The actual reassociation indices will be handled // later. int64_t currTargetShape = (targetDim < targetType.getRank() ? targetShape[targetDim] : 1); while (sourceShape[sourceDim] != ShapedType::kDynamicSize && prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && sourceDim < sourceShape.size()) { prodOfCollapsedDims *= sourceShape[sourceDim]; currIndices.push_back(sourceDim++); } // If the current expanded dimension is dynamic, then the collapsed // dimensions should also be dynamic and product of all previous unprocessed // dimensions of the expanded shape should be 1. if (sourceShape[sourceDim] == ShapedType::kDynamicSize && (currTargetShape != ShapedType::kDynamicSize || prodOfCollapsedDims != 1)) return llvm::None; // If the collapsed dim is dynamic, the current expanded dim should also // be dynamic. if (currTargetShape == ShapedType::kDynamicSize && sourceShape[sourceDim] != ShapedType::kDynamicSize) return llvm::None; // For static shapes, if the product of dimensions of the expanded shape // should match the collapsed dimension shape. if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) return llvm::None; currIndices.push_back(sourceDim++); // If the reassociation is empty but the currIndices is not, this by // definition is folding unit-dimensions with the result being scalar type. // So only append the `currIndices` if reassociation map is not empty. if (targetDim == targetShape.size()) { if (!reassociationMap.empty() && !currIndices.empty()) reassociationMap.back().append(currIndices.begin(), currIndices.end()); // Break out of the loops. We should be done here. break; } reassociationMap.emplace_back(ReassociationIndices{}); std::swap(reassociationMap.back(), currIndices); prodOfCollapsedDims = 1; } // All the dimensions in the two shapes must have been processed. if (reassociationMap.size() != targetShape.size() || sourceDim != sourceShape.size()) return llvm::None; return reassociationMap; } ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser, OperationState &result) { // Parse the operand. OpAsmParser::OperandType src; if (parser.parseOperand(src)) return failure(); // Parse reassociation indices. Builder &b = parser.getBuilder(); SmallVector reassociation; if (parser.parseLSquare()) return failure(); while (true) { if (succeeded(parser.parseOptionalRSquare())) break; if (parser.parseLSquare()) return failure(); SmallVector indices; while (true) { int64_t index; if (parser.parseInteger(index)) return failure(); indices.push_back(index); if (succeeded(parser.parseOptionalComma())) continue; if (failed(parser.parseRSquare())) return failure(); break; } reassociation.push_back(b.getI64ArrayAttr(indices)); if (succeeded(parser.parseOptionalComma())) continue; if (failed(parser.parseRSquare())) return failure(); break; } result.addAttribute(getReassociationAttrName(), b.getArrayAttr(reassociation)); // Parse optional attributes. parser.parseOptionalAttrDict(result.attributes); // Parse types. Type srcType; Type resultType; if (parser.parseColon() || parser.parseType(srcType) || parser.resolveOperand(src, srcType, result.operands) || parser.parseKeyword("into") || parser.parseType(resultType)) return failure(); result.addTypes(resultType); return success(); } Optional> mlir::composeReassociationIndices( ArrayRef producerReassociations, ArrayRef consumerReassociations, MLIRContext *context) { SmallVector composedIndices; // Make the producer the larger sized vector. If they are of same size, the // resulting reshape is not a supported reshape op. if (producerReassociations.size() == consumerReassociations.size()) return llvm::None; if (producerReassociations.size() < consumerReassociations.size()) std::swap(producerReassociations, consumerReassociations); // Handle the corner case of the result being a rank 0 shaped type. Return an // empty reassociation. if (consumerReassociations.empty()) return composedIndices; size_t consumerDims = std::accumulate( consumerReassociations.begin(), consumerReassociations.end(), 0, [](size_t all, ReassociationIndicesRef indices) { return all + indices.size(); }); if (producerReassociations.size() != consumerDims) return llvm::None; for (ReassociationIndicesRef consumerIndices : consumerReassociations) { ReassociationIndices reassociations; for (int64_t consumerIndex : consumerIndices) { for (int64_t producerIndex : producerReassociations[consumerIndex]) reassociations.push_back(producerIndex); } composedIndices.push_back(std::move(reassociations)); } return composedIndices; } SmallVector, 2> mlir::convertReassociationIndicesToExprs( MLIRContext *context, ArrayRef reassociationIndices) { SmallVector, 2> reassociationMaps; for (const auto &indices : reassociationIndices) { SmallVector reassociationMap; reassociationMap.reserve(indices.size()); for (int64_t index : indices) reassociationMap.push_back(mlir::getAffineDimExpr(index, context)); reassociationMaps.push_back(std::move(reassociationMap)); } return reassociationMaps; } template unsigned getMaxPosOfType(ArrayRef exprArrays) { unsigned pos = 0; for (const auto &exprs : exprArrays) { for (auto expr : exprs) { expr.walk([&pos](AffineExpr e) { if (auto d = e.dyn_cast()) pos = std::max(pos, d.getPosition()); }); } } return pos; } ArrayAttr mlir::getReassociationIndicesAttribute( OpBuilder &b, ArrayRef reassociation) { SmallVector reassociationAttr = llvm::to_vector<4>(llvm::map_range( reassociation, [&](ReassociationIndices indices) -> Attribute { return b.getI64ArrayAttr(indices).cast(); })); return b.getArrayAttr(reassociationAttr); } SmallVector mlir::convertReassociationMapsToIndices( OpBuilder &b, ArrayRef reassociationExprs) { SmallVector reassociationIndices; for (const auto &exprs : reassociationExprs) { ReassociationIndices indices; indices.reserve(exprs.size()); for (const auto &expr : exprs) indices.push_back(expr.cast().getPosition()); reassociationIndices.push_back(indices); } return reassociationIndices; } SmallVector mlir::getSymbolLessAffineMaps(ArrayRef reassociation) { unsigned maxDim = getMaxPosOfType(reassociation); assert(getMaxPosOfType(reassociation) == 0 && "Expected symbol-less expressions"); SmallVector maps; maps.reserve(reassociation.size()); for (const auto &exprs : reassociation) { assert(!exprs.empty()); maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); } return maps; } bool mlir::isReassociationValid(ArrayRef reassociation, int *invalidIndex) { if (reassociation.empty()) return true; unsigned nDims = reassociation[0].getNumDims(); unsigned nextExpectedDim = 0; for (auto it : llvm::enumerate(reassociation)) { auto m = it.value(); if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { if (invalidIndex) *invalidIndex = it.index(); return false; } for (auto e : m.getResults()) { auto d = e.dyn_cast(); if (!d || d.getPosition() != nextExpectedDim++) { if (invalidIndex) *invalidIndex = it.index(); return false; } } } if (nextExpectedDim != nDims) { if (invalidIndex) *invalidIndex = reassociation.size() - 1; return false; } return true; }