[mlir][LLVM] NFC - Refactor a lookupOrCreateFn to reuse common function creation.
Differential revision: https://reviews.llvm.org/D96488
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
@@ -1311,11 +1312,14 @@ public:
|
||||
Type eltType = vectorType ? vectorType.getElementType() : printType;
|
||||
Operation *printer;
|
||||
if (eltType.isF32()) {
|
||||
printer = getPrintFloat(printOp);
|
||||
printer =
|
||||
LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
|
||||
} else if (eltType.isF64()) {
|
||||
printer = getPrintDouble(printOp);
|
||||
printer =
|
||||
LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
|
||||
} else if (eltType.isIndex()) {
|
||||
printer = getPrintU64(printOp);
|
||||
printer =
|
||||
LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
|
||||
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
|
||||
// Integers need a zero or sign extension on the operand
|
||||
// (depending on the source type) as well as a signed or
|
||||
@@ -1325,7 +1329,8 @@ public:
|
||||
if (width <= 64) {
|
||||
if (width < 64)
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
printer = getPrintU64(printOp);
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(
|
||||
printOp->getParentOfType<ModuleOp>());
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
@@ -1338,7 +1343,8 @@ public:
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
else if (width < 64)
|
||||
conversion = PrintConversion::SignExt64;
|
||||
printer = getPrintI64(printOp);
|
||||
printer = LLVM::lookupOrCreatePrintI64Fn(
|
||||
printOp->getParentOfType<ModuleOp>());
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
@@ -1351,7 +1357,9 @@ public:
|
||||
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
||||
emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
|
||||
conversion);
|
||||
emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
|
||||
emitCall(rewriter, printOp->getLoc(),
|
||||
LLVM::lookupOrCreatePrintNewlineFn(
|
||||
printOp->getParentOfType<ModuleOp>()));
|
||||
rewriter.eraseOp(printOp);
|
||||
return success();
|
||||
}
|
||||
@@ -1386,8 +1394,10 @@ private:
|
||||
return;
|
||||
}
|
||||
|
||||
emitCall(rewriter, loc, getPrintOpen(op));
|
||||
Operation *printComma = getPrintComma(op);
|
||||
emitCall(rewriter, loc,
|
||||
LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
|
||||
Operation *printComma =
|
||||
LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
|
||||
int64_t dim = vectorType.getDimSize(0);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
auto reducedType =
|
||||
@@ -1401,7 +1411,8 @@ private:
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(rewriter, loc, getPrintClose(op));
|
||||
emitCall(rewriter, loc,
|
||||
LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
|
||||
}
|
||||
|
||||
// Helper to emit a call.
|
||||
@@ -1410,46 +1421,6 @@ private:
|
||||
rewriter.create<LLVM::CallOp>(loc, TypeRange(),
|
||||
rewriter.getSymbolRefAttr(ref), params);
|
||||
}
|
||||
|
||||
// Helper for printer method declaration (first hit) and lookup.
|
||||
static Operation *getPrint(Operation *op, StringRef name,
|
||||
ArrayRef<Type> params) {
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
|
||||
if (func)
|
||||
return func;
|
||||
OpBuilder moduleBuilder(module.getBodyRegion());
|
||||
return moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||
op->getLoc(), name,
|
||||
LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
|
||||
params));
|
||||
}
|
||||
|
||||
// Helpers for method names.
|
||||
Operation *getPrintI64(Operation *op) const {
|
||||
return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64));
|
||||
}
|
||||
Operation *getPrintU64(Operation *op) const {
|
||||
return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64));
|
||||
}
|
||||
Operation *getPrintFloat(Operation *op) const {
|
||||
return getPrint(op, "printF32", Float32Type::get(op->getContext()));
|
||||
}
|
||||
Operation *getPrintDouble(Operation *op) const {
|
||||
return getPrint(op, "printF64", Float64Type::get(op->getContext()));
|
||||
}
|
||||
Operation *getPrintOpen(Operation *op) const {
|
||||
return getPrint(op, "printOpen", {});
|
||||
}
|
||||
Operation *getPrintClose(Operation *op) const {
|
||||
return getPrint(op, "printClose", {});
|
||||
}
|
||||
Operation *getPrintComma(Operation *op) const {
|
||||
return getPrint(op, "printComma", {});
|
||||
}
|
||||
Operation *getPrintNewline(Operation *op) const {
|
||||
return getPrint(op, "printNewline", {});
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of ExtractStridedSliceOp to either:
|
||||
|
||||
Reference in New Issue
Block a user