[MLIR][mesh] moving shardinginterfaceimpl for tensor to tensor extension lib (#104913)
Follow-up to #102598 : as discussed, move tensor sharding implementation into separate tensor extension lib. @sogartar @yaochengji, could you take a look at this PR?
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
|
||||
//===- MeshShardingExtensions.h - -----------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
||||
30
mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h
Normal file
30
mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h
Normal file
@@ -0,0 +1,30 @@
|
||||
//===- AllExtensions.h - All Tensor Extensions ------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines a common entry point for registering all extensions to the
|
||||
// Tensor dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
|
||||
#define MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace tensor {
|
||||
/// Register all extensions of the Tensor dialect. This should generally only be
|
||||
/// used by tools, or other use cases that really do want *all* extensions of
|
||||
/// the dialect. All other cases should prefer to instead register the specific
|
||||
/// extensions they intend to take advantage of.
|
||||
void registerAllExtensions(DialectRegistry ®istry);
|
||||
} // namespace tensor
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
|
||||
@@ -0,0 +1,23 @@
|
||||
//===- MeshShardingExtensions.h - -------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
|
||||
#define MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
|
||||
namespace tensor {
|
||||
|
||||
void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
} // namespace tensor
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
|
||||
@@ -58,7 +58,6 @@
|
||||
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
|
||||
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
|
||||
#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
||||
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
@@ -182,7 +181,6 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
||||
tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
|
||||
tensor::registerInferTypeOpInterfaceExternalModels(registry);
|
||||
tensor::registerShardingInterfaceExternalModels(registry);
|
||||
tensor::registerSubsetOpInterfaceExternalModels(registry);
|
||||
tensor::registerTilingInterfaceExternalModels(registry);
|
||||
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
|
||||
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
|
||||
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
|
||||
#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
|
||||
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
|
||||
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
|
||||
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
|
||||
@@ -60,6 +61,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
|
||||
registerConvertComplexToLLVMInterface(registry);
|
||||
cf::registerConvertControlFlowToLLVMInterface(registry);
|
||||
func::registerAllExtensions(registry);
|
||||
tensor::registerAllExtensions(registry);
|
||||
registerConvertFuncToLLVMInterface(registry);
|
||||
index::registerConvertIndexToLLVMInterface(registry);
|
||||
registerConvertMathToLLVMInterface(registry);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
add_mlir_library(MLIRShardingInterface
|
||||
ShardingInterface.cpp
|
||||
TensorShardingInterfaceImpl.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
add_subdirectory(Extensions)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(TransformOps)
|
||||
|
||||
16
mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
Normal file
16
mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
//===- AllExtensions.cpp - All Tensor Dialect Extensions ------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
|
||||
#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
void mlir::tensor::registerAllExtensions(DialectRegistry ®istry) {
|
||||
registerShardingInterfaceExternalModels(registry);
|
||||
}
|
||||
26
mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
Normal file
26
mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
set(LLVM_OPTIONAL_SOURCES
|
||||
AllExtensions.cpp
|
||||
MeshShardingExtensions.cpp
|
||||
)
|
||||
|
||||
add_mlir_extension_library(MLIRTensorMeshShardingExtensions
|
||||
MeshShardingExtensions.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRTensorDialect
|
||||
MLIRIR
|
||||
MLIRShardingInterface
|
||||
)
|
||||
|
||||
add_mlir_extension_library(MLIRTensorAllExtensions
|
||||
AllExtensions.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRTensorMeshShardingExtensions
|
||||
)
|
||||
@@ -6,9 +6,9 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
|
||||
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
@@ -47,6 +47,7 @@ set(LIBS
|
||||
MLIRLspServerLib
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRTensorAllExtensions
|
||||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
MLIRSupport
|
||||
|
||||
Reference in New Issue
Block a user