[mlir] Reapply "Loosen restrictions on folding dynamic reshapes" (#142827)
The original PR https://github.com/llvm/llvm-project/pull/137963 had a
nvidia bot failure. This appears to be a flaky test because rerunning
the build was successful.
This change needs commit 6f2ba47 to fix incorrect usage of
`getReassociationIndicesForCollapse`.
Reverts llvm/llvm-project#142639
Co-authored-by: Artem Gindinson <gindinson@roofline.ai>
This commit is contained in:
@@ -10,6 +10,10 @@
|
||||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
@@ -28,67 +32,329 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// A simple struct to represent ReassociationIndices as an inclusive interval.
|
||||
/// It's designed to be feasibly minimal, so the call sites should manage the
|
||||
/// validity of the range manually.
|
||||
struct ReassociationIndexRange {
|
||||
/// FIXME: Signed type is used for consistency with ReassociationIndices.
|
||||
/// We should consider refactoring all reassociation utilities to use unsigned
|
||||
/// types.
|
||||
int64_t leftIdx = 0, rightIdx = 0;
|
||||
|
||||
/// Util for manual checks of the range's validity
|
||||
LogicalResult verify() const {
|
||||
return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
|
||||
}
|
||||
|
||||
/// Checks range's containment within another range. Treats the edges
|
||||
/// non-exclusively.
|
||||
bool isInRange(const ReassociationIndexRange &outerRange) const {
|
||||
return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
|
||||
}
|
||||
|
||||
unsigned size() const {
|
||||
assert(succeeded(verify()));
|
||||
return rightIdx - leftIdx + 1;
|
||||
}
|
||||
bool containsSingleIndex() const { return size() == 1; }
|
||||
|
||||
/// Collects indices that do not overlap between this and another range.
|
||||
ReassociationIndices
|
||||
getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
|
||||
if (rightIdx < rhs.leftIdx) {
|
||||
// The intervals do not overlap - concatenate the indices from both.
|
||||
auto jointFullIndices = getFullIndices();
|
||||
jointFullIndices.append(rhs.getFullIndices());
|
||||
return jointFullIndices;
|
||||
}
|
||||
ReassociationIndices result;
|
||||
// Handle the chunk left of the overlapping range.
|
||||
int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
|
||||
int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
|
||||
llvm::append_range(result, llvm::seq(leftStart, leftEnd));
|
||||
// Handle the chunk right of the overlapping range. Symmetrically, we should
|
||||
// skip the edge of the overlap AND include the rightmost index.
|
||||
int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
|
||||
int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
|
||||
if (rightStart < rightEnd)
|
||||
llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Converts the range into ReassociationIndices.
|
||||
ReassociationIndices getFullIndices() const {
|
||||
ReassociationIndices result;
|
||||
for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
|
||||
result.push_back(idx);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
|
||||
/// sequence that can be collapsed into a dynamic dimension (at least one must
|
||||
/// be present in the source).
|
||||
/// By default, lazily returns once the first dynamic dimension has been found.
|
||||
/// Setting `matchGreedily` as `true` will also mark all subsequent
|
||||
/// source dimensions for collapsing into the target.
|
||||
static FailureOr<ReassociationIndexRange>
|
||||
findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
|
||||
int64_t sourceStartIdx,
|
||||
bool matchGreedily = false) {
|
||||
const unsigned numSourceDims = sourceShape.size();
|
||||
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
|
||||
std::optional<ReassociationIndexRange> resultRange = std::nullopt;
|
||||
|
||||
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
|
||||
for (; iterationRange.isInRange(sourceShapeAsRange);
|
||||
iterationRange.rightIdx++) {
|
||||
int64_t sourceSize = sourceShape[iterationRange.rightIdx];
|
||||
if (sourceSize == ShapedType::kDynamic) {
|
||||
resultRange = iterationRange;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!resultRange)
|
||||
return failure();
|
||||
if (matchGreedily)
|
||||
resultRange->rightIdx = sourceShapeAsRange.rightIdx;
|
||||
return *resultRange;
|
||||
}
|
||||
|
||||
/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
|
||||
/// sequence of static dimensions such that their product matches `targetSize`.
|
||||
/// By default, lazily returns once the product matches the target size. Setting
|
||||
/// `matchGreedily` as `true` will append all neighboring unit dimensions
|
||||
/// (dimensions of 1) to the match.
|
||||
static FailureOr<ReassociationIndexRange>
|
||||
findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
|
||||
int64_t sourceStartIdx, int64_t targetSize,
|
||||
bool matchGreedily = false) {
|
||||
const unsigned numSourceDims = sourceShape.size();
|
||||
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
|
||||
std::optional<ReassociationIndexRange> resultRange = std::nullopt;
|
||||
|
||||
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
|
||||
int64_t prodOfCollapsedDims = 1;
|
||||
while (iterationRange.isInRange(sourceShapeAsRange)) {
|
||||
int64_t sourceSize = sourceShape[iterationRange.rightIdx];
|
||||
if (sourceSize == ShapedType::kDynamic) {
|
||||
// Reassociation for a static dim cannot include a dynamic dim. Reset
|
||||
// induction variables to essentially restart the loop from the next
|
||||
// source dimension.
|
||||
prodOfCollapsedDims = 1;
|
||||
iterationRange = {iterationRange.rightIdx + 1,
|
||||
iterationRange.rightIdx + 1};
|
||||
continue;
|
||||
}
|
||||
prodOfCollapsedDims *= sourceSize;
|
||||
// If the target size has been exceeded without matching, we need to shift
|
||||
// the range start right. From the start of the range, roll back the
|
||||
// multiplication until the target size exceeds the product again.
|
||||
while (prodOfCollapsedDims > targetSize &&
|
||||
!iterationRange.containsSingleIndex()) {
|
||||
int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
|
||||
prodOfCollapsedDims /= frontSourceSize;
|
||||
// Shrink the range rightwards
|
||||
iterationRange.leftIdx++;
|
||||
}
|
||||
// We could've reached the target size with the current dimension,
|
||||
// also as a result of the above shift to right.
|
||||
if (prodOfCollapsedDims == targetSize) {
|
||||
resultRange = iterationRange;
|
||||
break;
|
||||
}
|
||||
// Increment the iteration range
|
||||
iterationRange.rightIdx++;
|
||||
}
|
||||
if (!resultRange)
|
||||
return failure();
|
||||
if (matchGreedily) {
|
||||
// We now want to collect all unit dimensions directly after the target
|
||||
// product match. Advance the iterator to avoid OOB when the product match
|
||||
// happens at the last element.
|
||||
iterationRange.rightIdx++;
|
||||
while (iterationRange.isInRange(sourceShapeAsRange) &&
|
||||
sourceShape[iterationRange.rightIdx] == 1) {
|
||||
resultRange = iterationRange;
|
||||
iterationRange.rightIdx++;
|
||||
}
|
||||
}
|
||||
return *resultRange;
|
||||
}
|
||||
|
||||
/// Attempts to find a valid collapsing reassociation of `sourceShape` into
|
||||
/// `targetShape` through a simple traversal. If successful, an array of source
|
||||
/// index ranges is returned, correspondingly to each dimension in the target
|
||||
/// shape. The resulting indices shall fully cover the `sourceShape` without
|
||||
/// overlaps.
|
||||
///
|
||||
/// The algorithm is essentially a lazy one, searching for non-greedy matches -
|
||||
/// it will only yield a greedy match for the last target dimension.
|
||||
/// FIXME: The algorithm can only backtrack when it needs to append an offset
|
||||
/// for a static target dimension to the preceding dynamic one (this retains the
|
||||
/// linear complexity). As feasible, consider adding further backtracking
|
||||
/// routines to enable more reassociations, e.g.:
|
||||
/// - ?x2x?x2 into ?x2
|
||||
static FailureOr<SmallVector<ReassociationIndexRange>>
|
||||
findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
unsigned numSourceDims = sourceShape.size(),
|
||||
numTargetDims = targetShape.size();
|
||||
assert(numSourceDims > numTargetDims);
|
||||
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
|
||||
|
||||
SmallVector<ReassociationIndexRange> reassocRanges;
|
||||
reassocRanges.reserve(numTargetDims);
|
||||
// We'll iterate in strides of 2 to enable pseudo-backtracking for simple
|
||||
// cases, e.g.:
|
||||
// - ?x2x3x5 into ?x15
|
||||
std::optional<int64_t> prevTargetSize = std::nullopt;
|
||||
for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
|
||||
targetDimIdx < numTargetDims; ++targetDimIdx) {
|
||||
int64_t targetSize = targetShape[targetDimIdx];
|
||||
// Simply check if there are any subsequent target dimensions left - if not,
|
||||
// the match must be made greedily.
|
||||
bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
|
||||
FailureOr<ReassociationIndexRange> sourceRange;
|
||||
if (targetSize == ShapedType::kDynamic) {
|
||||
sourceRange = findReassociationRangeForDynamicDim(
|
||||
sourceShape, sourceDimIdx, shouldMatchGreedily);
|
||||
} else {
|
||||
sourceRange = findReassociationRangeForSize(
|
||||
sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
|
||||
}
|
||||
|
||||
// Run sanity checks on the returned index range.
|
||||
if (failed(sourceRange) || failed(sourceRange->verify()) ||
|
||||
!sourceRange->isInRange(sourceShapeAsRange))
|
||||
return failure();
|
||||
if (sourceRange->leftIdx > sourceDimIdx) {
|
||||
// If some source dimensions had to be skipped in order to find a match,
|
||||
// they must be collapsed into the directly preceding dynamic dimension.
|
||||
if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
|
||||
return failure();
|
||||
reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
|
||||
}
|
||||
|
||||
// Store the gathered information as required for the next iteration.
|
||||
prevTargetSize = targetSize;
|
||||
sourceDimIdx = sourceRange->rightIdx + 1;
|
||||
reassocRanges.push_back(*sourceRange);
|
||||
}
|
||||
// Fail if the source shape wasn't a full match for the target shape. We only
|
||||
// need to check the last recorded index - any other gaps should have been
|
||||
// mended by the main loop.
|
||||
if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
|
||||
return failure();
|
||||
return reassocRanges;
|
||||
}
|
||||
|
||||
/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
|
||||
/// the shapes right-to-left.
|
||||
static FailureOr<SmallVector<ReassociationIndexRange>>
|
||||
findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
bool iterateRightToLeft) {
|
||||
if (!iterateRightToLeft)
|
||||
return findReassociationRangesForCollapse(sourceShape, targetShape);
|
||||
// NB: To iterate right-to-left, we currently reverse the shapes and then
|
||||
// reverse the result back. The reversed shapes must not be temporary, as
|
||||
// we're passing through an ArrayRef.
|
||||
// FIXME: It would be preferable to avoid the expensive copies. At the moment,
|
||||
// this approach is chosen for readability of the main implementation.
|
||||
std::vector<int64_t> sourceToReverse = sourceShape.vec(),
|
||||
targetToReverse = targetShape.vec();
|
||||
std::reverse(sourceToReverse.begin(), sourceToReverse.end());
|
||||
std::reverse(targetToReverse.begin(), targetToReverse.end());
|
||||
auto invertedRanges =
|
||||
findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
|
||||
if (failed(invertedRanges))
|
||||
return failure();
|
||||
SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
|
||||
unsigned numSourceDims = sourceShape.size();
|
||||
// We have received the ranges for inverted shapes. Now we have to invert
|
||||
// the ranges back to correspond with the original source shape.
|
||||
for (auto &range : rangesToInvert) {
|
||||
int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
|
||||
range.leftIdx = numSourceDims - 1 - invRightIdx;
|
||||
range.rightIdx = numSourceDims - 1 - invLeftIdx;
|
||||
}
|
||||
// Also invert the ordering of the ranges to correspond with the original
|
||||
// target shape.
|
||||
std::reverse(rangesToInvert.begin(), rangesToInvert.end());
|
||||
return rangesToInvert;
|
||||
}
|
||||
|
||||
std::optional<SmallVector<ReassociationIndices>>
|
||||
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
if (sourceShape.size() <= targetShape.size())
|
||||
unsigned numSourceDims = sourceShape.size(),
|
||||
numTargetDims = targetShape.size();
|
||||
// We're supposed to search for a collapsing reassociation. If the sizes
|
||||
// match, there's no actual collapsing taking place - it's either a no-op or a
|
||||
// `tensor.reshape`-style reassociation (that would be beyond the scope of
|
||||
// this utility).
|
||||
if (numSourceDims <= numTargetDims)
|
||||
return std::nullopt;
|
||||
unsigned sourceDim = 0;
|
||||
SmallVector<ReassociationIndices> reassociationMap;
|
||||
reassociationMap.reserve(targetShape.size());
|
||||
|
||||
ReassociationIndices currIndices;
|
||||
int64_t prodOfCollapsedDims = 1;
|
||||
while (sourceDim < sourceShape.size()) {
|
||||
unsigned targetDim = reassociationMap.size();
|
||||
// If we have mapped all the target dimensions stop and handle the remaining
|
||||
// tail of size-1 dimensions explicitly.
|
||||
if (targetDim == targetShape.size())
|
||||
break;
|
||||
|
||||
int64_t currTargetShape = targetShape[targetDim];
|
||||
while (sourceDim < (sourceShape.size() - 1) &&
|
||||
sourceShape[sourceDim] != ShapedType::kDynamic &&
|
||||
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
|
||||
prodOfCollapsedDims *= sourceShape[sourceDim];
|
||||
currIndices.push_back(sourceDim++);
|
||||
// Early handling for scalar target types.
|
||||
if (numTargetDims == 0) {
|
||||
ReassociationIndices allSourceIndices;
|
||||
allSourceIndices.reserve(numSourceDims);
|
||||
for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
|
||||
++sourceDimIdx) {
|
||||
int64_t sourceSize = sourceShape[sourceDimIdx];
|
||||
// All source dimensions must be unit or dynamic.
|
||||
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
|
||||
return std::nullopt;
|
||||
allSourceIndices.push_back(sourceDimIdx);
|
||||
}
|
||||
|
||||
// If the current expanded dimension is dynamic, then the collapsed
|
||||
// dimensions should also be dynamic and product of all previous unprocessed
|
||||
// dimensions of the expanded shape should be 1.
|
||||
if (sourceShape[sourceDim] == ShapedType::kDynamic &&
|
||||
(currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
|
||||
return std::nullopt;
|
||||
|
||||
// If the collapsed dim is dynamic, the current expanded dim should also
|
||||
// be dynamic.
|
||||
if (currTargetShape == ShapedType::kDynamic &&
|
||||
sourceShape[sourceDim] != ShapedType::kDynamic)
|
||||
return std::nullopt;
|
||||
|
||||
// For static shapes, if the product of dimensions of the expanded shape
|
||||
// should match the collapsed dimension shape.
|
||||
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
|
||||
return std::nullopt;
|
||||
|
||||
currIndices.push_back(sourceDim++);
|
||||
reassociationMap.emplace_back(ReassociationIndices{});
|
||||
std::swap(reassociationMap.back(), currIndices);
|
||||
prodOfCollapsedDims = 1;
|
||||
return SmallVector<ReassociationIndices>{allSourceIndices};
|
||||
}
|
||||
// All the dimensions in the target must have been processed.
|
||||
if (reassociationMap.size() != targetShape.size())
|
||||
|
||||
// Collect source ranges by iterating over the target shape left-to-right.
|
||||
FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
|
||||
findReassociationRangesForCollapse(sourceShape, targetShape);
|
||||
if (failed(maybeForwardRanges))
|
||||
return std::nullopt;
|
||||
// Process any remaining entries in the source shape. They all need to be
|
||||
// 1 or dynamic.
|
||||
for (; sourceDim < sourceShape.size(); sourceDim++) {
|
||||
if (sourceShape[sourceDim] != ShapedType::kDynamic &&
|
||||
sourceShape[sourceDim] != 1)
|
||||
return std::nullopt;
|
||||
// The map is empty when the target type is a scalar.
|
||||
if (!reassociationMap.empty())
|
||||
reassociationMap.back().push_back(sourceDim);
|
||||
auto &ranges = *maybeForwardRanges;
|
||||
// Now do the same in reverse. We need to get another valid reassociation
|
||||
// through some other strategy, and then compare the results in order to
|
||||
// disambiguate mixed subshapes, such as:
|
||||
// ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
|
||||
// This leads us to lose some of the reassociation opportunities that can only
|
||||
// be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
|
||||
// backtracking, the algorithm will fail right-to-left. However, this is the
|
||||
// best way to preserve correctness.
|
||||
FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
|
||||
findReassociationRangesForCollapse(sourceShape, targetShape,
|
||||
/*iterateRightToLeft=*/true);
|
||||
if (failed(maybeReverseRanges))
|
||||
return std::nullopt;
|
||||
auto &reverseRanges = *maybeReverseRanges;
|
||||
|
||||
if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
|
||||
return std::nullopt;
|
||||
// Now we can check for ambiguity of each target dimension's reassociation. If
|
||||
// successful, we put the full indices into our result map for the target
|
||||
// shape.
|
||||
SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
|
||||
for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
|
||||
++targetDimIdx) {
|
||||
ReassociationIndexRange &range = ranges[targetDimIdx];
|
||||
ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
|
||||
// Get non-overlapping indices between the ranges
|
||||
ReassociationIndices nonMatchingIndices =
|
||||
range.getNonOverlappingIndicesWith(reverseRange);
|
||||
// Unit dimensions can be collapsed wherever - this is the only ambiguity
|
||||
// that we allow.
|
||||
for (int64_t sourceDimIdx : nonMatchingIndices) {
|
||||
if (sourceShape[sourceDimIdx] != 1)
|
||||
return std::nullopt;
|
||||
}
|
||||
reassociationMap[targetDimIdx] = range.getFullIndices();
|
||||
}
|
||||
return reassociationMap;
|
||||
}
|
||||
|
||||
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @unpack_dynamic
|
||||
// CHECK-NOT: tensor.collapse
|
||||
// CHECK: linalg.unpack
|
||||
// CHECK: tensor.collapse
|
||||
// CHECK-NOT: linalg.unpack
|
||||
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
|
||||
%c32 = arith.constant 32 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
|
||||
@@ -1101,7 +1101,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
|
||||
func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
|
||||
-> tensor<?x4x?xf32> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
|
||||
: tensor<?x4x?xf32> into tensor<?x?xf32>
|
||||
@@ -1109,12 +1109,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: ind
|
||||
: tensor<?x?xf32> into tensor<?x4x?xf32>
|
||||
return %1 : tensor<?x4x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: @fold_expand_of_collapse_dynamic
|
||||
// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
|
||||
// CHECK-NOT: tensor.{{.*}}_shape
|
||||
|
||||
// -----
|
||||
|
||||
func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
|
||||
func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
|
||||
-> tensor<?x4x?xf32> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
|
||||
: tensor<?x4x?x2xf32> into tensor<?x?xf32>
|
||||
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
|
||||
: tensor<?x?xf32> into tensor<?x4x?xf32>
|
||||
return %1 : tensor<?x4x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
|
||||
// CHECK-NOT: tensor.expand_shape
|
||||
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
|
||||
// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
|
||||
// CHECK-NEXT: return %[[COLLAPSE]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
|
||||
-> tensor<?x?x?xf32> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
|
||||
: tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||
@@ -1122,7 +1138,22 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1:
|
||||
: tensor<?x?xf32> into tensor<?x?x?xf32>
|
||||
return %1 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
|
||||
// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
|
||||
// CHECK: tensor.collapse_shape
|
||||
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
|
||||
// CHECK: return %[[EXPAND]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @no_fold_expand_of_collapse_adjacent_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index)
|
||||
-> tensor<?x?xf32> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
|
||||
: tensor<?x?x?xf32> into tensor<?xf32>
|
||||
%1 = tensor.expand_shape %0 [[0, 1]] output_shape [%arg1, %arg2]
|
||||
: tensor<?xf32> into tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: @no_fold_expand_of_collapse_adjacent_dynamic
|
||||
// CHECK: tensor.collapse_shape
|
||||
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
|
||||
// CHECK: return %[[EXPAND]]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
add_mlir_unittest(MLIRDialectUtilsTests
|
||||
StructuredOpsUtilsTest.cpp
|
||||
ReshapeOpsUtilsTest.cpp
|
||||
IndexingUtilsTest.cpp
|
||||
)
|
||||
mlir_target_link_libraries(MLIRDialectUtilsTests
|
||||
|
||||
203
mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
Normal file
203
mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
Normal file
@@ -0,0 +1,203 @@
|
||||
//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests ---------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include <optional>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Helper to make constructing
|
||||
/// `std::optional<SmallVector<ReassociationIndices>>` more readable.
|
||||
static std::optional<SmallVector<ReassociationIndices>>
|
||||
makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
|
||||
return std::optional<SmallVector<ReassociationIndices>>(list);
|
||||
}
|
||||
|
||||
TEST(ReassociationIndicesForCollapse, ScalarTest) {
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
|
||||
makeOptionalIndices({{0}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
|
||||
makeOptionalIndices({{0, 1}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
|
||||
makeOptionalIndices({{0}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
|
||||
ShapedType::kDynamic, 1,
|
||||
ShapedType::kDynamic},
|
||||
{}),
|
||||
makeOptionalIndices({{0, 1, 2, 3, 4}}));
|
||||
}
|
||||
|
||||
TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt);
|
||||
EXPECT_EQ(
|
||||
getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}),
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
TEST(ReassociationIndicesForCollapse, StaticTest) {
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
|
||||
makeOptionalIndices({{0, 1}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}),
|
||||
makeOptionalIndices({{0}, {1, 2}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}),
|
||||
makeOptionalIndices({{0, 1}, {2}}));
|
||||
}
|
||||
|
||||
TEST(ReassociationIndicesForCollapse, StaticTestFailure) {
|
||||
// No-op reassociation
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}),
|
||||
std::nullopt);
|
||||
// Invalid static reassociations
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}),
|
||||
std::nullopt);
|
||||
// Non-collapsing (expanding) reassociation
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}),
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) {
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}),
|
||||
makeOptionalIndices({{0, 1}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}),
|
||||
makeOptionalIndices({{0, 1, 2}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}),
|
||||
makeOptionalIndices({{0, 1, 2}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1, 1}, {1, 1, 1}),
|
||||
makeOptionalIndices({{0}, {1}, {2, 3}}));
|
||||
}
|
||||
|
||||
TEST(ReassociationIndicesForCollapse, DynamicTest) {
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1},
|
||||
{ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0, 1}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1},
|
||||
{ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0, 1, 2}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{1, ShapedType::kDynamic, 1, ShapedType::kDynamic, 1},
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0, 1}, {2, 3, 4}}));
|
||||
EXPECT_EQ(
|
||||
getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0, 1}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{1, ShapedType::kDynamic, ShapedType::kDynamic},
|
||||
{1, ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0}, {1, 2}}));
|
||||
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10},
|
||||
{ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0, 1}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{1, ShapedType::kDynamic, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0, 1, 2}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0, 1}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, 1, 2, ShapedType::kDynamic, 10},
|
||||
{ShapedType::kDynamic, 10}),
|
||||
makeOptionalIndices({{0, 1, 2, 3}, {4}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
|
||||
{ShapedType::kDynamic, 20}),
|
||||
makeOptionalIndices({{0, 1}, {2}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20},
|
||||
{ShapedType::kDynamic, 20}),
|
||||
makeOptionalIndices({{0, 1}, {2}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 20}),
|
||||
makeOptionalIndices({{0, 1}, {2, 3, 4}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1},
|
||||
{ShapedType::kDynamic, 20, ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0, 1}, {2}, {3, 4}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
|
||||
{ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0, 1, 2}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic, 1},
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0}, {1, 2}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{1, ShapedType::kDynamic, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0, 1}, {2}}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, 1, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic}),
|
||||
makeOptionalIndices({{0}, {1, 2}}));
|
||||
}
|
||||
|
||||
TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
|
||||
{ShapedType::kDynamic, 10}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, 10, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{20, ShapedType::kDynamic, 10, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 20}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(
|
||||
getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic, 10, 1,
|
||||
ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic, ShapedType::kDynamic}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic, 10, ShapedType::kDynamic}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic, 12, ShapedType::kDynamic}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic},
|
||||
{ShapedType::kDynamic, 32, ShapedType::kDynamic}),
|
||||
std::nullopt);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: Reassociation for the following examples can be computed, but isn't
|
||||
// supported by `getReassociationIndicesForCollapse`.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: Fails because there's no backtracking when some source dimensions
|
||||
// remain unmatched at either edge.
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse(
|
||||
{ShapedType::kDynamic, 10, ShapedType::kDynamic, 10},
|
||||
{ShapedType::kDynamic, 10}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 2, 2},
|
||||
{1, ShapedType::kDynamic, 2}),
|
||||
std::nullopt);
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({2, 2, ShapedType::kDynamic, 1},
|
||||
{2, ShapedType::kDynamic}),
|
||||
std::nullopt);
|
||||
}
|
||||
Reference in New Issue
Block a user