//===- Distibution.cpp - linalg named ops to generic ops --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the Linalg distibution pass. It updates `tiled_loop` // control variables depending on the distribution type. // //===----------------------------------------------------------------------===// // #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #define DEBUG_TYPE "linalg-distribution" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") using namespace mlir; using namespace mlir::linalg; namespace { struct DistributeTiledLoopPattern : public OpRewritePattern { DistributeTiledLoopPattern(MLIRContext *context, LinalgLoopDistributionOptions options, LinalgTransformationFilter marker) : OpRewritePattern(context), options(options), marker(marker) {} LogicalResult matchAndRewrite(linalg::TiledLoopOp op, PatternRewriter &rewriter) const override { if (failed(marker.checkAndNotify(rewriter, op))) return failure(); if (!op.distribution_types().hasValue()) return failure(); Location loc = op.getLoc(); SmallVector newLowerBounds = op.lowerBound(); SmallVector newUpperBounds = op.upperBound(); SmallVector newSteps = op.step(); // Update bounds and steps. auto distributionTypes = op.distribution_types().getValue(); for (int i = 0, e = op.getNumLoops(); i < e; ++i) { StringRef type = distributionTypes[i].cast().getValue(); auto procInfoCallback = options.procInfoMap.find(type); if (procInfoCallback == options.procInfoMap.end()) continue; if (!isParallelIterator(op.iterator_types()[i])) { op.emitOpError("only support for parallel loops is implemented"); return failure(); } ProcInfo info = procInfoCallback->second(rewriter, loc); updateBoundsForCyclicDistribution(rewriter, loc, info.procId, info.nprocs, newLowerBounds[i], newUpperBounds[i], newSteps[i]); } rewriter.updateRootInPlace(op, [&] { op.setLowerBounds(newLowerBounds); op.setUpperBounds(newUpperBounds); op.setSteps(newSteps); }); marker.replaceLinalgTransformationFilter(rewriter, op); return success(); } private: LinalgLoopDistributionOptions options; LinalgTransformationFilter marker; }; } // namespace void mlir::linalg::populateLinalgDistributeTiledLoopPattern( RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts, const LinalgTransformationFilter &marker) { patterns.add(patterns.getContext(), opts, marker); }