[mlir][Conversion] Add type converter parameter to ConvertToLLVMPatternInterface
Most `*-to-llvm` conversion patterns require a type converter. This revision adds a type converter to the `populateConvertToLLVMConversionPatterns` function and implements the interface for the MemRef dialect. Differential Revision: https://reviews.llvm.org/D157387
This commit is contained in:
@@ -40,13 +40,15 @@ public:
|
||||
/// Hook for derived dialect interface to provide conversion patterns
|
||||
/// and mark dialect legal for the conversion target.
|
||||
virtual void populateConvertToLLVMConversionPatterns(
|
||||
ConversionTarget &target, RewritePatternSet &patterns) const = 0;
|
||||
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) const = 0;
|
||||
};
|
||||
|
||||
/// Recursively walk the IR and collect all dialects implementing the interface,
|
||||
/// and populate the conversion patterns.
|
||||
void populateConversionTargetFromOperation(Operation *op,
|
||||
ConversionTarget &target,
|
||||
LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
class Pass;
|
||||
class LLVMTypeConverter;
|
||||
class RewritePatternSet;
|
||||
@@ -23,6 +24,9 @@ class RewritePatternSet;
|
||||
/// MemRef dialect to the LLVM dialect.
|
||||
void populateFinalizeMemRefToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns);
|
||||
|
||||
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_MEMREFTOLLVM_MEMREFTOLLVM_H
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#ifndef MLIR_INITALLEXTENSIONS_H_
|
||||
#define MLIR_INITALLEXTENSIONS_H_
|
||||
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
|
||||
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
|
||||
#include "mlir/Target/LLVM/NVVM/Target.h"
|
||||
@@ -29,6 +30,7 @@ namespace mlir {
|
||||
/// pipelines and transformations you are using.
|
||||
inline void registerAllExtensions(DialectRegistry ®istry) {
|
||||
func::registerAllExtensions(registry);
|
||||
registerConvertMemRefToLLVMInterface(registry);
|
||||
registerConvertNVVMToLLVMInterface(registry);
|
||||
registerNVVMTarget(registry);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMPass
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRConvertToLLVMInterface
|
||||
MLIRIR
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMDialect
|
||||
MLIRPass
|
||||
MLIRRewrite
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -62,6 +63,7 @@ class ConvertToLLVMPass
|
||||
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
std::shared_ptr<const ConversionTarget> target;
|
||||
std::shared_ptr<const LLVMTypeConverter> typeConverter;
|
||||
|
||||
public:
|
||||
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
|
||||
@@ -72,23 +74,26 @@ public:
|
||||
|
||||
ConvertToLLVMPass(const ConvertToLLVMPass &other)
|
||||
: ConvertToLLVMPassBase(other), patterns(other.patterns),
|
||||
target(other.target) {}
|
||||
target(other.target), typeConverter(other.typeConverter) {}
|
||||
|
||||
LogicalResult initialize(MLIRContext *context) final {
|
||||
RewritePatternSet tempPatterns(context);
|
||||
auto target = std::make_shared<ConversionTarget>(*context);
|
||||
target->addLegalDialect<LLVM::LLVMDialect>();
|
||||
auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
|
||||
for (Dialect *dialect : context->getLoadedDialects()) {
|
||||
// First time we encounter this dialect: if it implements the interface,
|
||||
// let's populate patterns !
|
||||
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
|
||||
if (!iface)
|
||||
continue;
|
||||
iface->populateConvertToLLVMConversionPatterns(*target, tempPatterns);
|
||||
iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
|
||||
tempPatterns);
|
||||
}
|
||||
patterns =
|
||||
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
|
||||
this->target = target;
|
||||
this->typeConverter = typeConverter;
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
void mlir::populateConversionTargetFromOperation(Operation *root,
|
||||
ConversionTarget &target,
|
||||
RewritePatternSet &patterns) {
|
||||
void mlir::populateConversionTargetFromOperation(
|
||||
Operation *root, ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
DenseSet<Dialect *> dialects;
|
||||
root->walk([&](Operation *op) {
|
||||
Dialect *dialect = op->getDialect();
|
||||
@@ -26,6 +26,7 @@ void mlir::populateConversionTargetFromOperation(Operation *root,
|
||||
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
|
||||
if (!iface)
|
||||
return;
|
||||
iface->populateConvertToLLVMConversionPatterns(target, patterns);
|
||||
iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
|
||||
patterns);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
|
||||
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
@@ -1935,4 +1936,27 @@ struct FinalizeMemRefToLLVMConversionPass
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
/// Implement the interface to convert MemRef to LLVM.
|
||||
struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
||||
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
||||
void loadDependentDialects(MLIRContext *context) const final {
|
||||
context->loadDialect<LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
/// Hook for derived dialect interface to provide conversion patterns
|
||||
/// and mark dialect legal for the conversion target.
|
||||
void populateConvertToLLVMConversionPatterns(
|
||||
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) const final {
|
||||
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
|
||||
dialect->addInterfaces<MemRefToLLVMDialectInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -201,7 +201,8 @@ struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
||||
/// Hook for derived dialect interface to provide conversion patterns
|
||||
/// and mark dialect legal for the conversion target.
|
||||
void populateConvertToLLVMConversionPatterns(
|
||||
ConversionTarget &target, RewritePatternSet &patterns) const final {
|
||||
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) const final {
|
||||
populateNVVMToLLVMConversionPatterns(patterns);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
// RUN: mlir-opt -finalize-memref-to-llvm='use-opaque-pointers=1' %s -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt -finalize-memref-to-llvm='index-bitwidth=32 use-opaque-pointers=1' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s
|
||||
|
||||
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
|
||||
// and the generic `convert-to-llvm` pass. This produces slightly different IR
|
||||
// because the conversion target is set up differently. Only one test case is
|
||||
// checked.
|
||||
// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck --check-prefix=CHECK-INTERFACE %s
|
||||
|
||||
// CHECK-LABEL: func @view(
|
||||
// CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index
|
||||
func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
@@ -88,6 +94,10 @@ func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
// CHECK-LABEL: func @view_empty_memref(
|
||||
// CHECK: %[[ARG0:.*]]: index,
|
||||
// CHECK: %[[ARG1:.*]]: memref<0xi8>)
|
||||
|
||||
// CHECK-INTERFACE-LABEL: func @view_empty_memref(
|
||||
// CHECK-INTERFACE: %[[ARG0:.*]]: index,
|
||||
// CHECK-INTERFACE: %[[ARG1:.*]]: memref<0xi8>)
|
||||
func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) {
|
||||
|
||||
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
@@ -101,6 +111,18 @@ func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) {
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.mlir.constant(4 : index) : i64
|
||||
// CHECK: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
|
||||
// CHECK-INTERFACE: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK-INTERFACE: llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK-INTERFACE: llvm.mlir.constant(4 : index) : i64
|
||||
// CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK-INTERFACE: llvm.mlir.constant(1 : index) : i64
|
||||
// CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK-INTERFACE: llvm.mlir.constant(0 : index) : i64
|
||||
// CHECK-INTERFACE: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK-INTERFACE: llvm.mlir.constant(4 : index) : i64
|
||||
// CHECK-INTERFACE: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||
%0 = memref.view %mem[%offset][] : memref<0xi8> to memref<0x4xf32>
|
||||
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user