//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements functions concerned with hoisting invariant operations // in the context of Linalg transformations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using llvm::dbgs; #define DEBUG_TYPE "linalg-hoisting" #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") using namespace mlir; using namespace mlir::linalg; void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) { IRRewriter rewriter(func->getContext()); // TODO: walking in some reverse / inside-out order would be more efficient // and would capture more cases. func.walk([&](scf::ForOp forOp) { hoistRedundantSubsetExtractInsert(rewriter, forOp); }); } static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop) { Value source = transferRead.getSource(); while (auto subView = source.getDefiningOp()) source = subView.getSource(); llvm::SmallVector users(source.getUsers().begin(), source.getUsers().end()); llvm::SmallDenseSet processed; while (!users.empty()) { Operation *user = users.pop_back_val(); // If the user has already been processed skip. if (!processed.insert(user).second) continue; if (auto subView = dyn_cast(user)) { users.append(subView->getUsers().begin(), subView->getUsers().end()); continue; } if (isMemoryEffectFree(user) || isa(user)) continue; if (!loop->isAncestor(user)) continue; return false; } return true; } void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { bool changed = true; while (changed) { changed = false; // First move loop invariant ops outside of their loop. This needs to be // done before as we cannot move ops without interrupting the function walk. func.walk( [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); func.walk([&](vector::TransferReadOp transferRead) { if (!isa(transferRead.getShapedType())) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *transferRead.getOperation() << "\n"); auto loop = dyn_cast(transferRead->getParentOp()); LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() << "\n"); if (!isa_and_nonnull(loop)) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() << "\n"); SetVector forwardSlice; getForwardSlice(transferRead.getOperation(), &forwardSlice); // Look for the last TransferWriteOp in the forwardSlice of // `transferRead` that operates on the same memref. vector::TransferWriteOp transferWrite; for (auto *sliceOp : llvm::reverse(forwardSlice)) { auto candidateWrite = dyn_cast(sliceOp); if (!candidateWrite || candidateWrite.getSource() != transferRead.getSource()) continue; transferWrite = candidateWrite; } // All operands of the TransferRead must be defined outside of the loop. for (auto operand : transferRead.getOperands()) if (!loop.isDefinedOutsideOfLoop(operand)) return WalkResult::advance(); // Only hoist transfer_read / transfer_write pairs and singleton // transfer_reads for now. if (!transferWrite) { // Make sure there are no other accesses to the memref before // hoisting transfer_read. if (noAliasingUseInLoop(transferRead, loop)) loop.moveOutOfLoop(transferRead); return WalkResult::advance(); } LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() << "\n"); // Approximate aliasing by checking that: // 1. indices, vector type and permutation map are the same (i.e., the // transfer_read/transfer_write ops are matching), // 2. no other operations in the loop access the same memref except // for transfer_read/transfer_write accessing statically disjoint // slices. if (transferRead.getIndices() != transferWrite.getIndices() || transferRead.getVectorType() != transferWrite.getVectorType() || transferRead.getPermutationMap() != transferWrite.getPermutationMap()) return WalkResult::advance(); // TODO: may want to memoize this information for performance but it // likely gets invalidated often. DominanceInfo dom(loop); if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) return WalkResult::advance(); for (auto &use : transferRead.getSource().getUses()) { if (!loop->isAncestor(use.getOwner())) continue; if (use.getOwner() == transferRead.getOperation() || use.getOwner() == transferWrite.getOperation()) continue; if (auto transferWriteUse = dyn_cast(use.getOwner())) { if (!vector::isDisjointTransferSet( cast(transferWrite.getOperation()), cast( transferWriteUse.getOperation()))) return WalkResult::advance(); } else if (auto transferReadUse = dyn_cast(use.getOwner())) { if (!vector::isDisjointTransferSet( cast(transferWrite.getOperation()), cast( transferReadUse.getOperation()))) return WalkResult::advance(); } else { // Unknown use, we cannot prove that it doesn't alias with the // transferRead/transferWrite operations. return WalkResult::advance(); } } // Hoist read before. loop.moveOutOfLoop(transferRead); // Hoist write after. transferWrite->moveAfter(loop); // Rewrite `loop` with new yields by cloning and erase the original loop. OpBuilder b(transferRead); NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { return SmallVector{transferWrite.getVector()}; }; // Transfer write has been hoisted, need to update the written vector by // the value yielded by the newForOp. return TypeSwitch(loop) .Case([&](scf::ForOp scfForOp) { auto newForOp = replaceLoopWithNewYields( b, scfForOp, transferRead.getVector(), yieldFn); transferWrite.getVectorMutable().assign( newForOp.getResults().back()); changed = true; loop.erase(); // Need to interrupt and restart because erasing the loop messes up // the walk. return WalkResult::interrupt(); }) .Case([&](affine::AffineForOp affineForOp) { auto newForOp = replaceForOpWithNewYields( b, affineForOp, transferRead.getVector(), SmallVector{transferWrite.getVector()}, transferWrite.getVector()); // Replace all uses of the `transferRead` with the corresponding // basic block argument. transferRead.getVector().replaceUsesWithIf( newForOp.getLoopBody().getArguments().back(), [&](OpOperand &use) { Operation *user = use.getOwner(); return newForOp->isProperAncestor(user); }); transferWrite.getVectorMutable().assign( newForOp.getResults().back()); changed = true; loop.erase(); // Need to interrupt and restart because erasing the loop messes up // the walk. return WalkResult::interrupt(); }) .Default([](Operation *) { return WalkResult::interrupt(); }); }); } }