From 68f0bc6f2e869dc7d3e8394b99fba7052ad8116a Mon Sep 17 00:00:00 2001 From: Rik Huijzer Date: Wed, 6 Dec 2023 07:35:18 +0100 Subject: [PATCH] [mlir] Fix a zero stride canonicalizer crash (#74200) This PR fixes https://github.com/llvm/llvm-project/issues/73383 and is another shot at the refactoring proposed in https://github.com/llvm/llvm-project/pull/72885. --------- Co-authored-by: Kai Sasaki --- .../mlir/Dialect/Utils/StaticValueUtils.h | 30 +++++++++++++++++-- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 17 ++++------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 17 +++++------ mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 27 ++++++++++++++++- mlir/test/Dialect/MemRef/canonicalize.mlir | 12 ++++++++ 5 files changed, 78 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 502ab93ddbfa..1dc0398494dc 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -139,12 +139,36 @@ SmallVector getValuesSortedByKey(ArrayRef keys, ArrayRef values, llvm::function_ref compare); +/// Helper function to check whether the passed in `sizes` or `offsets` are +/// valid. This can be used to re-check whether dimensions are still valid +/// after constant folding the dynamic dimensions. +bool hasValidSizesOffsets(SmallVector sizesOrOffsets); + +/// Helper function to check whether the passed in `strides` are valid. This +/// can be used to re-check whether dimensions are still valid after constant +/// folding the dynamic dimensions. +bool hasValidStrides(SmallVector strides); + /// Returns "success" when any of the elements in `ofrs` is a constant value. In /// that case the value is replaced by an attribute. Returns "failure" when no -/// folding happened. If `onlyNonNegative` is set, only non-negative constant -/// values are folded. +/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only +/// non-negative and non-zero constant values are folded respectively. LogicalResult foldDynamicIndexList(SmallVectorImpl &ofrs, - bool onlyNonNegative = false); + bool onlyNonNegative = false, + bool onlyNonZero = false); + +/// Returns "success" when any of the elements in `offsetsOrSizes` is a +/// constant value. In that case the value is replaced by an attribute. Returns +/// "failure" when no folding happened. Invalid values are not folded to avoid +/// canonicalization crashes. +LogicalResult +foldDynamicOffsetSizeList(SmallVectorImpl &offsetsOrSizes); + +/// Returns "success" when any of the elements in `strides` is a constant +/// value. In that case the value is replaced by an attribute. Returns +/// "failure" when no folding happened. Invalid values are not folded to avoid +/// canonicalization crashes. +LogicalResult foldDynamicStrideList(SmallVectorImpl &strides); /// Return the number of iterations for a loop with a lower bound `lb`, upper /// bound `ub` and step `step`. diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index a397506629cf..93327a28234e 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2582,17 +2582,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - - // If one of the offsets or sizes is invalid, fail the canonicalization. - // These checks also occur in the verifier, but they are needed here - // because some dynamic dimensions may have been constant folded. - for (int64_t offset : staticOffsets) - if (offset < 0 && !ShapedType::isDynamic(offset)) - return {}; - for (int64_t size : staticSizes) - if (size < 0 && !ShapedType::isDynamic(size)) - return {}; - + if (!hasValidSizesOffsets(staticOffsets)) + return {}; + if (!hasValidSizesOffsets(staticSizes)) + return {}; + if (!hasValidStrides(staticStrides)) + return {}; return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, staticSizes, staticStrides); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index f15695383d34..55f813df78b8 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1447,13 +1447,8 @@ struct StaticTensorGenerate : public OpRewritePattern { SmallVector newShape; operandsAndShape(resultType, dynamicExtents, newOperands, newShape); - for (int64_t newdim : newShape) { - // This check also occurs in the verifier, but we need it here too - // since intermediate passes may have replaced some dynamic dimensions - // by constants. - if (newdim < 0 && !ShapedType::isDynamic(newdim)) - return failure(); - } + if (!hasValidSizesOffsets(newShape)) + return failure(); if (newOperands.size() == tensorFromElements.getDynamicExtents().size()) return failure(); @@ -2549,9 +2544,9 @@ public: SmallVector mixedStrides(insertSliceOp.getMixedStrides()); // No constant operands were folded, just return; - if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) && - failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) && - failed(foldDynamicIndexList(mixedStrides))) + if (failed(foldDynamicOffsetSizeList(mixedOffsets)) && + failed(foldDynamicOffsetSizeList(mixedSizes)) && + failed(foldDynamicStrideList(mixedStrides))) return failure(); // Create the new op in canonical form. @@ -2692,6 +2687,8 @@ struct InsertSliceOpSourceCastInserter final newSrcShape[i] = *constInt; } } + if (!hasValidSizesOffsets(newSrcShape)) + return failure(); RankedTensorType newSrcType = RankedTensorType::get(newSrcShape, srcType.getElementType()); diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index c7a3d8fc8eb2..0c8a88da789e 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -256,8 +256,20 @@ std::optional constantTripCount(OpFoldResult lb, OpFoldResult ub, return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant); } +bool hasValidSizesOffsets(SmallVector sizesOrOffsets) { + return llvm::none_of(sizesOrOffsets, [](int64_t value) { + return !ShapedType::isDynamic(value) && value < 0; + }); +} + +bool hasValidStrides(SmallVector strides) { + return llvm::none_of(strides, [](int64_t value) { + return !ShapedType::isDynamic(value) && value == 0; + }); +} + LogicalResult foldDynamicIndexList(SmallVectorImpl &ofrs, - bool onlyNonNegative) { + bool onlyNonNegative, bool onlyNonZero) { bool valuesChanged = false; for (OpFoldResult &ofr : ofrs) { if (ofr.is()) @@ -267,6 +279,8 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl &ofrs, // Note: All ofrs have index type. if (onlyNonNegative && *getConstantIntValue(attr) < 0) continue; + if (onlyNonZero && *getConstantIntValue(attr) == 0) + continue; ofr = attr; valuesChanged = true; } @@ -274,4 +288,15 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl &ofrs, return success(valuesChanged); } +LogicalResult +foldDynamicOffsetSizeList(SmallVectorImpl &offsetsOrSizes) { + return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/false); +} + +LogicalResult foldDynamicStrideList(SmallVectorImpl &strides) { + return foldDynamicIndexList(strides, /*onlyNonNegative=*/false, + /*onlyNonZero=*/true); +} + } // namespace mlir diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index a1f8673638ff..d3406c630f6d 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref) -> memref<1xf32, strided<[?], offset: 1>> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>> + return %1 : memref<1xf32, strided<[?], offset: 1>> +} + +// ----- + // CHECK-LABEL: func @no_fold_of_store // CHECK: %[[cst:.+]] = memref.cast %arg // CHECK: memref.store %[[cst]]