[mlir][LLVM] Add OpBuilder & to lookupOrCreateFn functions (#136421)

These functions are called from lowering patterns. All IR modifications
in a pattern must be performed through the provided rewriter, but these
functions used to instantiate a new `OpBuilder`, bypassing the provided
rewriter.
This commit is contained in:
Matthias Springer
2025-04-20 10:06:22 +02:00
committed by GitHub
parent 71037ee9de
commit 8553efd2e9
8 changed files with 120 additions and 98 deletions

View File

@@ -1570,13 +1570,13 @@ public:
FailureOr<LLVM::LLVMFuncOp> op = [&]() {
switch (punct) {
case PrintPunctuation::Close:
return LLVM::lookupOrCreatePrintCloseFn(parent);
return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
case PrintPunctuation::Open:
return LLVM::lookupOrCreatePrintOpenFn(parent);
return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
case PrintPunctuation::Comma:
return LLVM::lookupOrCreatePrintCommaFn(parent);
return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
case PrintPunctuation::NewLine:
return LLVM::lookupOrCreatePrintNewlineFn(parent);
return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
default:
llvm_unreachable("unexpected punctuation");
}
@@ -1610,17 +1610,17 @@ private:
PrintConversion conversion = PrintConversion::None;
FailureOr<Operation *> printer;
if (printType.isF32()) {
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent);
} else if (printType.isF64()) {
printer = LLVM::lookupOrCreatePrintF64Fn(parent);
printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent);
} else if (printType.isF16()) {
conversion = PrintConversion::Bitcast16; // bits!
printer = LLVM::lookupOrCreatePrintF16Fn(parent);
printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent);
} else if (printType.isBF16()) {
conversion = PrintConversion::Bitcast16; // bits!
printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent);
} else if (printType.isIndex()) {
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@@ -1630,7 +1630,7 @@ private:
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
} else {
return failure();
}
@@ -1643,7 +1643,7 @@ private:
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
printer = LLVM::lookupOrCreatePrintI64Fn(parent);
printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent);
} else {
return failure();
}