[mlir][sparse] Move a few routines to CodegenUtils.

Move a few supporting routines for generating function calls to CodegenUtils so
that they can be used by the codegen path for sparse tensor file input and
output.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135691
This commit is contained in:
bixia1
2022-10-11 13:22:14 -07:00
parent 51db96ad2b
commit 2d252a0f5c
3 changed files with 56 additions and 46 deletions

View File

@@ -20,8 +20,6 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -36,20 +34,10 @@ using namespace mlir::sparse_tensor;
namespace {
/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
enum class EmitCInterface : bool { Off = false, On = true };
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
/// Returns the equivalent of `void*` for opaque arguments to the
/// execution engine.
static Type getOpaquePointerType(OpBuilder &builder) {
return LLVM::LLVMPointerType::get(builder.getI8Type());
}
/// Maps each sparse tensor type to an opaque pointer.
static Optional<Type> convertSparseTensorTypes(Type type) {
if (getSparseTensorEncoding(type) != nullptr)
@@ -57,40 +45,6 @@ static Optional<Type> convertSparseTensorTypes(Type type) {
return llvm::None;
}
/// Returns a function reference (first hit also inserts into module). Sets
/// the "_emit_c_interface" on the function declaration when requested,
/// so that LLVM lowering generates a wrapper function that takes care
/// of ABI complications with passing in and returning MemRefs to C functions.
static FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name,
TypeRange resultType, ValueRange operands,
EmitCInterface emitCInterface) {
MLIRContext *context = module.getContext();
auto result = SymbolRefAttr::get(context, name);
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
if (!func) {
OpBuilder moduleBuilder(module.getBodyRegion());
func = moduleBuilder.create<func::FuncOp>(
module.getLoc(), name,
FunctionType::get(context, operands.getTypes(), resultType));
func.setPrivate();
if (static_cast<bool>(emitCInterface))
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(context));
}
return result;
}
/// Creates a `CallOp` to the function reference returned by `getFunc()` in
/// the builder's module.
static func::CallOp createFuncCall(OpBuilder &builder, Location loc,
StringRef name, TypeRange resultType,
ValueRange operands,
EmitCInterface emitCInterface) {
auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
auto fn = getFunc(module, name, resultType, operands, emitCInterface);
return builder.create<func::CallOp>(loc, resultType, fn, operands);
}
/// Replaces the `op` with a `CallOp` to the function reference returned
/// by `getFunc()`.
static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,