diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 05e9fe9d5885..4a7ec6f2efe6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -33,40 +33,53 @@ class LLVMFuncOp; /// implemented separately (e.g. as part of a support runtime library or as part /// of the libc). /// Failure if an unexpected version of function is found. -FailureOr lookupOrCreatePrintI64Fn(Operation *moduleOp); -FailureOr lookupOrCreatePrintU64Fn(Operation *moduleOp); -FailureOr lookupOrCreatePrintF16Fn(Operation *moduleOp); -FailureOr lookupOrCreatePrintBF16Fn(Operation *moduleOp); -FailureOr lookupOrCreatePrintF32Fn(Operation *moduleOp); -FailureOr lookupOrCreatePrintF64Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintI64Fn(OpBuilder &b, + Operation *moduleOp); +FailureOr lookupOrCreatePrintU64Fn(OpBuilder &b, + Operation *moduleOp); +FailureOr lookupOrCreatePrintF16Fn(OpBuilder &b, + Operation *moduleOp); +FailureOr lookupOrCreatePrintBF16Fn(OpBuilder &b, + Operation *moduleOp); +FailureOr lookupOrCreatePrintF32Fn(OpBuilder &b, + Operation *moduleOp); +FailureOr lookupOrCreatePrintF64Fn(OpBuilder &b, + Operation *moduleOp); /// Declares a function to print a C-string. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. FailureOr -lookupOrCreatePrintStringFn(Operation *moduleOp, +lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp, std::optional runtimeFunctionName = {}); -FailureOr lookupOrCreatePrintOpenFn(Operation *moduleOp); -FailureOr lookupOrCreatePrintCloseFn(Operation *moduleOp); -FailureOr lookupOrCreatePrintCommaFn(Operation *moduleOp); -FailureOr lookupOrCreatePrintNewlineFn(Operation *moduleOp); -FailureOr lookupOrCreateMallocFn(Operation *moduleOp, - Type indexType); -FailureOr lookupOrCreateAlignedAllocFn(Operation *moduleOp, - Type indexType); -FailureOr lookupOrCreateFreeFn(Operation *moduleOp); -FailureOr lookupOrCreateGenericAllocFn(Operation *moduleOp, - Type indexType); +FailureOr lookupOrCreatePrintOpenFn(OpBuilder &b, + Operation *moduleOp); +FailureOr lookupOrCreatePrintCloseFn(OpBuilder &b, + Operation *moduleOp); +FailureOr lookupOrCreatePrintCommaFn(OpBuilder &b, + Operation *moduleOp); +FailureOr lookupOrCreatePrintNewlineFn(OpBuilder &b, + Operation *moduleOp); FailureOr -lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType); -FailureOr lookupOrCreateGenericFreeFn(Operation *moduleOp); +lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType); FailureOr -lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, +lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType); +FailureOr lookupOrCreateFreeFn(OpBuilder &b, + Operation *moduleOp); +FailureOr +lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType); +FailureOr +lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, + Type indexType); +FailureOr lookupOrCreateGenericFreeFn(OpBuilder &b, + Operation *moduleOp); +FailureOr +lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType); /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`. /// Return a failure if the FuncOp found has unexpected signature. FailureOr -lookupOrCreateFn(Operation *moduleOp, StringRef name, +lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef paramTypes = {}, Type resultType = {}, bool isVarArg = false, bool isReserved = false); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 47d4474a5c28..c95e375ce9af 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -395,7 +395,7 @@ public: // Allocate memory for the coroutine frame. auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( - op->getParentOfType(), rewriter.getI64Type()); + rewriter, op->getParentOfType(), rewriter.getI64Type()); if (failed(allocFuncOp)) return failure(); auto coroAlloc = rewriter.create( @@ -432,7 +432,7 @@ public: // Free the memory. auto freeFuncOp = - LLVM::lookupOrCreateFreeFn(op->getParentOfType()); + LLVM::lookupOrCreateFreeFn(rewriter, op->getParentOfType()); if (failed(freeFuncOp)) return failure(); rewriter.replaceOpWithNewOp(op, freeFuncOp.value(), diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 1ae99561e9d1..0505214de201 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -278,12 +278,12 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( auto module = builder.getInsertionPoint()->getParentOfType(); FailureOr freeFunc, mallocFunc; if (toDynamic) { - mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); + mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType); if (failed(mallocFunc)) return failure(); } if (!toDynamic) { - freeFunc = LLVM::lookupOrCreateFreeFn(module); + freeFunc = LLVM::lookupOrCreateFreeFn(builder, module); if (failed(freeFunc)) return failure(); } diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 337c01f01a7c..2815e05b3e11 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -60,7 +60,7 @@ LogicalResult mlir::LLVM::createPrintStrCall( Value gep = builder.create(loc, ptrTy, arrayTy, msgAddr, indices); FailureOr printer = - LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); + LLVM::lookupOrCreatePrintStringFn(builder, moduleOp, runtimeFunctionName); if (failed(printer)) return failure(); builder.create(loc, TypeRange(), diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp index bad209a4ddec..e9b79983696a 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -15,24 +15,24 @@ using namespace mlir; static FailureOr -getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, - Type indexType) { +getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, + Operation *module, Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) - return LLVM::lookupOrCreateGenericAllocFn(module, indexType); + return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType); - return LLVM::lookupOrCreateMallocFn(module, indexType); + return LLVM::lookupOrCreateMallocFn(b, module, indexType); } static FailureOr -getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, - Type indexType) { +getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, + Operation *module, Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) - return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType); + return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType); - return LLVM::lookupOrCreateAlignedAllocFn(module, indexType); + return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType); } Value AllocationOpLLVMLowering::createAligned( @@ -75,8 +75,8 @@ std::tuple AllocationOpLLVMLowering::allocateBufferManuallyAlign( Type elementPtrType = this->getElementPtrType(memRefType); assert(elementPtrType && "could not compute element ptr type"); FailureOr allocFuncOp = getNotalignedAllocFn( - getTypeConverter(), op->getParentWithTrait(), - getIndexType()); + rewriter, getTypeConverter(), + op->getParentWithTrait(), getIndexType()); if (failed(allocFuncOp)) return std::make_tuple(Value(), Value()); auto results = @@ -144,8 +144,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign( Type elementPtrType = this->getElementPtrType(memRefType); FailureOr allocFuncOp = getAlignedAllocFn( - getTypeConverter(), op->getParentWithTrait(), - getIndexType()); + rewriter, getTypeConverter(), + op->getParentWithTrait(), getIndexType()); if (failed(allocFuncOp)) return Value(); auto results = rewriter.create( diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index cb4317ef1bce..9c219d8a3d8c 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -43,13 +43,14 @@ static bool isStaticStrideOrOffset(int64_t strideOrOffset) { } static FailureOr -getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) { +getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, + ModuleOp module) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) - return LLVM::lookupOrCreateGenericFreeFn(module); + return LLVM::lookupOrCreateGenericFreeFn(b, module); - return LLVM::lookupOrCreateFreeFn(module); + return LLVM::lookupOrCreateFreeFn(b, module); } struct AllocOpLowering : public AllocLikeOpLLVMLowering { @@ -223,8 +224,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Insert the `free` declaration if it is not already present. - FailureOr freeFunc = - getFreeFn(getTypeConverter(), op->getParentOfType()); + FailureOr freeFunc = getFreeFn( + rewriter, getTypeConverter(), op->getParentOfType()); if (failed(freeFunc)) return failure(); Value allocatedPtr; @@ -834,7 +835,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { // potential alignment auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( - op->getParentOfType(), getIndexType(), sourcePtr.getType()); + rewriter, op->getParentOfType(), getIndexType(), + sourcePtr.getType()); if (failed(copyFn)) return failure(); rewriter.create(loc, copyFn.value(), diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 1a35d0819645..076e5512f375 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1570,13 +1570,13 @@ public: FailureOr op = [&]() { switch (punct) { case PrintPunctuation::Close: - return LLVM::lookupOrCreatePrintCloseFn(parent); + return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent); case PrintPunctuation::Open: - return LLVM::lookupOrCreatePrintOpenFn(parent); + return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent); case PrintPunctuation::Comma: - return LLVM::lookupOrCreatePrintCommaFn(parent); + return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent); case PrintPunctuation::NewLine: - return LLVM::lookupOrCreatePrintNewlineFn(parent); + return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent); default: llvm_unreachable("unexpected punctuation"); } @@ -1610,17 +1610,17 @@ private: PrintConversion conversion = PrintConversion::None; FailureOr printer; if (printType.isF32()) { - printer = LLVM::lookupOrCreatePrintF32Fn(parent); + printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent); } else if (printType.isF64()) { - printer = LLVM::lookupOrCreatePrintF64Fn(parent); + printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent); } else if (printType.isF16()) { conversion = PrintConversion::Bitcast16; // bits! - printer = LLVM::lookupOrCreatePrintF16Fn(parent); + printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent); } else if (printType.isBF16()) { conversion = PrintConversion::Bitcast16; // bits! - printer = LLVM::lookupOrCreatePrintBF16Fn(parent); + printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent); } else if (printType.isIndex()) { - printer = LLVM::lookupOrCreatePrintU64Fn(parent); + printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent); } else if (auto intTy = dyn_cast(printType)) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or @@ -1630,7 +1630,7 @@ private: if (width <= 64) { if (width < 64) conversion = PrintConversion::ZeroExt64; - printer = LLVM::lookupOrCreatePrintU64Fn(parent); + printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent); } else { return failure(); } @@ -1643,7 +1643,7 @@ private: conversion = PrintConversion::ZeroExt64; else if (width < 64) conversion = PrintConversion::SignExt64; - printer = LLVM::lookupOrCreatePrintI64Fn(parent); + printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent); } else { return failure(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 68d4426e6530..1b4a8f496d3d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -46,7 +46,7 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; /// Generic print function lookupOrCreate helper. FailureOr -mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name, +mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef paramTypes, Type resultType, bool isVarArg, bool isReserved) { assert(moduleOp->hasTrait() && @@ -69,60 +69,63 @@ mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name, } return func; } - OpBuilder b(moduleOp->getRegion(0)); + + OpBuilder::InsertionGuard g(b); + assert(!moduleOp->getRegion(0).empty() && "expected non-empty region"); + b.setInsertionPointToStart(&moduleOp->getRegion(0).front()); return b.create( moduleOp->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg)); } static FailureOr -lookupOrCreateReservedFn(Operation *moduleOp, StringRef name, +lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef paramTypes, Type resultType) { - return lookupOrCreateFn(moduleOp, name, paramTypes, resultType, + return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType, /*isVarArg=*/false, /*isReserved=*/true); } FailureOr -mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64), + b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64), LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64), + b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64), LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kPrintF16, + b, moduleOp, kPrintF16, IntegerType::get(moduleOp->getContext(), 16), // bits! LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kPrintBF16, + b, moduleOp, kPrintBF16, IntegerType::get(moduleOp->getContext(), 16), // bits! LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), + b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()), + b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()), LLVM::LLVMVoidType::get(moduleOp->getContext())); } @@ -136,87 +139,91 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) { } FailureOr mlir::LLVM::lookupOrCreatePrintStringFn( - Operation *moduleOp, std::optional runtimeFunctionName) { + OpBuilder &b, Operation *moduleOp, + std::optional runtimeFunctionName) { return lookupOrCreateReservedFn( - moduleOp, runtimeFunctionName.value_or(kPrintString), + b, moduleOp, runtimeFunctionName.value_or(kPrintString), getCharPtr(moduleOp->getContext()), LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kPrintOpen, {}, + b, moduleOp, kPrintOpen, {}, LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kPrintClose, {}, + b, moduleOp, kPrintClose, {}, LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kPrintComma, {}, + b, moduleOp, kPrintComma, {}, LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kPrintNewline, {}, + b, moduleOp, kPrintNewline, {}, LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) { - return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType, +mlir::LLVM::lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, + Type indexType) { + return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType, getVoidPtr(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) { - return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc, +mlir::LLVM::lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, + Type indexType) { + return lookupOrCreateReservedFn(b, moduleOp, kAlignedAlloc, {indexType, indexType}, getVoidPtr(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kFree, getVoidPtr(moduleOp->getContext()), + b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()), LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) { - return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType, +mlir::LLVM::lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, + Type indexType) { + return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType, getVoidPtr(moduleOp->getContext())); } -FailureOr -mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, - Type indexType) { - return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc, +FailureOr mlir::LLVM::lookupOrCreateGenericAlignedAllocFn( + OpBuilder &b, Operation *moduleOp, Type indexType) { + return lookupOrCreateReservedFn(b, moduleOp, kGenericAlignedAlloc, {indexType, indexType}, getVoidPtr(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) { +mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp) { return lookupOrCreateReservedFn( - moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()), + b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()), LLVM::LLVMVoidType::get(moduleOp->getContext())); } FailureOr -mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, +mlir::LLVM::lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, + Type indexType, Type unrankedDescriptorType) { return lookupOrCreateReservedFn( - moduleOp, kMemRefCopy, + b, moduleOp, kMemRefCopy, ArrayRef{indexType, unrankedDescriptorType, unrankedDescriptorType}, LLVM::LLVMVoidType::get(moduleOp->getContext())); }