Files
clang-p2996/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
Luohao Wang e84f6b6a88 [mlir] Fix conflict of user defined reserved functions with internal prototypes (#123378)
On lowering from `memref` to LLVM, `malloc` and other intrinsic
functions from `libc` will be declared in the current module. User's
redefinition of these reserved functions will poison the internal
analysis with wrong prototype. This patch adds assertion on the found
function's type and reports if it mismatch with the intended type.

Related to #120950


---------

Co-authored-by: Luohao Wang <Luohaothu@users.noreply.github.com>
2025-01-28 14:40:47 +01:00

70 lines
2.8 KiB
C++

//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/ArrayRef.h"
using namespace mlir;
using namespace llvm;
static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
StringRef symbolName) {
static int counter = 0;
std::string uniqueName = std::string(symbolName);
while (moduleOp.lookupSymbol(uniqueName)) {
uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
}
return uniqueName;
}
LogicalResult mlir::LLVM::createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
std::optional<StringRef> runtimeFunctionName) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(moduleOp.getBody());
MLIRContext *ctx = builder.getContext();
// Create a zero-terminated byte representation and allocate global symbol.
SmallVector<uint8_t> elementVals;
elementVals.append(string.begin(), string.end());
if (addNewline)
elementVals.push_back('\n');
elementVals.push_back('\0');
auto dataAttrType = RankedTensorType::get(
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
auto dataAttr =
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
auto arrayTy =
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
auto globalOp = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
// Emit call to `printStr` in runtime library.
builder.restoreInsertionPoint(ip);
auto msgAddr =
builder.create<LLVM::AddressOfOp>(loc, ptrTy, globalOp.getName());
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep =
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
FailureOr<LLVM::LLVMFuncOp> printer =
LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
if (failed(printer))
return failure();
builder.create<LLVM::CallOp>(loc, TypeRange(),
SymbolRefAttr::get(printer.value()), gep);
return success();
}