`memref.collapse_shape` has verification logic to make sure result dim must be static if all the collapsing src dims are static. Cast folding might add more static information for the src operand of `memref.collapse_shape` which might change a valid collapsing operation to be invalid. Add `CollapseShapeOpMemRefCastFolder` pattern to fix this. Minor changes to `convertReassociationIndicesToExprs` to use `context` instead of `builder` to avoid extra steps to construct temporary builders. Reviewed By: nicolasvasilache, mravishankar Differential Revision: https://reviews.llvm.org/D106670
277 lines
9.8 KiB
C++
277 lines
9.8 KiB
C++
//===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
|
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include <numeric>
|
|
|
|
using namespace mlir;
|
|
|
|
Optional<SmallVector<ReassociationIndices>>
|
|
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
|
|
ShapedType targetType) {
|
|
// Make the sourceType greater rank than the targetType. If they are same
|
|
// rank, then its an unsupported reshape op.
|
|
if (sourceType.getRank() == targetType.getRank())
|
|
return llvm::None;
|
|
if (sourceType.getRank() < targetType.getRank())
|
|
std::swap(sourceType, targetType);
|
|
|
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
ArrayRef<int64_t> targetShape = targetType.getShape();
|
|
unsigned sourceDim = 0;
|
|
SmallVector<ReassociationIndices> reassociationMap;
|
|
reassociationMap.reserve(targetType.getRank());
|
|
|
|
ReassociationIndices currIndices;
|
|
int64_t prodOfCollapsedDims = 1;
|
|
while (sourceDim < sourceShape.size()) {
|
|
unsigned targetDim = reassociationMap.size();
|
|
|
|
// If all the dimensions of the targetShape are exhausted, then the
|
|
// remaining dims in the source shape must be all 1s. So for such cases, set
|
|
// 1 as the target shape. The actual reassociation indices will be handled
|
|
// later.
|
|
int64_t currTargetShape =
|
|
(targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
|
|
while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
|
|
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
|
|
sourceDim < sourceShape.size()) {
|
|
prodOfCollapsedDims *= sourceShape[sourceDim];
|
|
currIndices.push_back(sourceDim++);
|
|
}
|
|
|
|
// If the current expanded dimension is dynamic, then the collapsed
|
|
// dimensions should also be dynamic and product of all previous unprocessed
|
|
// dimensions of the expanded shape should be 1.
|
|
if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
|
|
(currTargetShape != ShapedType::kDynamicSize ||
|
|
prodOfCollapsedDims != 1))
|
|
return llvm::None;
|
|
|
|
// If the collapsed dim is dynamic, the current expanded dim should also
|
|
// be dynamic.
|
|
if (currTargetShape == ShapedType::kDynamicSize &&
|
|
sourceShape[sourceDim] != ShapedType::kDynamicSize)
|
|
return llvm::None;
|
|
|
|
// For static shapes, if the product of dimensions of the expanded shape
|
|
// should match the collapsed dimension shape.
|
|
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
|
|
return llvm::None;
|
|
|
|
currIndices.push_back(sourceDim++);
|
|
// If the reassociation is empty but the currIndices is not, this by
|
|
// definition is folding unit-dimensions with the result being scalar type.
|
|
// So only append the `currIndices` if reassociation map is not empty.
|
|
if (targetDim == targetShape.size()) {
|
|
if (!reassociationMap.empty() && !currIndices.empty())
|
|
reassociationMap.back().append(currIndices.begin(), currIndices.end());
|
|
// Break out of the loops. We should be done here.
|
|
break;
|
|
}
|
|
reassociationMap.emplace_back(ReassociationIndices{});
|
|
std::swap(reassociationMap.back(), currIndices);
|
|
prodOfCollapsedDims = 1;
|
|
}
|
|
// All the dimensions in the two shapes must have been processed.
|
|
if (reassociationMap.size() != targetShape.size() ||
|
|
sourceDim != sourceShape.size())
|
|
return llvm::None;
|
|
return reassociationMap;
|
|
}
|
|
|
|
ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
// Parse the operand.
|
|
OpAsmParser::OperandType src;
|
|
if (parser.parseOperand(src))
|
|
return failure();
|
|
|
|
// Parse reassociation indices.
|
|
Builder &b = parser.getBuilder();
|
|
SmallVector<Attribute, 4> reassociation;
|
|
if (parser.parseLSquare())
|
|
return failure();
|
|
|
|
while (true) {
|
|
if (succeeded(parser.parseOptionalRSquare()))
|
|
break;
|
|
if (parser.parseLSquare())
|
|
return failure();
|
|
SmallVector<int64_t> indices;
|
|
while (true) {
|
|
int64_t index;
|
|
if (parser.parseInteger(index))
|
|
return failure();
|
|
indices.push_back(index);
|
|
|
|
if (succeeded(parser.parseOptionalComma()))
|
|
continue;
|
|
if (failed(parser.parseRSquare()))
|
|
return failure();
|
|
break;
|
|
}
|
|
reassociation.push_back(b.getI64ArrayAttr(indices));
|
|
if (succeeded(parser.parseOptionalComma()))
|
|
continue;
|
|
if (failed(parser.parseRSquare()))
|
|
return failure();
|
|
break;
|
|
}
|
|
|
|
result.addAttribute(getReassociationAttrName(),
|
|
b.getArrayAttr(reassociation));
|
|
|
|
// Parse optional attributes.
|
|
parser.parseOptionalAttrDict(result.attributes);
|
|
|
|
// Parse types.
|
|
Type srcType;
|
|
Type resultType;
|
|
if (parser.parseColon() || parser.parseType(srcType) ||
|
|
parser.resolveOperand(src, srcType, result.operands) ||
|
|
parser.parseKeyword("into") || parser.parseType(resultType))
|
|
return failure();
|
|
result.addTypes(resultType);
|
|
return success();
|
|
}
|
|
|
|
Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
|
|
ArrayRef<ReassociationIndices> producerReassociations,
|
|
ArrayRef<ReassociationIndices> consumerReassociations,
|
|
MLIRContext *context) {
|
|
SmallVector<ReassociationIndices> composedIndices;
|
|
// Make the producer the larger sized vector. If they are of same size, the
|
|
// resulting reshape is not a supported reshape op.
|
|
if (producerReassociations.size() == consumerReassociations.size())
|
|
return llvm::None;
|
|
if (producerReassociations.size() < consumerReassociations.size())
|
|
std::swap(producerReassociations, consumerReassociations);
|
|
|
|
// Handle the corner case of the result being a rank 0 shaped type. Return an
|
|
// empty reassociation.
|
|
if (consumerReassociations.empty())
|
|
return composedIndices;
|
|
|
|
size_t consumerDims = std::accumulate(
|
|
consumerReassociations.begin(), consumerReassociations.end(), 0,
|
|
[](size_t all, ReassociationIndicesRef indices) {
|
|
return all + indices.size();
|
|
});
|
|
if (producerReassociations.size() != consumerDims)
|
|
return llvm::None;
|
|
|
|
for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
|
|
ReassociationIndices reassociations;
|
|
for (int64_t consumerIndex : consumerIndices) {
|
|
for (int64_t producerIndex : producerReassociations[consumerIndex])
|
|
reassociations.push_back(producerIndex);
|
|
}
|
|
composedIndices.push_back(std::move(reassociations));
|
|
}
|
|
return composedIndices;
|
|
}
|
|
|
|
SmallVector<SmallVector<AffineExpr, 2>, 2>
|
|
mlir::convertReassociationIndicesToExprs(
|
|
MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
|
|
SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
|
|
for (const auto &indices : reassociationIndices) {
|
|
SmallVector<AffineExpr, 2> reassociationMap;
|
|
reassociationMap.reserve(indices.size());
|
|
for (int64_t index : indices)
|
|
reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
|
|
reassociationMaps.push_back(std::move(reassociationMap));
|
|
}
|
|
return reassociationMaps;
|
|
}
|
|
|
|
template <typename AffineExprTy>
|
|
unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
|
|
unsigned pos = 0;
|
|
for (const auto &exprs : exprArrays) {
|
|
for (auto expr : exprs) {
|
|
expr.walk([&pos](AffineExpr e) {
|
|
if (auto d = e.dyn_cast<AffineExprTy>())
|
|
pos = std::max(pos, d.getPosition());
|
|
});
|
|
}
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
ArrayAttr mlir::getReassociationIndicesAttribute(
|
|
OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
|
|
SmallVector<Attribute, 4> reassociationAttr =
|
|
llvm::to_vector<4>(llvm::map_range(
|
|
reassociation, [&](ReassociationIndices indices) -> Attribute {
|
|
return b.getI64ArrayAttr(indices).cast<Attribute>();
|
|
}));
|
|
return b.getArrayAttr(reassociationAttr);
|
|
}
|
|
|
|
SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
|
|
OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
|
|
SmallVector<ReassociationIndices, 2> reassociationIndices;
|
|
for (const auto &exprs : reassociationExprs) {
|
|
ReassociationIndices indices;
|
|
indices.reserve(exprs.size());
|
|
for (const auto &expr : exprs)
|
|
indices.push_back(expr.cast<AffineDimExpr>().getPosition());
|
|
reassociationIndices.push_back(indices);
|
|
}
|
|
return reassociationIndices;
|
|
}
|
|
|
|
SmallVector<AffineMap, 4>
|
|
mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
|
|
unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
|
|
assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
|
|
"Expected symbol-less expressions");
|
|
SmallVector<AffineMap, 4> maps;
|
|
maps.reserve(reassociation.size());
|
|
for (const auto &exprs : reassociation) {
|
|
assert(!exprs.empty());
|
|
maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
|
|
}
|
|
return maps;
|
|
}
|
|
bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
|
|
int *invalidIndex) {
|
|
if (reassociation.empty())
|
|
return true;
|
|
unsigned nDims = reassociation[0].getNumDims();
|
|
unsigned nextExpectedDim = 0;
|
|
for (auto it : llvm::enumerate(reassociation)) {
|
|
auto m = it.value();
|
|
if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
|
|
if (invalidIndex)
|
|
*invalidIndex = it.index();
|
|
return false;
|
|
}
|
|
for (auto e : m.getResults()) {
|
|
auto d = e.dyn_cast<AffineDimExpr>();
|
|
if (!d || d.getPosition() != nextExpectedDim++) {
|
|
if (invalidIndex)
|
|
*invalidIndex = it.index();
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
if (nextExpectedDim != nDims) {
|
|
if (invalidIndex)
|
|
*invalidIndex = reassociation.size() - 1;
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|