[mlir] Fix a zero stride canonicalizer crash (#74200)
This PR fixes https://github.com/llvm/llvm-project/issues/73383 and is another shot at the refactoring proposed in https://github.com/llvm/llvm-project/pull/72885. --------- Co-authored-by: Kai Sasaki <lewuathe@gmail.com>
This commit is contained in:
@@ -139,12 +139,36 @@ SmallVector<int64_t>
|
||||
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
|
||||
llvm::function_ref<bool(Attribute, Attribute)> compare);
|
||||
|
||||
/// Helper function to check whether the passed in `sizes` or `offsets` are
|
||||
/// valid. This can be used to re-check whether dimensions are still valid
|
||||
/// after constant folding the dynamic dimensions.
|
||||
bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
|
||||
|
||||
/// Helper function to check whether the passed in `strides` are valid. This
|
||||
/// can be used to re-check whether dimensions are still valid after constant
|
||||
/// folding the dynamic dimensions.
|
||||
bool hasValidStrides(SmallVector<int64_t> strides);
|
||||
|
||||
/// Returns "success" when any of the elements in `ofrs` is a constant value. In
|
||||
/// that case the value is replaced by an attribute. Returns "failure" when no
|
||||
/// folding happened. If `onlyNonNegative` is set, only non-negative constant
|
||||
/// values are folded.
|
||||
/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
|
||||
/// non-negative and non-zero constant values are folded respectively.
|
||||
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
|
||||
bool onlyNonNegative = false);
|
||||
bool onlyNonNegative = false,
|
||||
bool onlyNonZero = false);
|
||||
|
||||
/// Returns "success" when any of the elements in `offsetsOrSizes` is a
|
||||
/// constant value. In that case the value is replaced by an attribute. Returns
|
||||
/// "failure" when no folding happened. Invalid values are not folded to avoid
|
||||
/// canonicalization crashes.
|
||||
LogicalResult
|
||||
foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
|
||||
|
||||
/// Returns "success" when any of the elements in `strides` is a constant
|
||||
/// value. In that case the value is replaced by an attribute. Returns
|
||||
/// "failure" when no folding happened. Invalid values are not folded to avoid
|
||||
/// canonicalization crashes.
|
||||
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
|
||||
|
||||
/// Return the number of iterations for a loop with a lower bound `lb`, upper
|
||||
/// bound `ub` and step `step`.
|
||||
|
||||
@@ -2582,17 +2582,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
|
||||
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
|
||||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
|
||||
|
||||
// If one of the offsets or sizes is invalid, fail the canonicalization.
|
||||
// These checks also occur in the verifier, but they are needed here
|
||||
// because some dynamic dimensions may have been constant folded.
|
||||
for (int64_t offset : staticOffsets)
|
||||
if (offset < 0 && !ShapedType::isDynamic(offset))
|
||||
return {};
|
||||
for (int64_t size : staticSizes)
|
||||
if (size < 0 && !ShapedType::isDynamic(size))
|
||||
return {};
|
||||
|
||||
if (!hasValidSizesOffsets(staticOffsets))
|
||||
return {};
|
||||
if (!hasValidSizesOffsets(staticSizes))
|
||||
return {};
|
||||
if (!hasValidStrides(staticStrides))
|
||||
return {};
|
||||
return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
|
||||
staticSizes, staticStrides);
|
||||
}
|
||||
|
||||
@@ -1447,13 +1447,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
|
||||
SmallVector<int64_t> newShape;
|
||||
operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
|
||||
|
||||
for (int64_t newdim : newShape) {
|
||||
// This check also occurs in the verifier, but we need it here too
|
||||
// since intermediate passes may have replaced some dynamic dimensions
|
||||
// by constants.
|
||||
if (newdim < 0 && !ShapedType::isDynamic(newdim))
|
||||
return failure();
|
||||
}
|
||||
if (!hasValidSizesOffsets(newShape))
|
||||
return failure();
|
||||
|
||||
if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
|
||||
return failure();
|
||||
@@ -2549,9 +2544,9 @@ public:
|
||||
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
|
||||
|
||||
// No constant operands were folded, just return;
|
||||
if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
|
||||
failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
|
||||
failed(foldDynamicIndexList(mixedStrides)))
|
||||
if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
|
||||
failed(foldDynamicOffsetSizeList(mixedSizes)) &&
|
||||
failed(foldDynamicStrideList(mixedStrides)))
|
||||
return failure();
|
||||
|
||||
// Create the new op in canonical form.
|
||||
@@ -2692,6 +2687,8 @@ struct InsertSliceOpSourceCastInserter final
|
||||
newSrcShape[i] = *constInt;
|
||||
}
|
||||
}
|
||||
if (!hasValidSizesOffsets(newSrcShape))
|
||||
return failure();
|
||||
|
||||
RankedTensorType newSrcType =
|
||||
RankedTensorType::get(newSrcShape, srcType.getElementType());
|
||||
|
||||
@@ -256,8 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
|
||||
return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
|
||||
}
|
||||
|
||||
bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
|
||||
return llvm::none_of(sizesOrOffsets, [](int64_t value) {
|
||||
return !ShapedType::isDynamic(value) && value < 0;
|
||||
});
|
||||
}
|
||||
|
||||
bool hasValidStrides(SmallVector<int64_t> strides) {
|
||||
return llvm::none_of(strides, [](int64_t value) {
|
||||
return !ShapedType::isDynamic(value) && value == 0;
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
|
||||
bool onlyNonNegative) {
|
||||
bool onlyNonNegative, bool onlyNonZero) {
|
||||
bool valuesChanged = false;
|
||||
for (OpFoldResult &ofr : ofrs) {
|
||||
if (ofr.is<Attribute>())
|
||||
@@ -267,6 +279,8 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
|
||||
// Note: All ofrs have index type.
|
||||
if (onlyNonNegative && *getConstantIntValue(attr) < 0)
|
||||
continue;
|
||||
if (onlyNonZero && *getConstantIntValue(attr) == 0)
|
||||
continue;
|
||||
ofr = attr;
|
||||
valuesChanged = true;
|
||||
}
|
||||
@@ -274,4 +288,15 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
|
||||
return success(valuesChanged);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
|
||||
return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
|
||||
/*onlyNonZero=*/false);
|
||||
}
|
||||
|
||||
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
|
||||
return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
|
||||
/*onlyNonZero=*/true);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @no_fold_subview_zero_stride
|
||||
// CHECK: %[[SUBVIEW:.+]] = memref.subview
|
||||
// CHECK: return %[[SUBVIEW]]
|
||||
func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>>
|
||||
return %1 : memref<1xf32, strided<[?], offset: 1>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @no_fold_of_store
|
||||
// CHECK: %[[cst:.+]] = memref.cast %arg
|
||||
// CHECK: memref.store %[[cst]]
|
||||
|
||||
Reference in New Issue
Block a user