[MLIR] Add optional cached symbol tables to LLVM conversion patterns (#144032)

This PR allows to optionally speed up the lookup of symbols by providing a `SymbolTableCollection` instance to the interested conversion patterns. It is follow-up on the discussion about symbol / symbol table management carried on [Discourse](https://discourse.llvm.org/t/symbol-table-as-first-class-citizen-in-builders/86813).
This commit is contained in:
Michele Scuttari
2025-06-21 10:55:44 +02:00
committed by GitHub
parent 0921bfd81d
commit bb372963df
11 changed files with 359 additions and 176 deletions

View File

@@ -1595,8 +1595,14 @@ private:
};
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
SymbolTableCollection *symbolTables = nullptr;
public:
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
explicit VectorPrintOpConversion(
const LLVMTypeConverter &typeConverter,
SymbolTableCollection *symbolTables = nullptr)
: ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
symbolTables(symbolTables) {}
// Lowering implementation that relies on a small runtime support library,
// which only needs to provide a few printing methods (single value for all
@@ -1643,13 +1649,17 @@ public:
FailureOr<LLVM::LLVMFuncOp> op = [&]() {
switch (punct) {
case PrintPunctuation::Close:
return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
symbolTables);
case PrintPunctuation::Open:
return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
symbolTables);
case PrintPunctuation::Comma:
return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
symbolTables);
case PrintPunctuation::NewLine:
return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
symbolTables);
default:
llvm_unreachable("unexpected punctuation");
}
@@ -1683,17 +1693,17 @@ private:
PrintConversion conversion = PrintConversion::None;
FailureOr<Operation *> printer;
if (printType.isF32()) {
printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent);
printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
} else if (printType.isF64()) {
printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent);
printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
} else if (printType.isF16()) {
conversion = PrintConversion::Bitcast16; // bits!
printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent);
printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
} else if (printType.isBF16()) {
conversion = PrintConversion::Bitcast16; // bits!
printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent);
printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
} else if (printType.isIndex()) {
printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@@ -1703,7 +1713,8 @@ private:
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
printer =
LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
} else {
return failure();
}
@@ -1716,7 +1727,8 @@ private:
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent);
printer =
LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
} else {
return failure();
}