[mlir][sparse] Move a few routines to CodegenUtils.
Move a few supporting routines for generating function calls to CodegenUtils so that they can be used by the codegen path for sparse tensor file input and output. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D135691
This commit is contained in:
@@ -20,8 +20,6 @@
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
@@ -36,20 +34,10 @@ using namespace mlir::sparse_tensor;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
|
||||
/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
|
||||
enum class EmitCInterface : bool { Off = false, On = true };
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the equivalent of `void*` for opaque arguments to the
|
||||
/// execution engine.
|
||||
static Type getOpaquePointerType(OpBuilder &builder) {
|
||||
return LLVM::LLVMPointerType::get(builder.getI8Type());
|
||||
}
|
||||
|
||||
/// Maps each sparse tensor type to an opaque pointer.
|
||||
static Optional<Type> convertSparseTensorTypes(Type type) {
|
||||
if (getSparseTensorEncoding(type) != nullptr)
|
||||
@@ -57,40 +45,6 @@ static Optional<Type> convertSparseTensorTypes(Type type) {
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// Returns a function reference (first hit also inserts into module). Sets
|
||||
/// the "_emit_c_interface" on the function declaration when requested,
|
||||
/// so that LLVM lowering generates a wrapper function that takes care
|
||||
/// of ABI complications with passing in and returning MemRefs to C functions.
|
||||
static FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name,
|
||||
TypeRange resultType, ValueRange operands,
|
||||
EmitCInterface emitCInterface) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto result = SymbolRefAttr::get(context, name);
|
||||
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
|
||||
if (!func) {
|
||||
OpBuilder moduleBuilder(module.getBodyRegion());
|
||||
func = moduleBuilder.create<func::FuncOp>(
|
||||
module.getLoc(), name,
|
||||
FunctionType::get(context, operands.getTypes(), resultType));
|
||||
func.setPrivate();
|
||||
if (static_cast<bool>(emitCInterface))
|
||||
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
|
||||
UnitAttr::get(context));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Creates a `CallOp` to the function reference returned by `getFunc()` in
|
||||
/// the builder's module.
|
||||
static func::CallOp createFuncCall(OpBuilder &builder, Location loc,
|
||||
StringRef name, TypeRange resultType,
|
||||
ValueRange operands,
|
||||
EmitCInterface emitCInterface) {
|
||||
auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
auto fn = getFunc(module, name, resultType, operands, emitCInterface);
|
||||
return builder.create<func::CallOp>(loc, resultType, fn, operands);
|
||||
}
|
||||
|
||||
/// Replaces the `op` with a `CallOp` to the function reference returned
|
||||
/// by `getFunc()`.
|
||||
static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
|
||||
|
||||
Reference in New Issue
Block a user