[MLIR][mesh] moving shardinginterfaceimpl for tensor to tensor extension lib (#104913)

Follow-up to #102598 : as discussed, move tensor sharding implementation
into separate tensor extension lib.

@sogartar @yaochengji, could you take a look at this PR?
This commit is contained in:
Frank Schlimbach
2024-08-21 12:59:44 +02:00
committed by GitHub
parent 9d364286f3
commit 681ae09722
12 changed files with 101 additions and 5 deletions

View File

@@ -1,4 +1,4 @@
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
//===- MeshShardingExtensions.h - -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@@ -0,0 +1,30 @@
//===- AllExtensions.h - All Tensor Extensions ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a common entry point for registering all extensions to the
// Tensor dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
#define MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
namespace mlir {
class DialectRegistry;
namespace tensor {
/// Register all extensions of the Tensor dialect. This should generally only be
/// used by tools, or other use cases that really do want *all* extensions of
/// the dialect. All other cases should prefer to instead register the specific
/// extensions they intend to take advantage of.
void registerAllExtensions(DialectRegistry &registry);
} // namespace tensor
} // namespace mlir
#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H

View File

@@ -0,0 +1,23 @@
//===- MeshShardingExtensions.h - -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
#define MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
namespace mlir {
class DialectRegistry;
namespace tensor {
void registerShardingInterfaceExternalModels(DialectRegistry &registry);
} // namespace tensor
} // namespace mlir
#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_

View File

@@ -58,7 +58,6 @@
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -182,7 +181,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
tensor::registerShardingInterfaceExternalModels(registry);
tensor::registerSubsetOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);

View File

@@ -34,6 +34,7 @@
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
@@ -60,6 +61,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
registerConvertComplexToLLVMInterface(registry);
cf::registerConvertControlFlowToLLVMInterface(registry);
func::registerAllExtensions(registry);
tensor::registerAllExtensions(registry);
registerConvertFuncToLLVMInterface(registry);
index::registerConvertIndexToLLVMInterface(registry);
registerConvertMathToLLVMInterface(registry);

View File

@@ -1,6 +1,5 @@
add_mlir_library(MLIRShardingInterface
ShardingInterface.cpp
TensorShardingInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh

View File

@@ -1,3 +1,4 @@
add_subdirectory(Extensions)
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)

View File

@@ -0,0 +1,16 @@
//===- AllExtensions.cpp - All Tensor Dialect Extensions ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h"
using namespace mlir;
void mlir::tensor::registerAllExtensions(DialectRegistry &registry) {
registerShardingInterfaceExternalModels(registry);
}

View File

@@ -0,0 +1,26 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
MeshShardingExtensions.cpp
)
add_mlir_extension_library(MLIRTensorMeshShardingExtensions
MeshShardingExtensions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
LINK_LIBS PUBLIC
MLIRTensorDialect
MLIRIR
MLIRShardingInterface
)
add_mlir_extension_library(MLIRTensorAllExtensions
AllExtensions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
LINK_LIBS PUBLIC
MLIRTensorMeshShardingExtensions
)

View File

@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
#include "llvm/Support/Debug.h"

View File

@@ -47,6 +47,7 @@ set(LIBS
MLIRLspServerLib
MLIRParser
MLIRPass
MLIRTensorAllExtensions
MLIRTransforms
MLIRTransformUtils
MLIRSupport