//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns/pass to remove usage of unit-extent dimensions // to specify broadcasting in favor of more canonical representation of the // computation // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" namespace mlir { #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE "linalg-drop-unit-dims" using namespace mlir; using namespace mlir::linalg; namespace { enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice }; } // namespace /// Implements a pass that canonicalizes the uses of unit-extent dimensions for /// broadcasting. For example, /// /// ```mlir /// #accesses = [ /// affine_map<(d0, d1) -> (0, d1)>, /// affine_map<(d0, d1) -> (d0, 0)>, /// affine_map<(d0, d1) -> (d0, d1)> /// ] /// /// #trait = { /// args_in = 2, /// args_out = 1, /// indexing_maps = #accesses, /// iterator_types = ["parallel", "parallel"], /// library_call = "some_external_fn" /// } /// /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> /// tensor<5x5xf32> /// { /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : /// tensor<5xf32> into tensor<1x5xf32> /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : /// tensor<5xf32> into tensor<5x1xf32> /// %2 = linalg.generic #trait %0, %1 { /// ^bb0(%arg2: f32, %arg3: f32): /// %3 = arith.addf %arg2, %arg3 : f32 /// linalg.yield %3 : f32 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> /// return %2 : tensor<5x5xf32> /// } /// /// would canonicalize to /// /// ```mlir /// #accesses = [ /// affine_map<(d0, d1) -> (d1)>, /// affine_map<(d0, d1) -> (d0)>, /// affine_map<(d0, d1) -> (d0, d1)> /// ] /// /// #trait = { /// args_in = 2, /// args_out = 1, /// indexing_maps = #accesses, /// iterator_types = ["parallel", "parallel"], /// library_call = "some_external_fn" /// } /// /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> /// tensor<5x5xf32> /// { /// %0 = linalg.generic #trait %arg0, %arg1 { /// ^bb0(%arg2: f32, %arg3: f32): /// %3 = arith.addf %arg2, %arg3 : f32 /// linalg.yield %3 : f32 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> /// return %0 : tensor<5x5xf32> /// } /// Given dims of the iteration space of a structured op that are known to be /// single trip count (`unitDims`), return the indexing maps to use in the /// canonicalized op with these dims removed, given the original `indexingMaps`. static ArrayAttr replaceUnitDims(DenseSet &unitDims, ArrayRef indexingMaps, MLIRContext *context) { if (indexingMaps.empty()) return nullptr; unsigned numIterationDims = indexingMaps.front().getNumDims(); unsigned numSymbols = indexingMaps.front().getNumSymbols(); // Compute the replacement for each dim expr. SmallVector dimReplacements; dimReplacements.reserve(numIterationDims); unsigned numKeptDims = 0; for (unsigned dim : llvm::seq(0, numIterationDims)) { if (unitDims.count(dim)) dimReplacements.push_back(getAffineConstantExpr(0, context)); else dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); } // Symbols remain the same. SmallVector symReplacements; symReplacements.reserve(numSymbols); for (unsigned symbol : llvm::seq(0, numSymbols)) symReplacements.push_back(getAffineSymbolExpr(symbol, context)); SmallVector newIndexingMaps; newIndexingMaps.reserve(indexingMaps.size()); for (AffineMap operandMap : indexingMaps) { // Expected indexing maps to have no symbols. if (operandMap.getNumSymbols()) return nullptr; newIndexingMaps.push_back(simplifyAffineMap( operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, numIterationDims - unitDims.size(), numSymbols))); } // Check that the new index maps are invertible. If not, something went // wrong, so abort. if (!inversePermutation(concatAffineMaps(newIndexingMaps))) return nullptr; return ArrayAttr::get(context, llvm::to_vector<4>(llvm::map_range( newIndexingMaps, [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }))); } /// Update the index accesses of linalg operations having index semantics. static void replaceUnitDimIndexOps(GenericOp genericOp, const DenseSet &unitDims, PatternRewriter &rewriter) { for (IndexOp indexOp : llvm::make_early_inc_range(genericOp.getBody()->getOps())) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(indexOp); if (unitDims.count(indexOp.getDim()) != 0) { rewriter.replaceOpWithNewOp(indexOp, 0); } else { // Update the dimension of the index operation if needed. unsigned droppedDims = llvm::count_if( unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); }); if (droppedDims != 0) rewriter.replaceOpWithNewOp(indexOp, indexOp.getDim() - droppedDims); } } } namespace { /// Pattern to fold unit-trip count loops in GenericOps. struct FoldUnitDimLoops : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { SmallVector indexingMaps = genericOp.getIndexingMapsArray(); if (indexingMaps.empty()) return failure(); // Check if any of the iteration dimensions are unit-trip count. They will // end up being unit-trip count if they are used to index into a unit-dim // tensor/memref. AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); if (!invertedMap) return failure(); SmallVector dims = genericOp.getStaticShape(); DenseSet unitDims; SmallVector unitDimsReductionLoops; ArrayAttr iteratorTypes = genericOp.getIteratorTypes(); for (const auto &expr : enumerate(invertedMap.getResults())) { if (AffineDimExpr dimExpr = expr.value().dyn_cast()) if (dims[dimExpr.getPosition()] == 1) unitDims.insert(expr.index()); } if (unitDims.empty()) return failure(); // Compute the modified indexing maps. MLIRContext *context = rewriter.getContext(); ArrayAttr newIndexingMapAttr = replaceUnitDims(unitDims, indexingMaps, context); if (!newIndexingMapAttr) return genericOp.emitError("unable to compute modified indexing_maps"); // Compute the iterator types of the modified op by dropping the one-trip // count loops. SmallVector newIteratorTypes; for (const auto &attr : llvm::enumerate(iteratorTypes)) { if (!unitDims.count(attr.index())) newIteratorTypes.push_back(attr.value()); } rewriter.startRootUpdate(genericOp); genericOp.setIndexingMapsAttr(newIndexingMapAttr); genericOp.setIteratorTypesAttr(ArrayAttr::get(context, newIteratorTypes)); replaceUnitDimIndexOps(genericOp, unitDims, rewriter); rewriter.finalizeRootUpdate(genericOp); return success(); } }; /// Pattern to move init operands to ins when all the loops are parallel and /// blockArgument corresponding to init is used in the region. This is a fix-up /// when unit reduction dimensions are all folded away. In this context, it /// becomes a elementwise generic op. E.g., it converts /// /// %0 = tensor.empty() : tensor<1x1xf32> /// %1 = linalg.fill /// ins(%cst : f32) /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> /// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, /// affine_map<(d0) -> (0, d0)>], /// iterator_types = ["parallel"]} /// ins(%arg0 : tensor<1x?x1x1xf32>) /// outs(%1 : tensor<1x1xf32>) { /// ^bb0(%in: f32, %out: f32): /// %3 = arith.addf %in, %out : f32 /// linalg.yield %3 : f32 /// } -> tensor<1x1xf32> /// /// into /// /// %0 = tensor.empty() : tensor<1x1xf32> /// %1 = linalg.fill /// ins(%cst : f32) /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> /// %2 = tensor.empty() : tensor<1x1xf32> /// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, /// affine_map<(d0) -> (0, d0)>, /// affine_map<(d0) -> (0, d0)>], /// iterator_types = ["parallel"]} /// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>) /// outs(%2 : tensor<1x1xf32>) { /// ^bb0(%in: f32, %in_0: f32, %out: f32): /// %4 = arith.addf %in, %in_0 : f32 /// linalg.yield %4 : f32 /// } -> tensor<1x1xf32> struct MoveInitOperandsToInput : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { if (!genericOp.hasTensorSemantics()) return failure(); if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) return failure(); auto outputOperands = genericOp.getDpsInitOperands(); SetVector candidates; for (OpOperand *op : outputOperands) { if (genericOp.getMatchingBlockArgument(op).use_empty()) continue; candidates.insert(op); } if (candidates.empty()) return failure(); // Compute the modified indexing maps. int64_t origNumInput = genericOp.getNumDpsInputs(); SmallVector newInputOperands = genericOp.getDpsInputOperands(); SmallVector indexingMaps = genericOp.getIndexingMapsArray(); SmallVector newIndexingMaps; newIndexingMaps.append(indexingMaps.begin(), std::next(indexingMaps.begin(), origNumInput)); for (OpOperand *op : candidates) { newInputOperands.push_back(op->get()); newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op)); } newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput), indexingMaps.end()); Location loc = genericOp.getLoc(); SmallVector newOutputOperands = outputOperands; for (OpOperand *op : candidates) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfterValue(op->get()); auto elemType = op->get().getType().cast().getElementType(); auto empty = rewriter.create( loc, tensor::createDimValues(rewriter, loc, op->get()), elemType); auto [start, end] = genericOp.getDpsInitsPositionRange(); newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); } auto newOp = rewriter.create( loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(), /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); Region ®ion = newOp.getRegion(); Block *block = new Block(); region.push_back(block); BlockAndValueMapping mapper; OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(block); for (auto bbarg : genericOp.getRegionInputArgs()) mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); for (OpOperand *op : candidates) { BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); } for (OpOperand *op : outputOperands) { BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); if (candidates.count(op)) block->addArgument(bbarg.getType(), loc); else mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); } for (auto &op : genericOp.getBody()->getOperations()) { rewriter.clone(op, mapper); } rewriter.replaceOp(genericOp, newOp.getResults()); return success(); } }; struct UnitExtentReplacementInfo { AffineMap indexMap; SmallVector reassociation; SmallVector targetShape; }; } // namespace /// Utility function for replacing operands/results to a linalg generic /// operation with unit-extent dimensions. These can be replaced with /// an operand/result with the unit-extent dimension removed. This is only done /// if the indexing map used to access that dimension has a /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a /// Linalg op, and its `indexMap` the utility function returns: /// - the new type with dimensions of size 1 removed. /// - modified index map that can be used to access the replaced result/operand /// - the reassociation that converts from the original tensor type to the /// modified tensor type. static std::optional replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand, MLIRContext *context) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); ArrayRef shape = genericOp.getShape(opOperand); ArrayRef exprs = indexingMap.getResults(); SmallVector newIndexExprs; SmallVector newShape; int64_t origRank = genericOp.getRank(opOperand); AffineExpr zeroExpr = getAffineConstantExpr(0, context); auto isUnitExtent = [&](int64_t dim) -> bool { return shape[dim] == 1 && exprs[dim] == zeroExpr; }; // Early return for memrefs with affine maps to represent that we will always // leave them unchanged. Type actualType = opOperand->get().getType(); if (auto memref = actualType.dyn_cast()) { if (!memref.getLayout().isIdentity()) return std::nullopt; } int64_t dim = 0; SmallVector reassociation; ReassociationIndices reassociationGroup; // Fold dimensions that are unit-extent at the beginning of the tensor. while (dim < origRank && isUnitExtent(dim)) reassociationGroup.push_back(dim++); while (dim < origRank) { assert(!isUnitExtent(dim) && "expected non unit-extent"); reassociationGroup.push_back(dim); newIndexExprs.push_back(exprs[dim]); newShape.push_back(shape[dim]); ++dim; // Fold all following dimensions that are unit-extent. while (dim < origRank && isUnitExtent(dim)) reassociationGroup.push_back(dim++); reassociation.push_back(reassociationGroup); reassociationGroup.clear(); } // Return if the rank was not reduced. if (origRank == static_cast(newShape.size())) return std::nullopt; UnitExtentReplacementInfo info = { /*indexMap=*/AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(), newIndexExprs, context), /*reassociation=*/reassociation, /*targetShape=*/newShape}; return info; } namespace { /// Pattern to replace tensor/buffer operands/results that are unit extents. struct ReplaceUnitExtents : public OpRewritePattern { ReplaceUnitExtents(MLIRContext *ctx, RankReductionStrategy rankReductionStrategy) : OpRewritePattern(ctx), rankReductionStrategy(rankReductionStrategy) {} // Expand the given value. Value expandValue(Value result, Value origOutput, ArrayRef reassociation, Location loc, PatternRewriter &rewriter) const { // There are no results for memref outputs. auto origResultType = origOutput.getType().cast(); if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { unsigned rank = origResultType.getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes = tensor::getMixedSizes(rewriter, loc, origOutput); SmallVector strides(rank, rewriter.getIndexAttr(1)); return rewriter.createOrFold( loc, result, origOutput, offsets, sizes, strides); } assert(rankReductionStrategy == RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); return rewriter.create(loc, origResultType, result, reassociation); } // Collapse the given value. Value collapseValue(Value operand, ArrayRef targetShape, ArrayRef reassociation, Location loc, PatternRewriter &rewriter) const { if (auto memrefType = operand.getType().dyn_cast()) { if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { FailureOr rankReducingExtract = memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, targetShape); assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); return *rankReducingExtract; } assert(rankReductionStrategy == RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); MemRefLayoutAttrInterface layout; auto targetType = MemRefType::get(targetShape, memrefType.getElementType(), layout, memrefType.getMemorySpace()); return rewriter.create(loc, targetType, operand, reassociation); } if (auto tensorType = operand.getType().dyn_cast()) { if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { FailureOr rankReducingExtract = tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, targetShape); assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); return *rankReducingExtract; } assert(rankReductionStrategy == RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); auto targetType = RankedTensorType::get(targetShape, tensorType.getElementType()); return rewriter.create(loc, targetType, operand, reassociation); } llvm_unreachable("unsupported operand type"); } LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Skip the pattern if the op has any tensor with special encoding. if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { auto tensorType = type.dyn_cast(); return tensorType && tensorType.getEncoding() != nullptr; })) return failure(); MLIRContext *context = rewriter.getContext(); Location loc = genericOp.getLoc(); SmallVector oldOutputs(genericOp.getOutputs().begin(), genericOp.getOutputs().end()); SmallVector newIndexingMaps; SmallVector> reassociations; SmallVector> targetShapes; SmallVector collapsed; for (OpOperand &opOperand : genericOp->getOpOperands()) { auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context); if (replacementInfo) { reassociations.push_back(replacementInfo->reassociation); newIndexingMaps.push_back(replacementInfo->indexMap); targetShapes.push_back(replacementInfo->targetShape); collapsed.push_back(true); } else { // If replaceUnitExtents cannot handle this case (or no unit dim was // removed), maintain the same type, indexing map, and create a set of // mappings representing an identity matrix. newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand)); reassociations.emplace_back(); targetShapes.emplace_back(); collapsed.push_back(false); } } // Abort if the indexing maps of the result operation are not invertible // (i.e. not legal) or if no dimension was reduced. if (!llvm::any_of(collapsed, [](bool c) { return c; }) || !inversePermutation(concatAffineMaps(newIndexingMaps))) return failure(); // Insert rank reductions. SmallVector newOperands; for (OpOperand &opOperand : genericOp->getOpOperands()) { int64_t idx = opOperand.getOperandNumber(); if (!collapsed[idx]) { newOperands.push_back(opOperand.get()); continue; } newOperands.push_back(collapseValue(opOperand.get(), targetShapes[idx], reassociations[idx], loc, rewriter)); } // If any result type changes, insert a reshape to convert from the original // type to the new type. ArrayRef newInputs = ArrayRef(newOperands).take_front(genericOp.getNumDpsInputs()); ArrayRef newOutputs = ArrayRef(newOperands).take_back(genericOp.getNumDpsInits()); SmallVector resultTypes; resultTypes.reserve(genericOp.getNumResults()); for (unsigned i : llvm::seq(0, genericOp.getNumResults())) resultTypes.push_back(newOutputs[i].getType()); GenericOp replacementOp = rewriter.create( loc, resultTypes, newInputs, newOutputs, newIndexingMaps, genericOp.getIteratorTypesArray()); rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), replacementOp.getRegion().begin()); // If any result tensor has a modified shape, then add reshape to recover // the original shape. SmallVector resultReplacements; for (const auto &result : llvm::enumerate(replacementOp.getResults())) { unsigned index = result.index() + replacementOp.getNumDpsInputs(); Value origOutput = oldOutputs[result.index()]; if (!collapsed[result.index() + genericOp.getNumDpsInputs()]) { resultReplacements.push_back(result.value()); continue; } resultReplacements.push_back(expandValue( result.value(), origOutput, reassociations[index], loc, rewriter)); } rewriter.replaceOp(genericOp, resultReplacements); return success(); } private: RankReductionStrategy rankReductionStrategy; }; } // namespace namespace { /// Convert `extract_slice` operations to rank-reduced versions. struct RankReducedExtractSliceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { RankedTensorType resultType = sliceOp.getType(); SmallVector offsets = sliceOp.getMixedOffsets(); SmallVector sizes = sliceOp.getMixedSizes(); SmallVector strides = sliceOp.getMixedStrides(); auto reassociation = getReassociationMapForFoldingUnitDims(sizes); if (!reassociation || reassociation->size() == static_cast(resultType.getRank())) return failure(); auto rankReducedType = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( reassociation->size(), sliceOp.getSourceType(), offsets, sizes, strides) .cast(); Location loc = sliceOp.getLoc(); Value newSlice = rewriter.create( loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides); rewriter.replaceOpWithNewOp( sliceOp, resultType, newSlice, *reassociation); return success(); } }; /// Convert `insert_slice` operations to rank-reduced versions. /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp. template struct RankReducedInsertSliceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, PatternRewriter &rewriter) const override { RankedTensorType sourceType = insertSliceOp.getSourceType(); SmallVector offsets = insertSliceOp.getMixedOffsets(); SmallVector sizes = insertSliceOp.getMixedSizes(); SmallVector strides = insertSliceOp.getMixedStrides(); auto reassociation = getReassociationMapForFoldingUnitDims(sizes); if (!reassociation || reassociation->size() == static_cast(sourceType.getRank())) return failure(); Location loc = insertSliceOp.getLoc(); tensor::CollapseShapeOp reshapedSource; { OpBuilder::InsertionGuard g(rewriter); // The only difference between InsertSliceOp and ParallelInsertSliceOp is // the insertion point is just before the ParallelCombiningOp in the // parallel case. if (std::is_same::value) rewriter.setInsertionPoint(insertSliceOp->getParentOp()); reshapedSource = rewriter.create( loc, insertSliceOp.getSource(), *reassociation); } rewriter.replaceOpWithNewOp( insertSliceOp, reshapedSource, insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); return success(); } }; } // namespace /// Patterns that are used to canonicalize the use of unit-extent dims for /// broadcasting. void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); patterns.add(context, RankReductionStrategy::ReassociativeReshape); // TODO: Patterns unrelated to unit dim folding should be factored out. patterns.add, RankReducedInsertSliceOp>( context); linalg::FillOp::getCanonicalizationPatterns(patterns, context); tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); tensor::populateFoldTensorEmptyPatterns(patterns); memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); patterns.add(context, RankReductionStrategy::ExtractInsertSlice); patterns.add(context); // TODO: Patterns unrelated to unit dim folding should be factored out. linalg::FillOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::populateFoldTensorEmptyPatterns(patterns); memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } void mlir::linalg::populateMoveInitOperandsToInputPattern( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } namespace { /// Pass that removes unit-extent dims within generic ops. struct LinalgFoldUnitExtentDimsPass : public impl::LinalgFoldUnitExtentDimsBase { void runOnOperation() override { Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); if (foldOneTripLoopsOnly) { patterns.add(context); } else if (useRankReducingSlices) { populateFoldUnitExtentDimsViaSlicesPatterns(patterns); populateMoveInitOperandsToInputPattern(patterns); } else { populateFoldUnitExtentDimsViaReshapesPatterns(patterns); populateMoveInitOperandsToInputPattern(patterns); } (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; } // namespace std::unique_ptr mlir::createLinalgFoldUnitExtentDimsPass() { return std::make_unique(); }