[MLIR] Add optional cached symbol tables to LLVM conversion patterns (#144032)
This PR allows to optionally speed up the lookup of symbols by providing a `SymbolTableCollection` instance to the interested conversion patterns. It is follow-up on the discussion about symbol / symbol table management carried on [Discourse](https://discourse.llvm.org/t/symbol-table-as-first-class-citizen-in-builders/86813).
This commit is contained in:
@@ -1595,8 +1595,14 @@ private:
|
||||
};
|
||||
|
||||
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
|
||||
SymbolTableCollection *symbolTables = nullptr;
|
||||
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
|
||||
explicit VectorPrintOpConversion(
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
SymbolTableCollection *symbolTables = nullptr)
|
||||
: ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
|
||||
symbolTables(symbolTables) {}
|
||||
|
||||
// Lowering implementation that relies on a small runtime support library,
|
||||
// which only needs to provide a few printing methods (single value for all
|
||||
@@ -1643,13 +1649,17 @@ public:
|
||||
FailureOr<LLVM::LLVMFuncOp> op = [&]() {
|
||||
switch (punct) {
|
||||
case PrintPunctuation::Close:
|
||||
return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
|
||||
return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
|
||||
symbolTables);
|
||||
case PrintPunctuation::Open:
|
||||
return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
|
||||
return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
|
||||
symbolTables);
|
||||
case PrintPunctuation::Comma:
|
||||
return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
|
||||
return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
|
||||
symbolTables);
|
||||
case PrintPunctuation::NewLine:
|
||||
return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
|
||||
return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
|
||||
symbolTables);
|
||||
default:
|
||||
llvm_unreachable("unexpected punctuation");
|
||||
}
|
||||
@@ -1683,17 +1693,17 @@ private:
|
||||
PrintConversion conversion = PrintConversion::None;
|
||||
FailureOr<Operation *> printer;
|
||||
if (printType.isF32()) {
|
||||
printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent);
|
||||
printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
|
||||
} else if (printType.isF64()) {
|
||||
printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent);
|
||||
printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
|
||||
} else if (printType.isF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent);
|
||||
printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
|
||||
} else if (printType.isBF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent);
|
||||
printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
|
||||
} else if (printType.isIndex()) {
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
|
||||
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
|
||||
// Integers need a zero or sign extension on the operand
|
||||
// (depending on the source type) as well as a signed or
|
||||
@@ -1703,7 +1713,8 @@ private:
|
||||
if (width <= 64) {
|
||||
if (width < 64)
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
|
||||
printer =
|
||||
LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
@@ -1716,7 +1727,8 @@ private:
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
else if (width < 64)
|
||||
conversion = PrintConversion::SignExt64;
|
||||
printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent);
|
||||
printer =
|
||||
LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user