From 00eaff3e9c897c263a879416d0f151d7ca7eeaff Mon Sep 17 00:00:00 2001 From: Andrei Golubev Date: Tue, 15 Apr 2025 10:38:49 +0100 Subject: [PATCH] [mlir][bufferization] Add tensor-like and buffer-like interfaces (#134220) Current one-shot bufferization infrastructure operates on top of TensorType and BaseMemRefType. These are non-extensible base classes of the respective builtins: tensor and memref. Thus, the infrastructure is bound to work only with builtin tensor/memref types. At the same time, there are customization points that allow one to provide custom logic to control the bufferization behavior. This patch introduces new type interfaces: tensor-like and buffer-like that aim to supersede TensorType/BaseMemRefType within the bufferization dialect and allow custom tensors / memrefs to be used. Additionally, these new type interfaces are attached to the respective builtin types so that the switch is seamless. Note that this patch does very minimal initial work, it does NOT refactor bufferization infrastructure. See https://discourse.llvm.org/t/rfc-changing-base-types-for-tensors-and-memrefs-from-c-base-classes-to-type-interfaces/85509 --- .../IR/BufferizationTypeInterfaces.h | 18 ++++ .../IR/BufferizationTypeInterfaces.td | 42 ++++++++ .../Dialect/Bufferization/IR/CMakeLists.txt | 6 ++ .../Bufferization/Transforms/Passes.td | 4 + .../Bufferization/IR/BufferizationDialect.cpp | 26 +++++ .../Bufferization/Transforms/Bufferize.cpp | 5 - .../Transforms/tensorlike-bufferlike.mlir | 37 +++++++ .../lib/Dialect/Bufferization/CMakeLists.txt | 8 ++ .../TestTensorLikeAndBufferLike.cpp | 99 +++++++++++++++++++ mlir/test/lib/Dialect/Test/CMakeLists.txt | 1 + mlir/test/lib/Dialect/Test/TestTypeDefs.td | 46 +++++++++ mlir/test/lib/Dialect/Test/TestTypes.h | 1 + mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 13 files changed, 290 insertions(+), 5 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td create mode 100644 mlir/test/Dialect/Bufferization/Transforms/tensorlike-bufferlike.mlir create mode 100644 mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndBufferLike.cpp diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h new file mode 100644 index 000000000000..f6b296eccd74 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h @@ -0,0 +1,18 @@ +//===- BufferizationTypeInterfaces.h - Type Interfaces ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ +#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ + +//===----------------------------------------------------------------------===// +// Bufferization Type Interfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc" + +#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td new file mode 100644 index 000000000000..f19224a29564 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td @@ -0,0 +1,42 @@ +//===- BufferizationTypeInterfaces.td - Type Interfaces ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for type interfaces used in Bufferization. +// +//===----------------------------------------------------------------------===// + +#ifndef BUFFERIZATION_TYPE_INTERFACES +#define BUFFERIZATION_TYPE_INTERFACES + +include "mlir/IR/OpBase.td" + +def Bufferization_TensorLikeTypeInterface + : TypeInterface<"TensorLikeType"> { + let cppNamespace = "::mlir::bufferization"; + let description = [{ + Indicates that this type is a tensor type (similarly to a MLIR builtin + tensor) for bufferization purposes. + + The interface currently has no methods as it is used by types to opt into + being supported by the bufferization procedures. + }]; +} + +def Bufferization_BufferLikeTypeInterface + : TypeInterface<"BufferLikeType"> { + let cppNamespace = "::mlir::bufferization"; + let description = [{ + Indicates that this type is a buffer type (similarly to a MLIR builtin + memref) for bufferization purposes. + + The interface currently has no methods as it is used by types to opt into + being supported by the bufferization procedures. + }]; +} + +#endif // BUFFERIZATION_TYPE_INTERFACES diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt index 13a5bc370a4f..3ead52148c20 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt @@ -10,3 +10,9 @@ mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls) mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRBufferizationEnumsIncGen) add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS BufferizationTypeInterfaces.td) +mlir_tablegen(BufferizationTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(BufferizationTypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRBufferizationTypeInterfacesIncGen) +add_dependencies(mlir-headers MLIRBufferizationTypeInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index f53f569070f0..ee33476f441e 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -471,6 +471,10 @@ def OneShotBufferizePass : Pass<"one-shot-bufferize", "ModuleOp"> { Statistic<"numTensorOutOfPlace", "num-tensor-out-of-place", "Number of out-of-place tensor OpOperands">, ]; + + let dependentDialects = [ + "bufferization::BufferizationDialect", "memref::MemRefDialect" + ]; } def PromoteBuffersToStackPass diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index e5a0c3c45b09..6b9253a5d71d 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -9,8 +9,10 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Transforms/InliningUtils.h" @@ -51,6 +53,16 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface { return true; } }; + +template +struct BuiltinTensorExternalModel + : TensorLikeType::ExternalModel, + Tensor> {}; + +template +struct BuiltinMemRefExternalModel + : BufferLikeType::ExternalModel, + MemRef> {}; } // namespace //===----------------------------------------------------------------------===// @@ -63,6 +75,20 @@ void mlir::bufferization::BufferizationDialect::initialize() { #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" >(); addInterfaces(); + + // Note: Unlike with other external models, declaring bufferization's + // "promised interfaces" in builtins for TensorLike and BufferLike type + // interfaces is not possible (due to builtins being independent of + // bufferization). Thus, the compromise is to attach these interfaces directly + // during dialect initialization. + RankedTensorType::attachInterface< + BuiltinTensorExternalModel>(*getContext()); + UnrankedTensorType::attachInterface< + BuiltinTensorExternalModel>(*getContext()); + MemRefType::attachInterface>( + *getContext()); + UnrankedMemRefType::attachInterface< + BuiltinMemRefExternalModel>(*getContext()); } LogicalResult BufferizationDialect::verifyRegionArgAttribute( diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index e97b34b20ff7..0b60c44ece5f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -57,11 +57,6 @@ struct OneShotBufferizePass OneShotBufferizePass> { using Base::Base; - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - void runOnOperation() override { OneShotBufferizationOptions opt; if (!options) { diff --git a/mlir/test/Dialect/Bufferization/Transforms/tensorlike-bufferlike.mlir b/mlir/test/Dialect/Bufferization/Transforms/tensorlike-bufferlike.mlir new file mode 100644 index 000000000000..f8691e110aad --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/tensorlike-bufferlike.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s -test-tensorlike-bufferlike -split-input-file | FileCheck %s + +// CHECK: func.func @builtin_unranked +// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_buffer_like"}} +func.func @builtin_unranked(%t: tensor<*xf32>) -> (memref<*xf32>) +{ + %0 = bufferization.to_memref %t : tensor<*xf32> to memref<*xf32> + return %0 : memref<*xf32> +} + +// ----- + +// CHECK: func.func @builtin_ranked +// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_buffer_like"}} +func.func @builtin_ranked(%t: tensor<42xf32>) -> (memref<42xf32>) +{ + %0 = bufferization.to_memref %t : tensor<42xf32> to memref<42xf32> + return %0 : memref<42xf32> +} + +// ----- + +// CHECK: func.func @custom_tensor +// CHECK-SAME: {found = {operand_0 = "is_tensor_like"}} +func.func @custom_tensor(%t: !test.test_tensor<[42], f32>) -> () +{ + return +} + +// ----- + +// CHECK: func.func @custom_memref +// CHECK-SAME: {found = {operand_0 = "is_buffer_like"}} +func.func @custom_memref(%t: !test.test_memref<[42], f32>) -> () +{ + return +} diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt index c14a9f2cc9bb..226e0bb97732 100644 --- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRBufferizationTestPasses TestTensorCopyInsertion.cpp + TestTensorLikeAndBufferLike.cpp EXCLUDE_FROM_LIBMLIR ) @@ -9,4 +10,11 @@ mlir_target_link_libraries(MLIRBufferizationTestPasses PUBLIC MLIRBufferizationTransforms MLIRIR MLIRPass + MLIRTestDialect ) + +target_include_directories(MLIRBufferizationTestPasses + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test + ) diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndBufferLike.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndBufferLike.cpp new file mode 100644 index 000000000000..60e60849f3e6 --- /dev/null +++ b/mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndBufferLike.cpp @@ -0,0 +1,99 @@ +//===- TestTensorLikeAndBufferLike.cpp - Bufferization Test -----*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" + +#include + +using namespace mlir; + +namespace { +std::string getImplementationStatus(Type type) { + if (isa(type)) { + return "is_tensor_like"; + } + if (isa(type)) { + return "is_buffer_like"; + } + return {}; +} + +DictionaryAttr findAllImplementeesOfTensorOrBufferLike(func::FuncOp funcOp) { + llvm::SmallVector attributes; + + const auto funcType = funcOp.getFunctionType(); + for (auto [index, inputType] : llvm::enumerate(funcType.getInputs())) { + const auto status = getImplementationStatus(inputType); + if (status.empty()) { + continue; + } + + attributes.push_back( + NamedAttribute(StringAttr::get(funcOp.getContext(), + "operand_" + std::to_string(index)), + StringAttr::get(funcOp.getContext(), status))); + } + + for (auto [index, resultType] : llvm::enumerate(funcType.getResults())) { + const auto status = getImplementationStatus(resultType); + if (status.empty()) { + continue; + } + + attributes.push_back(NamedAttribute( + StringAttr::get(funcOp.getContext(), "result_" + std::to_string(index)), + StringAttr::get(funcOp.getContext(), status))); + } + + return mlir::DictionaryAttr::get(funcOp.getContext(), attributes); +} + +/// This pass tests whether specified types implement TensorLike and (or) +/// BufferLike type interfaces defined in bufferization. +/// +/// The pass analyses operation signature. When the aforementioned interface +/// implementation found, an attribute is added to the operation, signifying the +/// associated operand / result. +struct TestTensorLikeAndBufferLikePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorLikeAndBufferLikePass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-tensorlike-bufferlike"; } + StringRef getDescription() const final { + return "Module pass to test custom types that implement TensorLike / " + "BufferLike interfaces"; + } + + void runOnOperation() override { + auto op = getOperation(); + + op.walk([](func::FuncOp funcOp) { + const auto dict = findAllImplementeesOfTensorOrBufferLike(funcOp); + if (!dict.empty()) { + funcOp->setAttr("found", dict); + } + }); + } +}; +} // namespace + +namespace mlir::test { +void registerTestTensorLikeAndBufferLikePass() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index a48ac24ca056..6e608e477239 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -93,6 +93,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC MLIRTransformUtils MLIRTransforms MLIRValueBoundsOpInterface + MLIRBufferizationDialect ) add_mlir_translation_library(MLIRTestFromLLVMIRTranslation diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index f1c31658c13a..e9785594d333 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -19,6 +19,7 @@ include "TestAttrDefs.td" include "TestInterfaces.td" include "mlir/IR/BuiltinTypes.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" // All of the types will extend this class. class Test_Type traits = []> @@ -403,4 +404,49 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface", let mnemonic = "op_asm_type_interface"; } +def TestTensorType : Test_Type<"TestTensor", + [Bufferization_TensorLikeTypeInterface, ShapedTypeInterface]> { + let mnemonic = "test_tensor"; + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "mlir::Type":$elementType + ); + let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`"; + + let extraClassDeclaration = [{ + // ShapedTypeInterface: + bool hasRank() const { + return true; + } + test::TestTensorType cloneWith(std::optional> shape, + mlir::Type elementType) const { + return test::TestTensorType::get( + getContext(), shape.value_or(getShape()), elementType); + } + }]; +} + +def TestMemrefType : Test_Type<"TestMemref", + [Bufferization_BufferLikeTypeInterface, ShapedTypeInterface]> { + let mnemonic = "test_memref"; + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "mlir::Type":$elementType, + DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace + ); + let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`"; + + let extraClassDeclaration = [{ + // ShapedTypeInterface: + bool hasRank() const { + return true; + } + test::TestMemrefType cloneWith(std::optional> shape, + mlir::Type elementType) const { + return test::TestMemrefType::get( + getContext(), shape.value_or(getShape()), elementType, getMemSpace()); + } + }]; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h index cef3f056a798..6499a96f495d 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -18,6 +18,7 @@ #include #include "TestTraits.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index ca4706e96787..91bbfca88c3c 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -150,6 +150,7 @@ void registerTestSPIRVCPURunnerPipeline(); void registerTestSPIRVFuncSignatureConversion(); void registerTestSPIRVVectorUnrolling(); void registerTestTensorCopyInsertionPass(); +void registerTestTensorLikeAndBufferLikePass(); void registerTestTensorTransforms(); void registerTestTopologicalSortAnalysisPass(); void registerTestTransformDialectEraseSchedulePass(); @@ -293,6 +294,7 @@ void registerTestPasses() { mlir::test::registerTestSPIRVFuncSignatureConversion(); mlir::test::registerTestSPIRVVectorUnrolling(); mlir::test::registerTestTensorCopyInsertionPass(); + mlir::test::registerTestTensorLikeAndBufferLikePass(); mlir::test::registerTestTensorTransforms(); mlir::test::registerTestTopologicalSortAnalysisPass(); mlir::test::registerTestTransformDialectEraseSchedulePass();