[mlir][LLVM] NFC - Refactor a lookupOrCreateFn to reuse common function creation.

Differential revision: https://reviews.llvm.org/D96488
This commit is contained in:
Nicolas Vasilache
2021-02-11 15:30:39 +00:00
parent 19b4d3ce27
commit e332c22cdf
5 changed files with 228 additions and 109 deletions

View File

@@ -10,6 +10,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
@@ -1311,11 +1312,14 @@ public:
Type eltType = vectorType ? vectorType.getElementType() : printType;
Operation *printer;
if (eltType.isF32()) {
printer = getPrintFloat(printOp);
printer =
LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
} else if (eltType.isF64()) {
printer = getPrintDouble(printOp);
printer =
LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
} else if (eltType.isIndex()) {
printer = getPrintU64(printOp);
printer =
LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@@ -1325,7 +1329,8 @@ public:
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
printer = getPrintU64(printOp);
printer = LLVM::lookupOrCreatePrintU64Fn(
printOp->getParentOfType<ModuleOp>());
} else {
return failure();
}
@@ -1338,7 +1343,8 @@ public:
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
printer = getPrintI64(printOp);
printer = LLVM::lookupOrCreatePrintI64Fn(
printOp->getParentOfType<ModuleOp>());
} else {
return failure();
}
@@ -1351,7 +1357,9 @@ public:
int64_t rank = vectorType ? vectorType.getRank() : 0;
emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
conversion);
emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
emitCall(rewriter, printOp->getLoc(),
LLVM::lookupOrCreatePrintNewlineFn(
printOp->getParentOfType<ModuleOp>()));
rewriter.eraseOp(printOp);
return success();
}
@@ -1386,8 +1394,10 @@ private:
return;
}
emitCall(rewriter, loc, getPrintOpen(op));
Operation *printComma = getPrintComma(op);
emitCall(rewriter, loc,
LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
Operation *printComma =
LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
int64_t dim = vectorType.getDimSize(0);
for (int64_t d = 0; d < dim; ++d) {
auto reducedType =
@@ -1401,7 +1411,8 @@ private:
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
emitCall(rewriter, loc, getPrintClose(op));
emitCall(rewriter, loc,
LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
}
// Helper to emit a call.
@@ -1410,46 +1421,6 @@ private:
rewriter.create<LLVM::CallOp>(loc, TypeRange(),
rewriter.getSymbolRefAttr(ref), params);
}
// Helper for printer method declaration (first hit) and lookup.
static Operation *getPrint(Operation *op, StringRef name,
ArrayRef<Type> params) {
auto module = op->getParentOfType<ModuleOp>();
auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
if (func)
return func;
OpBuilder moduleBuilder(module.getBodyRegion());
return moduleBuilder.create<LLVM::LLVMFuncOp>(
op->getLoc(), name,
LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
params));
}
// Helpers for method names.
Operation *getPrintI64(Operation *op) const {
return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64));
}
Operation *getPrintU64(Operation *op) const {
return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64));
}
Operation *getPrintFloat(Operation *op) const {
return getPrint(op, "printF32", Float32Type::get(op->getContext()));
}
Operation *getPrintDouble(Operation *op) const {
return getPrint(op, "printF64", Float64Type::get(op->getContext()));
}
Operation *getPrintOpen(Operation *op) const {
return getPrint(op, "printOpen", {});
}
Operation *getPrintClose(Operation *op) const {
return getPrint(op, "printClose", {});
}
Operation *getPrintComma(Operation *op) const {
return getPrint(op, "printComma", {});
}
Operation *getPrintNewline(Operation *op) const {
return getPrint(op, "printNewline", {});
}
};
/// Progressive lowering of ExtractStridedSliceOp to either: