[mlir][bufferization] Add tensor-like and buffer-like interfaces (#134220)

Current one-shot bufferization infrastructure operates on top of
TensorType and BaseMemRefType. These are non-extensible base classes of
the respective builtins: tensor and memref. Thus, the infrastructure is
bound to work only with builtin tensor/memref types. At the same time,
there are customization points that allow one to provide custom logic to
control the bufferization behavior.

This patch introduces new type interfaces: tensor-like and buffer-like
that aim to supersede TensorType/BaseMemRefType within the bufferization
dialect and allow custom tensors / memrefs to be used. Additionally,
these new type interfaces are attached to the respective builtin types
so that the switch is seamless.

Note that this patch does very minimal initial work, it does NOT
refactor bufferization infrastructure.

See https://discourse.llvm.org/t/rfc-changing-base-types-for-tensors-and-memrefs-from-c-base-classes-to-type-interfaces/85509
This commit is contained in:
Andrei Golubev
2025-04-15 10:38:49 +01:00
committed by GitHub
parent 96e3876611
commit 00eaff3e9c
13 changed files with 290 additions and 5 deletions

View File

@@ -0,0 +1,18 @@
//===- BufferizationTypeInterfaces.h - Type Interfaces ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
//===----------------------------------------------------------------------===//
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_

View File

@@ -0,0 +1,42 @@
//===- BufferizationTypeInterfaces.td - Type Interfaces ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the definition file for type interfaces used in Bufferization.
//
//===----------------------------------------------------------------------===//
#ifndef BUFFERIZATION_TYPE_INTERFACES
#define BUFFERIZATION_TYPE_INTERFACES
include "mlir/IR/OpBase.td"
def Bufferization_TensorLikeTypeInterface
: TypeInterface<"TensorLikeType"> {
let cppNamespace = "::mlir::bufferization";
let description = [{
Indicates that this type is a tensor type (similarly to a MLIR builtin
tensor) for bufferization purposes.
The interface currently has no methods as it is used by types to opt into
being supported by the bufferization procedures.
}];
}
def Bufferization_BufferLikeTypeInterface
: TypeInterface<"BufferLikeType"> {
let cppNamespace = "::mlir::bufferization";
let description = [{
Indicates that this type is a buffer type (similarly to a MLIR builtin
memref) for bufferization purposes.
The interface currently has no methods as it is used by types to opt into
being supported by the bufferization procedures.
}];
}
#endif // BUFFERIZATION_TYPE_INTERFACES

View File

@@ -10,3 +10,9 @@ mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS BufferizationTypeInterfaces.td)
mlir_tablegen(BufferizationTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(BufferizationTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRBufferizationTypeInterfacesIncGen)
add_dependencies(mlir-headers MLIRBufferizationTypeInterfacesIncGen)

View File

@@ -471,6 +471,10 @@ def OneShotBufferizePass : Pass<"one-shot-bufferize", "ModuleOp"> {
Statistic<"numTensorOutOfPlace", "num-tensor-out-of-place",
"Number of out-of-place tensor OpOperands">,
];
let dependentDialects = [
"bufferization::BufferizationDialect", "memref::MemRefDialect"
];
}
def PromoteBuffersToStackPass

View File

@@ -9,8 +9,10 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -51,6 +53,16 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface {
return true;
}
};
template <typename Tensor>
struct BuiltinTensorExternalModel
: TensorLikeType::ExternalModel<BuiltinTensorExternalModel<Tensor>,
Tensor> {};
template <typename MemRef>
struct BuiltinMemRefExternalModel
: BufferLikeType::ExternalModel<BuiltinMemRefExternalModel<MemRef>,
MemRef> {};
} // namespace
//===----------------------------------------------------------------------===//
@@ -63,6 +75,20 @@ void mlir::bufferization::BufferizationDialect::initialize() {
#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
>();
addInterfaces<BufferizationInlinerInterface>();
// Note: Unlike with other external models, declaring bufferization's
// "promised interfaces" in builtins for TensorLike and BufferLike type
// interfaces is not possible (due to builtins being independent of
// bufferization). Thus, the compromise is to attach these interfaces directly
// during dialect initialization.
RankedTensorType::attachInterface<
BuiltinTensorExternalModel<RankedTensorType>>(*getContext());
UnrankedTensorType::attachInterface<
BuiltinTensorExternalModel<UnrankedTensorType>>(*getContext());
MemRefType::attachInterface<BuiltinMemRefExternalModel<MemRefType>>(
*getContext());
UnrankedMemRefType::attachInterface<
BuiltinMemRefExternalModel<UnrankedMemRefType>>(*getContext());
}
LogicalResult BufferizationDialect::verifyRegionArgAttribute(

View File

@@ -57,11 +57,6 @@ struct OneShotBufferizePass
OneShotBufferizePass> {
using Base::Base;
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
}
void runOnOperation() override {
OneShotBufferizationOptions opt;
if (!options) {

View File

@@ -0,0 +1,37 @@
// RUN: mlir-opt %s -test-tensorlike-bufferlike -split-input-file | FileCheck %s
// CHECK: func.func @builtin_unranked
// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_buffer_like"}}
func.func @builtin_unranked(%t: tensor<*xf32>) -> (memref<*xf32>)
{
%0 = bufferization.to_memref %t : tensor<*xf32> to memref<*xf32>
return %0 : memref<*xf32>
}
// -----
// CHECK: func.func @builtin_ranked
// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_buffer_like"}}
func.func @builtin_ranked(%t: tensor<42xf32>) -> (memref<42xf32>)
{
%0 = bufferization.to_memref %t : tensor<42xf32> to memref<42xf32>
return %0 : memref<42xf32>
}
// -----
// CHECK: func.func @custom_tensor
// CHECK-SAME: {found = {operand_0 = "is_tensor_like"}}
func.func @custom_tensor(%t: !test.test_tensor<[42], f32>) -> ()
{
return
}
// -----
// CHECK: func.func @custom_memref
// CHECK-SAME: {found = {operand_0 = "is_buffer_like"}}
func.func @custom_memref(%t: !test.test_memref<[42], f32>) -> ()
{
return
}

View File

@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRBufferizationTestPasses
TestTensorCopyInsertion.cpp
TestTensorLikeAndBufferLike.cpp
EXCLUDE_FROM_LIBMLIR
)
@@ -9,4 +10,11 @@ mlir_target_link_libraries(MLIRBufferizationTestPasses PUBLIC
MLIRBufferizationTransforms
MLIRIR
MLIRPass
MLIRTestDialect
)
target_include_directories(MLIRBufferizationTestPasses
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
)

View File

@@ -0,0 +1,99 @@
//===- TestTensorLikeAndBufferLike.cpp - Bufferization Test -----*- c++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Pass/Pass.h"
#include <string>
using namespace mlir;
namespace {
std::string getImplementationStatus(Type type) {
if (isa<bufferization::TensorLikeType>(type)) {
return "is_tensor_like";
}
if (isa<bufferization::BufferLikeType>(type)) {
return "is_buffer_like";
}
return {};
}
DictionaryAttr findAllImplementeesOfTensorOrBufferLike(func::FuncOp funcOp) {
llvm::SmallVector<NamedAttribute> attributes;
const auto funcType = funcOp.getFunctionType();
for (auto [index, inputType] : llvm::enumerate(funcType.getInputs())) {
const auto status = getImplementationStatus(inputType);
if (status.empty()) {
continue;
}
attributes.push_back(
NamedAttribute(StringAttr::get(funcOp.getContext(),
"operand_" + std::to_string(index)),
StringAttr::get(funcOp.getContext(), status)));
}
for (auto [index, resultType] : llvm::enumerate(funcType.getResults())) {
const auto status = getImplementationStatus(resultType);
if (status.empty()) {
continue;
}
attributes.push_back(NamedAttribute(
StringAttr::get(funcOp.getContext(), "result_" + std::to_string(index)),
StringAttr::get(funcOp.getContext(), status)));
}
return mlir::DictionaryAttr::get(funcOp.getContext(), attributes);
}
/// This pass tests whether specified types implement TensorLike and (or)
/// BufferLike type interfaces defined in bufferization.
///
/// The pass analyses operation signature. When the aforementioned interface
/// implementation found, an attribute is added to the operation, signifying the
/// associated operand / result.
struct TestTensorLikeAndBufferLikePass
: public PassWrapper<TestTensorLikeAndBufferLikePass,
OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorLikeAndBufferLikePass)
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, test::TestDialect>();
}
StringRef getArgument() const final { return "test-tensorlike-bufferlike"; }
StringRef getDescription() const final {
return "Module pass to test custom types that implement TensorLike / "
"BufferLike interfaces";
}
void runOnOperation() override {
auto op = getOperation();
op.walk([](func::FuncOp funcOp) {
const auto dict = findAllImplementeesOfTensorOrBufferLike(funcOp);
if (!dict.empty()) {
funcOp->setAttr("found", dict);
}
});
}
};
} // namespace
namespace mlir::test {
void registerTestTensorLikeAndBufferLikePass() {
PassRegistration<TestTensorLikeAndBufferLikePass>();
}
} // namespace mlir::test

View File

@@ -93,6 +93,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
MLIRTransformUtils
MLIRTransforms
MLIRValueBoundsOpInterface
MLIRBufferizationDialect
)
add_mlir_translation_library(MLIRTestFromLLVMIRTranslation

View File

@@ -19,6 +19,7 @@ include "TestAttrDefs.td"
include "TestInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
// All of the types will extend this class.
class Test_Type<string name, list<Trait> traits = []>
@@ -403,4 +404,49 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
let mnemonic = "op_asm_type_interface";
}
def TestTensorType : Test_Type<"TestTensor",
[Bufferization_TensorLikeTypeInterface, ShapedTypeInterface]> {
let mnemonic = "test_tensor";
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"mlir::Type":$elementType
);
let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`";
let extraClassDeclaration = [{
// ShapedTypeInterface:
bool hasRank() const {
return true;
}
test::TestTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
mlir::Type elementType) const {
return test::TestTensorType::get(
getContext(), shape.value_or(getShape()), elementType);
}
}];
}
def TestMemrefType : Test_Type<"TestMemref",
[Bufferization_BufferLikeTypeInterface, ShapedTypeInterface]> {
let mnemonic = "test_memref";
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"mlir::Type":$elementType,
DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace
);
let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`";
let extraClassDeclaration = [{
// ShapedTypeInterface:
bool hasRank() const {
return true;
}
test::TestMemrefType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
mlir::Type elementType) const {
return test::TestMemrefType::get(
getContext(), shape.value_or(getShape()), elementType, getMemSpace());
}
}];
}
#endif // TEST_TYPEDEFS

View File

@@ -18,6 +18,7 @@
#include <tuple>
#include "TestTraits.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"

View File

@@ -150,6 +150,7 @@ void registerTestSPIRVCPURunnerPipeline();
void registerTestSPIRVFuncSignatureConversion();
void registerTestSPIRVVectorUnrolling();
void registerTestTensorCopyInsertionPass();
void registerTestTensorLikeAndBufferLikePass();
void registerTestTensorTransforms();
void registerTestTopologicalSortAnalysisPass();
void registerTestTransformDialectEraseSchedulePass();
@@ -293,6 +294,7 @@ void registerTestPasses() {
mlir::test::registerTestSPIRVFuncSignatureConversion();
mlir::test::registerTestSPIRVVectorUnrolling();
mlir::test::registerTestTensorCopyInsertionPass();
mlir::test::registerTestTensorLikeAndBufferLikePass();
mlir::test::registerTestTensorTransforms();
mlir::test::registerTestTopologicalSortAnalysisPass();
mlir::test::registerTestTransformDialectEraseSchedulePass();