diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h index ed3c21d952a0..43d61832cafd 100644 --- a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h @@ -119,6 +119,14 @@ bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef shifts); /// any dependence component is negative along any of `loops`. bool isTilingValid(ArrayRef loops); +/// Returns true if the affine nest rooted at `root` has a cyclic dependence +/// among its affine memory accesses. The dependence could be through any +/// dependences carried by loops contained in `root` (inclusive of `root`) and +/// those carried by loop bodies (blocks) contained. Dependences carried by +/// loops outer to `root` aren't relevant. This method doesn't consider/account +/// for aliases. +bool hasCyclicDependence(AffineForOp root); + } // namespace affine } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp index 0d4b0ea1668e..411e79171b89 100644 --- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp @@ -16,7 +16,7 @@ #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "llvm/Support/MathExtras.h" @@ -28,10 +28,138 @@ #include #include +#define DEBUG_TYPE "affine-loop-analysis" + using namespace mlir; using namespace mlir::affine; -#define DEBUG_TYPE "affine-loop-analysis" +namespace { + +/// A directed graph to model relationships between MLIR Operations. +class DirectedOpGraph { +public: + /// Add a node to the graph. + void addNode(Operation *op) { + assert(!hasNode(op) && "node already added"); + nodes.emplace_back(op); + edges[op] = {}; + } + + /// Add an edge from `src` to `dest`. + void addEdge(Operation *src, Operation *dest) { + // This is a multi-graph. + assert(hasNode(src) && "src node does not exist in graph"); + assert(hasNode(dest) && "dest node does not exist in graph"); + edges[src].push_back(getNode(dest)); + } + + /// Returns true if there is a (directed) cycle in the graph. + bool hasCycle() { return dfs(/*cycleCheck=*/true); } + + void printEdges() { + for (auto &en : edges) { + llvm::dbgs() << *en.first << " (" << en.first << ")" + << " has " << en.second.size() << " edges:\n"; + for (auto *node : en.second) { + llvm::dbgs() << '\t' << *node->op << '\n'; + } + } + } + +private: + /// A node of a directed graph between MLIR Operations to model various + /// relationships. This is meant to be used internally. + struct DGNode { + DGNode(Operation *op) : op(op) {}; + Operation *op; + + // Start and finish visit numbers are standard in DFS to implement things + // like finding strongly connected components. These numbers are modified + // during analyses on the graph and so seemingly const API methods will be + // non-const. + + /// Start visit number. + int vn = -1; + + /// Finish visit number. + int fn = -1; + }; + + /// Get internal node corresponding to `op`. + DGNode *getNode(Operation *op) { + auto *value = + llvm::find_if(nodes, [&](const DGNode &node) { return node.op == op; }); + assert(value != nodes.end() && "node doesn't exist in graph"); + return &*value; + } + + /// Returns true if `key` is in the graph. + bool hasNode(Operation *key) const { + return llvm::find_if(nodes, [&](const DGNode &node) { + return node.op == key; + }) != nodes.end(); + } + + /// Perform a depth-first traversal of the graph setting visited and finished + /// numbers. If `cycleCheck` is set, detects cycles and returns true as soon + /// as the first cycle is detected, and false if there are no cycles. If + /// `cycleCheck` is not set, completes the DFS and the `return` value doesn't + /// have a meaning. + bool dfs(bool cycleCheck = false) { + for (DGNode &node : nodes) { + node.vn = 0; + node.fn = -1; + } + + unsigned time = 0; + for (DGNode &node : nodes) { + if (node.vn == 0) { + bool ret = dfsNode(node, cycleCheck, time); + // Check if a cycle was already found. + if (cycleCheck && ret) + return true; + } else if (cycleCheck && node.fn == -1) { + // We have encountered a node whose visit has started but it's not + // finished. So we have a cycle. + return true; + } + } + return false; + } + + /// Perform depth-first traversal starting at `node`. Return true + /// as soon as a cycle is found if `cycleCheck` was set. Update `time`. + bool dfsNode(DGNode &node, bool cycleCheck, unsigned &time) const { + auto nodeEdges = edges.find(node.op); + assert(nodeEdges != edges.end() && "missing node in graph"); + node.vn = ++time; + + for (auto &neighbour : nodeEdges->second) { + if (neighbour->vn == 0) { + bool ret = dfsNode(*neighbour, cycleCheck, time); + if (cycleCheck && ret) + return true; + } else if (cycleCheck && neighbour->fn == -1) { + // We have encountered a node whose visit has started but it's not + // finished. So we have a cycle. + return true; + } + } + + // Update finish time. + node.fn = ++time; + + return false; + } + + // The list of nodes. The storage is owned by this class. + SmallVector nodes; + + // Edges as an adjacency list. + DenseMap> edges; +}; + +} // namespace /// Returns the trip count of the loop as an affine expression if the latter is /// expressible as an affine expression, and nullptr otherwise. The trip count @@ -447,3 +575,33 @@ bool mlir::affine::isTilingValid(ArrayRef loops) { return true; } + +bool mlir::affine::hasCyclicDependence(AffineForOp root) { + // Collect all the memory accesses in the source nest grouped by their + // immediate parent block. + DirectedOpGraph graph; + SmallVector accesses; + root->walk([&](Operation *op) { + if (isa(op)) { + accesses.emplace_back(op); + graph.addNode(op); + } + }); + + // Construct the dependence graph for all the collected acccesses. + unsigned rootDepth = getNestingDepth(root); + for (const auto &accA : accesses) { + for (const auto &accB : accesses) { + if (accA.memref != accB.memref) + continue; + // Perform the dependence on all surrounding loops + the body. + unsigned numCommonLoops = + getNumCommonSurroundingLoops(*accA.opInst, *accB.opInst); + for (unsigned d = rootDepth + 1; d <= numCommonLoops + 1; ++d) { + if (!noDependence(checkMemrefAccessDependence(accA, accB, d))) + graph.addEdge(accA.opInst, accB.opInst); + } + } + } + return graph.hasCycle(); +} diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 5add7df84928..b97f11a96382 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -274,6 +274,58 @@ getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock, return firstAncestor; } +/// Returns the amount of additional (redundant) computation that will be done +/// as a fraction of the total computation if `srcForOp` is fused into +/// `dstForOp` at depth `depth`. The method returns the compute cost of the +/// slice and the fused nest's compute cost in the trailing output arguments. +static std::optional getAdditionalComputeFraction( + AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth, + ArrayRef depthSliceUnions, int64_t &sliceCost, + int64_t &fusedLoopNestComputeCost) { + LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";); + // Compute cost of sliced and unsliced src loop nest. + // Walk src loop nest and collect stats. + LoopNestStats srcLoopNestStats; + if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n"); + return std::nullopt; + } + + // Compute cost of dst loop nest. + LoopNestStats dstLoopNestStats; + if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n"); + return std::nullopt; + } + + // Compute op instance count for the src loop nest without iteration slicing. + uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats); + + // Compute op cost for the dst loop nest. + uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); + + const ComputationSliceState &slice = depthSliceUnions[depth - 1]; + // Skip slice union if it wasn't computed for this depth. + if (slice.isEmpty()) { + LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n"); + return std::nullopt; + } + + if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp, + dstLoopNestStats, slice, + &fusedLoopNestComputeCost)) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n"); + return std::nullopt; + } + + double additionalComputeFraction = + fusedLoopNestComputeCost / + (static_cast(srcLoopNestCost) + dstLoopNestCost) - + 1; + + return additionalComputeFraction; +} + // Creates and returns a private (single-user) memref for fused loop rooted at // 'forOp', with (potentially reduced) memref size based on the memref region // written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock' @@ -384,20 +436,19 @@ static Value createPrivateMemRef(AffineForOp forOp, } // Checks the profitability of fusing a backwards slice of the loop nest -// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. -// The argument 'srcStoreOpInst' is used to calculate the storage reduction on -// the memref being produced and consumed, which is an input to the cost model. -// For producer-consumer fusion, 'srcStoreOpInst' will be the same as -// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse -// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the -// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the -// unique store op in the src node, which will be used to check that the write -// region is the same after input-reuse fusion. Computation slices are provided -// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which -// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is -// profitable to fuse the candidate loop nests. Returns false otherwise. -// `dstLoopDepth` is set to the most profitable depth at which to materialize -// the source loop nest slice. +// `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument +// 'srcStoreOpInst' is used to calculate the storage reduction on the memref +// being produced and consumed, which is an input to the cost model. For +// producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst', +// as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst' +// will be the src loop nest LoadOp which reads from the same memref as dst loop +// nest load ops, and 'srcStoreOpInst' will be the unique store op in the src +// node, which will be used to check that the write region is the same after +// input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for +// each legal fusion depth. The maximal depth at which fusion is legal is +// provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse +// the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to +// the most profitable depth at which to materialize the source loop nest slice. // The profitability model executes the following steps: // *) Computes the backward computation slice at 'srcOpInst'. This // computation slice of the loop nest surrounding 'srcOpInst' is @@ -422,15 +473,16 @@ static Value createPrivateMemRef(AffineForOp forOp, // is lower. // TODO: Extend profitability analysis to support scenarios with multiple // stores. -static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, +static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst, AffineForOp dstForOp, ArrayRef depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold) { LLVM_DEBUG({ - llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; - llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; + llvm::dbgs() + << "Checking whether fusion is profitable between source nest:\n"; + llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n"; llvm::dbgs() << dstForOp << "\n"; }); @@ -440,12 +492,10 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, } // Compute cost of sliced and unsliced src loop nest. - SmallVector srcLoopIVs; - getAffineForIVs(*srcOpInst, &srcLoopIVs); // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; - if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) + if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) return false; // Compute cost of dst loop nest. @@ -467,7 +517,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, std::optional bestDstLoopDepth; // Compute op instance count for the src loop nest without iteration slicing. - uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); + uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats); // Compute src loop nest write region size. MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); @@ -494,18 +544,21 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, if (slice.isEmpty()) continue; + // Compute cost of the slice separately, i.e, the compute cost of the slice + // if all outer trip counts are one. + int64_t sliceCost; + int64_t fusedLoopNestComputeCost; - if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, - dstLoopNestStats, slice, - &fusedLoopNestComputeCost)) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n"); + + auto mayAdditionalComputeFraction = + getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions, + sliceCost, fusedLoopNestComputeCost); + if (!mayAdditionalComputeFraction) { + LLVM_DEBUG(llvm::dbgs() + << "Can't determine additional compute fraction.\n"); continue; } - - double additionalComputeFraction = - fusedLoopNestComputeCost / - (static_cast(srcLoopNestCost) + dstLoopNestCost) - - 1; + double additionalComputeFraction = *mayAdditionalComputeFraction; // Determine what the slice write MemRefRegion would be, if the src loop // nest slice 'slice' were to be inserted into the dst loop nest at loop @@ -530,14 +583,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, } int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes; - // If we are fusing for reuse, check that write regions remain the same. - // TODO: Write region check should check sizes and offsets in - // each dimension, so that we are sure they are covering the same memref - // region. Also, move this out to a isMemRefRegionSuperSet helper function. - if (srcOpInst != srcStoreOpInst && - sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) - continue; - double storageReduction = static_cast(srcWriteRegionSizeBytes) / static_cast(sliceWriteRegionSizeBytes); @@ -560,7 +605,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, // (as per computeToleranceThreshold), we will simply pick the one that // reduces the intermediary size the most. if ((storageReduction > maxStorageReduction) && - (additionalComputeFraction < computeToleranceThreshold)) { + (additionalComputeFraction <= computeToleranceThreshold)) { maxStorageReduction = storageReduction; bestDstLoopDepth = i; minFusedLoopNestComputeCost = fusedLoopNestComputeCost; @@ -595,7 +640,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, << minFusedLoopNestComputeCost << "\n"); auto dstMemSize = getMemoryFootprintBytes(dstForOp); - auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); + auto srcMemSize = getMemoryFootprintBytes(srcForOp); std::optional storageReduction; @@ -840,6 +885,8 @@ public: LLVM_DEBUG(llvm::dbgs() << "Trying to fuse producer loop nest " << srcId << " with consumer loop nest " << dstId << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: " + << computeToleranceThreshold << '\n'); LLVM_DEBUG(llvm::dbgs() << "Producer loop nest:\n" << *srcNode->op << "\n and consumer loop nest:\n" @@ -926,6 +973,46 @@ public: continue; } + LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: " + << maxLegalFusionDepth << '\n'); + + double computeToleranceThresholdToUse = computeToleranceThreshold; + + // Cyclic dependences in the source nest may be violated when performing + // slicing-based fusion. They aren't actually violated in cases where no + // redundant execution of the source happens (1:1 pointwise dep on the + // producer-consumer memref access for example). Check this and allow + // fusion accordingly. + if (hasCyclicDependence(srcAffineForOp)) { + LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n"); + // Maximal fusion does not check for compute tolerance threshold; so + // perform the maximal fusion only when the redundanation computation + // is zero. + if (maximalFusion) { + auto srcForOp = cast(srcNode->op); + auto dstForOp = cast(dstNode->op); + int64_t sliceCost; + int64_t fusedLoopNestComputeCost; + auto fraction = getAdditionalComputeFraction( + srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, + sliceCost, fusedLoopNestComputeCost); + if (!fraction || fraction > 0) { + LLVM_DEBUG( + llvm::dbgs() + << "Can't perform maximal fusion with a cyclic dependence " + "and non-zero additional compute.\n"); + return; + } + } else { + // Set redundant computation tolerance to zero regardless of what + // the user specified. Without this, fusion would be invalid. + LLVM_DEBUG(llvm::dbgs() + << "Setting compute tolerance to zero since " + "source has a cylic dependence.\n"); + computeToleranceThresholdToUse = 0; + } + } + // Check if fusion would be profitable. We skip profitability analysis // for maximal fusion since we already know the maximal legal depth to // fuse. @@ -948,10 +1035,10 @@ public: if (producerStores.size() > 1) LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " "supported for this case\n"); - else if (!isFusionProfitable(producerStores[0], producerStores[0], + else if (!isFusionProfitable(srcAffineForOp, producerStores[0], dstAffineForOp, depthSliceUnions, maxLegalFusionDepth, &bestDstLoopDepth, - computeToleranceThreshold)) + computeToleranceThresholdToUse)) continue; } @@ -1163,15 +1250,51 @@ public: if (maxLegalFusionDepth == 0) continue; + double computeToleranceThresholdToUse = computeToleranceThreshold; + + // Cyclic dependences in the source nest may be violated when performing + // slicing-based fusion. They aren't actually violated in cases where no + // redundant execution of the source happens (1:1 pointwise dep on the + // producer-consumer memref access for example). Check this and allow + // fusion accordingly. + if (hasCyclicDependence(sibAffineForOp)) { + LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n"); + // Maximal fusion does not check for compute tolerance threshold; so + // perform the maximal fusion only when the redundanation computation is + // zero. + if (maximalFusion) { + auto dstForOp = cast(dstNode->op); + int64_t sliceCost; + int64_t fusedLoopNestComputeCost; + auto fraction = getAdditionalComputeFraction( + sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, + sliceCost, fusedLoopNestComputeCost); + if (!fraction || fraction > 0) { + LLVM_DEBUG( + llvm::dbgs() + << "Can't perform maximal fusion with a cyclic dependence " + "and non-zero additional compute.\n"); + return; + } + } else { + // Set redundant computation tolerance to zero regardless of what the + // user specified. Without this, fusion would be invalid. + LLVM_DEBUG(llvm::dbgs() << "Setting compute tolerance to zero since " + "source has a cyclic dependence.\n"); + computeToleranceThresholdToUse = 0.0; + } + } + unsigned bestDstLoopDepth = maxLegalFusionDepth; if (!maximalFusion) { // Check if fusion would be profitable. For sibling fusion, the sibling // load op is treated as the src "store" op for fusion profitability // purposes. The footprint of the load in the slice relative to the // unfused source's determines reuse. - if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp, + if (!isFusionProfitable(sibAffineForOp, sibLoadOpInst, dstAffineForOp, depthSliceUnions, maxLegalFusionDepth, - &bestDstLoopDepth, computeToleranceThreshold)) + &bestDstLoopDepth, + computeToleranceThresholdToUse)) continue; } diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir index 42d5ce632188..cf96a30a6e62 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir @@ -495,3 +495,52 @@ func.func @test_add_slice_bounds() { } return } + +// PRODUCER-CONSUMER-MAXIMAL-LABEL: func @producer_reduction_no_fusion +func.func @producer_reduction_no_fusion(%input : memref<10xf32>, %output : memref<10xf32>, %reduc : memref<1xf32>) { + %zero = arith.constant 0. : f32 + %one = arith.constant 1. : f32 + // This producer can't be fused into inside %i without a violation of + // semantics. + // PRODUCER-CONSUMER-MAXIMAL: affine.for %{{.*}} = 0 to 10 + affine.for %i = 0 to 10 { + %0 = affine.load %input[%i] : memref<10xf32> + %1 = affine.load %reduc[0] : memref<1xf32> + %2 = arith.addf %0, %1 : f32 + affine.store %2, %reduc[0] : memref<1xf32> + } + // PRODUCER-CONSUMER-MAXIMAL: affine.for %{{.*}} = 0 to 10 + affine.for %i = 0 to 10 { + %0 = affine.load %reduc[0] : memref<1xf32> + %2 = arith.addf %0, %one : f32 + affine.store %2, %output[%i] : memref<10xf32> + } + return +} + +// SIBLING-MAXIMAL-LABEL: func @sibling_reduction +func.func @sibling_reduction(%input : memref<10xf32>, %output : memref<10xf32>, %reduc : memref<10xf32>) { + %zero = arith.constant 0. : f32 + %one = arith.constant 1. : f32 + affine.for %i = 0 to 10 { + %0 = affine.load %input[%i] : memref<10xf32> + %2 = arith.addf %0, %one : f32 + affine.store %2, %output[%i] : memref<10xf32> + } + // Ensure that the fusion happens at the right depth. + affine.for %i = 0 to 10 { + %0 = affine.load %input[%i] : memref<10xf32> + %1 = affine.load %reduc[0] : memref<10xf32> + %2 = arith.addf %0, %1 : f32 + affine.store %2, %reduc[0] : memref<10xf32> + } + // SIBLING-MAXIMAL: affine.for %{{.*}} = 0 to 10 + // SIBLING-MAXIMAL-NEXT: affine.load + // SIBLING-MAXIMAL-NEXT: addf + // SIBLING-MAXIMAL-NEXT: affine.store + // SIBLING-MAXIMAL-NEXT: affine.load + // SIBLING-MAXIMAL-NEXT: affine.load + // SIBLING-MAXIMAL-NEXT: addf + // SIBLING-MAXIMAL-NEXT: affine.store + return +} diff --git a/mlir/test/Examples/mlir-opt/loop_fusion_options.mlir b/mlir/test/Examples/mlir-opt/loop_fusion_options.mlir index 556e58c522ae..0475e162c38e 100644 --- a/mlir/test/Examples/mlir-opt/loop_fusion_options.mlir +++ b/mlir/test/Examples/mlir-opt/loop_fusion_options.mlir @@ -1,12 +1,13 @@ // RUN: mlir-opt --pass-pipeline="builtin.module(affine-loop-fusion{compute-tolerance=0})" %s | FileCheck %s // CHECK-LABEL: @producer_consumer_fusion -// CHECK-COUNT-3: affine.for module { func.func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) { %0 = memref.alloc() : memref<10xf32> %1 = memref.alloc() : memref<10xf32> %cst = arith.constant 0.000000e+00 : f32 + // CHECK: affine.for + // CHECK-NOT: affine.for affine.for %arg2 = 0 to 10 { affine.store %cst, %0[%arg2] : memref<10xf32> affine.store %cst, %1[%arg2] : memref<10xf32>