[mlir][tensor] remove tensor.insert constant folding out of canonicalization (#142671)
Follow ups from https://github.com/llvm/llvm-project/pull/142458/ In particular concerns that indiscriminately folding tensor constants can lead to bloating the IR as these can be arbitrarily large. Signed-off-by: Asra Ali <asraa@google.com>
This commit is contained in:
@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1624,76 +1624,6 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
|
||||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
/// Pattern to fold an insert op of a constant destination and scalar to a new
|
||||
/// constant.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
|
||||
/// %c0 = arith.constant 0 : index
|
||||
/// %c4_f32 = arith.constant 4.0 : f32
|
||||
/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
|
||||
/// ```
|
||||
/// is rewritten into:
|
||||
/// ```
|
||||
/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
|
||||
/// ```
|
||||
class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(InsertOp insertOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Requires a ranked tensor type.
|
||||
auto destType =
|
||||
llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
|
||||
if (!destType)
|
||||
return failure();
|
||||
|
||||
// Pattern requires constant indices
|
||||
SmallVector<uint64_t, 8> indices;
|
||||
for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
|
||||
auto indiceAttr = dyn_cast<Attribute>(indice);
|
||||
if (!indiceAttr)
|
||||
return failure();
|
||||
indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
|
||||
}
|
||||
|
||||
// Requires a constant scalar to insert
|
||||
OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
|
||||
Attribute scalarAttr = dyn_cast<Attribute>(scalar);
|
||||
if (!scalarAttr)
|
||||
return failure();
|
||||
|
||||
if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
|
||||
insertOp.getDest().getDefiningOp())) {
|
||||
if (auto sourceAttr =
|
||||
llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
|
||||
// Update the attribute at the inserted index.
|
||||
auto sourceValues = sourceAttr.getValues<Attribute>();
|
||||
auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
|
||||
std::vector<Attribute> updatedValues;
|
||||
updatedValues.reserve(sourceAttr.getNumElements());
|
||||
for (unsigned i = 0; i < sourceAttr.getNumElements(); ++i) {
|
||||
updatedValues.push_back(i == flattenedIndex ? scalarAttr
|
||||
: sourceValues[i]);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
insertOp, sourceAttr.getType(),
|
||||
DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
|
||||
updatedValues));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void InsertOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "inserted");
|
||||
@@ -1717,11 +1647,6 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
}
|
||||
|
||||
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<InsertOpConstantFold>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenerateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
|
||||
return %ins_1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
|
||||
// Fold an insert into a splat.
|
||||
// CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
|
||||
// CHECK-LITERAL:
|
||||
// CHECK-NEXT: return %[[C4]]
|
||||
%cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4_i32 = arith.constant 4 : i32
|
||||
%inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
|
||||
return %inserted : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_from_tensor.cast
|
||||
|
||||
Reference in New Issue
Block a user