//===- Transforms.cpp ---------------------------------------------- C++ --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Mesh/Transforms/Transforms.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include #include namespace mlir::mesh { namespace { /// Lower `mesh.process_multi_index` into expression using /// `mesh.process_linear_index` and `mesh.cluster_shape`. struct ProcessMultiIndexOpLowering : OpRewritePattern { template ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs) : OpRewritePattern( std::forward(opRewritePatternArgs)...), symbolTableCollection(symbolTableCollection) {} LogicalResult matchAndRewrite(ProcessMultiIndexOp op, PatternRewriter &rewriter) const override { ClusterOp mesh = symbolTableCollection.lookupNearestSymbolFrom( op.getOperation(), op.getMeshAttr()); if (!mesh) { return failure(); } ImplicitLocOpBuilder builder(op->getLoc(), rewriter); builder.setInsertionPointAfter(op.getOperation()); Value linearIndex = builder.create(mesh); ValueRange meshShape = builder.create(mesh).getResults(); SmallVector completeMultiIndex = builder.create(linearIndex, meshShape) .getMultiIndex(); SmallVector multiIndex; ArrayRef opMeshAxes = op.getAxes(); SmallVector opAxesIota; if (opMeshAxes.empty()) { opAxesIota.resize(mesh.getRank()); std::iota(opAxesIota.begin(), opAxesIota.end(), 0); opMeshAxes = opAxesIota; } llvm::transform(opMeshAxes, std::back_inserter(multiIndex), [&completeMultiIndex](MeshAxis meshAxis) { return completeMultiIndex[meshAxis]; }); rewriter.replaceAllUsesWith(op.getResults(), multiIndex); return success(); } private: SymbolTableCollection &symbolTableCollection; }; } // namespace void processMultiIndexOpLoweringPopulatePatterns( RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { patterns.add(symbolTableCollection, patterns.getContext()); } void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry) { registry.insert(); } } // namespace mlir::mesh