[mlir][LLVM] Add OpBuilder & to lookupOrCreateFn functions (#136421)
These functions are called from lowering patterns. All IR modifications in a pattern must be performed through the provided rewriter, but these functions used to instantiate a new `OpBuilder`, bypassing the provided rewriter.
This commit is contained in:
committed by
GitHub
parent
71037ee9de
commit
8553efd2e9
@@ -1570,13 +1570,13 @@ public:
|
||||
FailureOr<LLVM::LLVMFuncOp> op = [&]() {
|
||||
switch (punct) {
|
||||
case PrintPunctuation::Close:
|
||||
return LLVM::lookupOrCreatePrintCloseFn(parent);
|
||||
return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
|
||||
case PrintPunctuation::Open:
|
||||
return LLVM::lookupOrCreatePrintOpenFn(parent);
|
||||
return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
|
||||
case PrintPunctuation::Comma:
|
||||
return LLVM::lookupOrCreatePrintCommaFn(parent);
|
||||
return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
|
||||
case PrintPunctuation::NewLine:
|
||||
return LLVM::lookupOrCreatePrintNewlineFn(parent);
|
||||
return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
|
||||
default:
|
||||
llvm_unreachable("unexpected punctuation");
|
||||
}
|
||||
@@ -1610,17 +1610,17 @@ private:
|
||||
PrintConversion conversion = PrintConversion::None;
|
||||
FailureOr<Operation *> printer;
|
||||
if (printType.isF32()) {
|
||||
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
|
||||
printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent);
|
||||
} else if (printType.isF64()) {
|
||||
printer = LLVM::lookupOrCreatePrintF64Fn(parent);
|
||||
printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent);
|
||||
} else if (printType.isF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintF16Fn(parent);
|
||||
printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent);
|
||||
} else if (printType.isBF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
|
||||
printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent);
|
||||
} else if (printType.isIndex()) {
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
|
||||
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
|
||||
// Integers need a zero or sign extension on the operand
|
||||
// (depending on the source type) as well as a signed or
|
||||
@@ -1630,7 +1630,7 @@ private:
|
||||
if (width <= 64) {
|
||||
if (width < 64)
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
@@ -1643,7 +1643,7 @@ private:
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
else if (width < 64)
|
||||
conversion = PrintConversion::SignExt64;
|
||||
printer = LLVM::lookupOrCreatePrintI64Fn(parent);
|
||||
printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user