[mlir][tensor] remove tensor.insert constant folding out of canonicalization (#142671)

Follow ups from https://github.com/llvm/llvm-project/pull/142458/
In particular concerns that indiscriminately folding tensor constants
can lead to bloating the IR as these can be arbitrarily large.

Signed-off-by: Asra Ali <asraa@google.com>
This commit is contained in:
asraa
2025-06-05 16:53:33 -05:00
committed by GitHub
parent 49386f40dd
commit c66b72f8ce
3 changed files with 0 additions and 92 deletions

View File

@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasFolder = 1;
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//

View File

@@ -1624,76 +1624,6 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
namespace {
/// Pattern to fold an insert op of a constant destination and scalar to a new
/// constant.
///
/// Example:
/// ```
/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
/// %c0 = arith.constant 0 : index
/// %c4_f32 = arith.constant 4.0 : f32
/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
/// ```
/// is rewritten into:
/// ```
/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
/// ```
class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
public:
using OpRewritePattern<InsertOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOp insertOp,
PatternRewriter &rewriter) const override {
// Requires a ranked tensor type.
auto destType =
llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
if (!destType)
return failure();
// Pattern requires constant indices
SmallVector<uint64_t, 8> indices;
for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
auto indiceAttr = dyn_cast<Attribute>(indice);
if (!indiceAttr)
return failure();
indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
}
// Requires a constant scalar to insert
OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
Attribute scalarAttr = dyn_cast<Attribute>(scalar);
if (!scalarAttr)
return failure();
if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
insertOp.getDest().getDefiningOp())) {
if (auto sourceAttr =
llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
// Update the attribute at the inserted index.
auto sourceValues = sourceAttr.getValues<Attribute>();
auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
std::vector<Attribute> updatedValues;
updatedValues.reserve(sourceAttr.getNumElements());
for (unsigned i = 0; i < sourceAttr.getNumElements(); ++i) {
updatedValues.push_back(i == flattenedIndex ? scalarAttr
: sourceValues[i]);
}
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
insertOp, sourceAttr.getType(),
DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
updatedValues));
return success();
}
}
return failure();
}
};
} // namespace
void InsertOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "inserted");
@@ -1717,11 +1647,6 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertOpConstantFold>(context);
}
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//

View File

@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
return %ins_1 : tensor<4xf32>
}
// -----
func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
// Fold an insert into a splat.
// CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
// CHECK-LITERAL:
// CHECK-NEXT: return %[[C4]]
%cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4_i32 = arith.constant 4 : i32
%inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
return %inserted : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @extract_from_tensor.cast